## Tensorized Neural Networks (TNN)

Computes $f_n \circ \cdots \circ f_0 (x)$, where $x$ is an MPS tensor and each $f_i$ is an MPO followed by a rank dropout and non-linearity 

Next steps:
- Add regression and classification heads


In [3]:
from typing import List, Union, Optional
import torch

def mpo_contract(
    mpo: Union[List[torch.Tensor], torch.nn.ParameterList], 
    mps: Union[List[torch.Tensor], torch.Tensor]
    ) -> Union[List[torch.Tensor], torch.Tensor]: # returns a new mps
    """Perform a tensor network contraction of an MPO and an MPS. Returns a new MPS.

    Args:
        mpo (List[torch.Tensor]): A list of tensors representing the MPO. Shape: (Rl, Di, Do, Rr)
        mps (List[torch.Tensor]): A list of tensors representing the MPS. Shape: (B, Rl, Di, Rr)

    Returns:
        List[torch.Tensor]: A list of tensors representing the new MPS. Shape: (B, Do, Rl, Rr)
    """
    out = []
    for i in range(len(mpo)):
        mps_prime = torch.einsum('rios,bpiq->brposq', mps[i], mpo[i])
        B, R, P, Do, S, Q = mps_prime.shape
        out.append(mps_prime.reshape(B, R*P, Do, S*Q))
    return out


class MPO(torch.nn.Module):
    def __init__(self, in_features: List[int], out_features: List[int], ranks: List[int], max_rank: int = 2):
        super(MPO, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ranks = ranks
        self.max_rank = max_rank
        self.g = torch.nn.ParameterList([
            torch.nn.Parameter(torch.randn(ranks[i], in_features[i], out_features[i], ranks[i+1]))
            for i in range(len(in_features))  # (Rk, Ik, Ok, Rk+1)
        ])

    def forward(self, mps_x: List[torch.Tensor], mps_y: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        # mps_x:[(B, Rl, Di, Rr), ...], mps_y:[(B, Rl, Do, Rr), ...]

        # MPO x MPS
        mps_y_hat = mpo_contract(self.g, mps_x)  

        # Rank dropout
        mps_y_hat_reduced = []
        for i in range(len(mps_y_hat)):
            # print(mps_y_hat[i].shape, self.max_rank)
            Rl = torch.randint(0, mps_y_hat[i].shape[1], (self.max_rank,))  # (B, Rl, Do, Rr)
            Rr = torch.randint(0, mps_y_hat[i].shape[3], (self.max_rank,))
            # print(Rl, Rr)
            mps_y_hat_reduced.append(mps_y_hat[i][:, Rl][:, :, :, Rr])

        # Non-linearity
        mps_y_hat_reduced = [torch.relu(mps_y_hat_reduced[i]) for i in range(len(mps_y_hat_reduced))]
        
        return mps_y_hat_reduced

class TNN(torch.nn.Module):
    def __init__(self, in_features: List[int], out_features: List[int], n_layers: int=4, max_rank: int = 2):
        super(TNN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_layers = n_layers
        self.max_rank = max_rank
        self.layers = torch.nn.ModuleList([
            MPO(
                in_features=in_features, 
                out_features=out_features, 
                ranks=[2] * (n_layers + 1), 
                max_rank=max_rank
            ) for _ in range(n_layers)
        ])
    
    def forward(self, mps_x: List[torch.Tensor], mps_y: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            mps_x = layer(mps_x, mps_y)
        return mps_x


In [None]:
# ----------------------------------------------------------
# Test TNN (Uses 8 MPO layers w/ non-linearities
# ----------------------------------------------------------

# Hyperparameters
batch_size, n_layers, max_rank, = 2, 8, 2
in_features, out_features, ranks = [2, 2], [2, 2], [2, 2, 2]

# Init model
mps_x = [torch.randn(batch_size, ranks[i], in_features[i], ranks[i+1]) for i in range(len(in_features))]
ttn = TNN(in_features=in_features, out_features=out_features, n_layers=n_layers, max_rank=max_rank)

# Forward pass
mps_y = ttn(mps_x)
print(f"MPSx cores: " + ', '.join([str(m.shape) for m in mps_x]))
print(f"MPSy cores: " + ', '.join([str(m.shape) for m in mps_y]))


MPSx cores: torch.Size([2, 2, 2, 2]), torch.Size([2, 2, 2, 2])
MPSy cores: torch.Size([2, 2, 2, 2]), torch.Size([2, 2, 2, 2])
