## Tensorized Neural Networks (TNN)

Computes $f_n \circ \cdots \circ f_0 (x)$, where $x$ is an MPS tensor and each $f_i$ is an MPO followed by a rank dropout and non-linearity 

Next steps:
- Add regression and classification heads


In [24]:
from dataclasses import dataclass
from typing import List, Union, Optional
import torch

def mpo_contract(
    mpo: Union[List[torch.Tensor], torch.nn.ParameterList], 
    mps: Union[List[torch.Tensor], torch.Tensor]
    ) -> Union[List[torch.Tensor], torch.Tensor]: # returns a new mps
    """Perform a tensor network contraction of an MPO and an MPS. Returns a new MPS.

    Args:
        mpo (List[torch.Tensor]): A list of tensors representing the MPO. Shape: (Rl, Di, Do, Rr)
        mps (List[torch.Tensor]): A list of tensors representing the MPS. Shape: (B, Rl, Di, Rr)

    Returns:
        List[torch.Tensor]: A list of tensors representing the new MPS. Shape: (B, Do, Rl, Rr)
    """
    out = []
    for i in range(len(mpo)):
        mps_prime = torch.einsum('rios,bpiq->brposq', mps[i], mpo[i])
        B, R, P, Do, S, Q = mps_prime.shape
        out.append(mps_prime.reshape(B, R*P, Do, S*Q))
    return out


def mps_add_single_core(mps_a: torch.Tensor, mps_b: torch.Tensor) -> torch.Tensor:
    # mps_a (Rl, D, Rr), mps_b (Rl, D, Rr)
    mps_c_top = torch.cat([mps_a, torch.zeros_like(mps_b)], dim=-1)  # (Rl, D, Rr+Rr)
    mps_c_bottom = torch.cat([torch.zeros_like(mps_a), mps_b], dim=-1)  # (Rl, D, Rr+Rr)
    mps_c = torch.cat([mps_c_top, mps_c_bottom], dim=0)  # (2Rl, D, Rr+Rr)
    return mps_c

def mps_norm(mps: List[torch.Tensor]) -> torch.Tensor:
    return torch.norm(torch.cat([mps[0], mps[-1]], dim=0))

mps_add_single_core_batch = torch.vmap(mps_add_single_core, in_dims=(0, 0))


def mps_add(mps_a: List[torch.Tensor], mps_b: List[torch.Tensor]) -> List[torch.Tensor]:
    mps_c = []
    # mps_a[0] (B, 1, D, Rr), mps_b[0] (B, 1, D, Rr)
    mps_c.append(torch.cat([mps_a[0], mps_b[0]], dim=-1))  # (B, 1, D, Rr+Rr)
    for i in range(1, len(mps_a) - 1):
        mps_c.append(mps_add_single_core(mps_a[i], mps_b[i]))
    mps_c.append(torch.cat([mps_a[-1], mps_b[-1]], dim=1))  # (B, Rl+Rl, D, 1)
    return mps_c


def mps_norm(
    g: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    use_scale_factors: bool = True,
    norm: str = "l2",
    **kwargs
):
    """Marginalize a Born MPS tensor.

    Args:
        g (torch.Tensor): g tensor. Shape: (N, R, D, R)
        a (torch.Tensor): a tensor. Shape: (N, R)
        b (torch.Tensor): b tensor. Shape: (N, R)

    Returns:
        torch.Tensor: Marginalized tensor. Shape: (1,)
    """
    H, _, _, _ = g.shape
    scale_factors = []
    norm_fn = {
        "l2": torch.linalg.norm,
        "linf": torch.amax,
    }[norm]
    L = torch.einsum("p,pdq,r,rds->qs", a, g[0], a, g[0])
    for h in range(1, H):
        L = torch.einsum("pdq,pr,rds ->qs", g[h], L, g[h])
        if use_scale_factors:
            sf = norm_fn(L.abs())
            scale_factors.append(sf)
            L = L / sf
    L = torch.einsum("pq,p,q->", L, b, b)
    if not use_scale_factors:
        scale_factors = [torch.tensor(1.0)]
    return L, torch.stack(scale_factors)  # (1,), (N,)
mps_norm_batch = torch.vmap(mps_norm, in_dims=(0, 0, 0))

class MPO(torch.nn.Module):
    def __init__(self, in_features: List[int], out_features: List[int], ranks: List[int], max_rank: int = 2):
        super(MPO, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ranks = ranks
        self.max_rank = max_rank
        self.g = torch.nn.ParameterList([
            torch.nn.Parameter(torch.randn(ranks[i], in_features[i], out_features[i], ranks[i+1]))
            for i in range(len(in_features))  # (Rk, Ik, Ok, Rk+1)
        ])

    def forward(self, mps_x: List[torch.Tensor], mps_y: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        # mps_x:[(B, Rl, Di, Rr), ...], mps_y:[(B, Rl, Do, Rr), ...]

        # MPO x MPS
        mps_y_hat = mpo_contract(self.g, mps_x)  

        # Rank dropout
        mps_y_hat_reduced = []
        for i in range(len(mps_y_hat)):
            # print(mps_y_hat[i].shape, self.max_rank)
            Rl = torch.randint(0, mps_y_hat[i].shape[1], (self.max_rank,))  # (B, Rl, Do, Rr)
            Rr = torch.randint(0, mps_y_hat[i].shape[3], (self.max_rank,))
            # print(Rl, Rr)
            mps_y_hat_reduced.append(mps_y_hat[i][:, Rl][:, :, :, Rr])

        # Non-linearity
        mps_y_hat_reduced = [torch.relu(mps_y_hat_reduced[i]) for i in range(len(mps_y_hat_reduced))]
        
        return mps_y_hat_reduced


@dataclass
class TNNOutput:
    mps_y_hat: List[torch.Tensor]
    loss: Optional[torch.Tensor] = None

class TNN(torch.nn.Module):
    def __init__(self, in_features: List[int], out_features: List[int], n_layers: int=4, max_rank: int = 2):
        super(TNN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_layers = n_layers
        self.max_rank = max_rank
        self.layers = torch.nn.ModuleList([
            MPO(
                in_features=in_features, 
                out_features=out_features, 
                ranks=[2] * (n_layers + 1), 
                max_rank=max_rank
            ) for _ in range(n_layers)
        ])
    
    def forward(self, mps_x: List[torch.Tensor], mps_y: Optional[List[torch.Tensor]] = None):
        mps_y_hat = mps_x
        for i, layer in enumerate(self.layers):
            mps_y_hat = layer(mps_y_hat, mps_y)

        if mps_y is not None:
            mps_y_hat_ = [-m for m in mps_y_hat]
            mps_y_hat_ = mps_add(mps_y, mps_y_hat_)  #[(B, Rl, Do, Rr), ...]
            loss, _ = mps_norm_batch(torch.stack(mps_y_hat_[1:-1], dim=1), mps_y_hat_[0], mps_y_hat_[-1])
        return TNNOutput(mps_y_hat=mps_y_hat)


In [28]:
# ----------------------------------------------------------
# Test TNN (Uses 8 MPO layers w/ non-linearities
# ----------------------------------------------------------

# Hyperparameters
batch_size, n_layers, max_rank, = 2, 8, 2
in_features, out_features, ranks = [2, 2], [2, 2], [2, 2, 2]

# Init model
mps_x = [torch.randn(batch_size, ranks[i], in_features[i], ranks[i+1]) for i in range(len(in_features))]
ttn = TNN(in_features=in_features, out_features=out_features, n_layers=n_layers, max_rank=max_rank)

# Forward pass
output = ttn(mps_x, mps_x)
print(f"MPSx cores: " + ', '.join([str(m.shape) for m in output.mps_y_hat]))
print(f"MPSy cores: " + ', '.join([str(m.shape) for m in output.mps_y_hat]))


RuntimeError: stack expects a non-empty TensorList

In [20]:
B, Rl, D, Rr = 8, 2, 2, 2
mps_a = torch.randn(B, Rl, D, Rr)
mps_b = torch.randn(B, Rl, D, Rr)
mps_add_single_core_batch(mps_a, mps_b).shape

torch.Size([8, 4, 2, 4])

In [None]:
# B, N, Rl, D, Rr = 8, 2, 2, 2, 2
mps_a = [torch.randn(B, Rl, D, Rr) for _ in range(N)]
# mps_b = [torch.randn(B, Rl, D, Rr) for _ in range(N)]
# mps_c = mps_add(mps_a, mps_b)

mps_y_hat_ = mps_add(mps_a, mps_a)  #[(B, Rl, Do, Rr), ...]
loss, _ = mps_norm_batch(torch.stack(mps_y_hat_[1:-1], dim=1), mps_y_hat_[0], mps_y_hat_[-1])