In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange

In [5]:
BATCH_SIZE = 256
LEARNING_RATE = 0.002
EPOCHS = 10
BOND_DIM = 64          
CP_RANK = 16           
DROPOUT = 0.1          
PATCH_SIZE = 4         

In [3]:
patch_layer = nn.Sequential(
    # 1. Chop the image physically into a sequence of flattened patches
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 
            p1=PATCH_SIZE, p2=PATCH_SIZE),
    # 2. Project raw pixels (16) to bond dimension (64)
    nn.Linear(PATCH_SIZE**2, BOND_DIM)
)

In [4]:
def get_mnist_loaders():
    transform = transforms.Compose([
        transforms.Resize((32,32)), 
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    return train_loader, test_loader

In [6]:

class CPLowRankLayer(nn.Module):
    """
    Represents a layer of tree nodes using CP Decomposition.
    Compresses a dense tensor of shape (Out, Left, Right) into 3 matrices.
    """
    def __init__(self, num_nodes, in_dim, out_dim, rank, dropout_p=0.0):
        super().__init__()
        self.num_nodes = num_nodes
        self.rank = rank
        self.dropout_p = dropout_p
        
        self.factor_left = nn.Parameter(torch.randn(num_nodes, rank, in_dim))
        self.factor_right = nn.Parameter(torch.randn(num_nodes, rank, in_dim))
        self.factor_out = nn.Parameter(torch.randn(num_nodes, rank, out_dim))
        
        self.scale = nn.Parameter(torch.ones(num_nodes, rank))

        self._initialize()

    def _initialize(self):
        with torch.no_grad():
            self.factor_left.data /= (self.factor_left.data.norm(dim=-1, keepdim=True) + 1e-8)
            self.factor_right.data /= (self.factor_right.data.norm(dim=-1, keepdim=True) + 1e-8)
            self.factor_out.data /= (self.factor_out.data.norm(dim=-1, keepdim=True) + 1e-8)
            self.scale.data.normal_(1.0, 0.02)

    def forward(self, x):
        B = x.size(0) 
        
        x = x.view(B, self.num_nodes, 2, -1)
        x_l = x[:, :, 0, :] # (B, N, In)
        x_r = x[:, :, 1, :] # (B, N, In)

        proj_l = torch.einsum('bni,nri->bnr', x_l, self.factor_left)
        proj_r = torch.einsum('bni,nri->bnr', x_r, self.factor_right)
        
        merged = self.scale.unsqueeze(0) * proj_l * proj_r
        
        if self.training and self.dropout_p > 0:
            mask = torch.bernoulli(torch.full_like(merged, 1 - self.dropout_p))
            merged = merged * mask / (1 - self.dropout_p)

        out = torch.einsum('bnr,nro->bno', merged, self.factor_out)
        out = out + x_l + x_r
        
        return out
