In [5]:
import numpy as np
import torch
#import torchvision
#import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline
import math
import os
from tnn_utils import *

## MNIST data import 

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_trainset = torchvision.datasets.MNIST(root="data", train=True, transform=transform)
mnist_testset = torchvision.datasets.MNIST(root="data", train=False, transform=transform)

In [None]:
n_row, n_column = 4, 8
img_show(mnist_trainset, n_row, n_column)

## Building blocks of the Model

### Forward
\begin{align*}
&f^{l}(x) = \frac{1}{m}\sum_{i}{[A^{s_{1}}_{\alpha_{0}\alpha_{1}}\cdots A^{s_{n}}_{\alpha_{n-1}\alpha_{n}}\Phi^{s_{1}s_{2}\cdots s_{n}}(x^{i})]}\\
&L = \frac{1}{m}\sum_{i}{[f^{l}(x_{i}) - \delta^{l}_{n}]^{2}}
\end{align*}

In [6]:
def tnn_cell_forward_right(X, Y, parameters):
    m, n = X.shape
    tensors, l = parameters['tensors'], parameters['l']
    psi = torch.ones([m, 1, 1])
    assert l < n-1, "invalid l!"
    for site in range(l):
        psi = psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
    left_psi = psi
    psi = psi @ tensors[l][:, X[:, l], :, :].permute(2, 1, 0, 3) @ tensors[l+1][:, X[:, l+1], :].permute(1, 0, 2)
    if l == n - 2:
        right_psi = torch.ones([m, 1, 1])
    else:
        right_psi = torch.ones([m, tensors[l+1].shape[2], tensors[l+1].shape[2]])
        for site in range(l+2, n):
            psi = psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
            right_psi = right_psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
    cache = (left_psi, right_psi, psi)
    loss = torch.sum(torch.pow(psi.squeeze()-Y.transpose(0, 1), 2)) / m
    return loss, cache

In [7]:
def tnn_cell_forward_left(X, Y, parameters):
    m, n = X.shape
    tensors, l = parameters['tensors'], parameters['l']
    psi = torch.ones([m, 1, 1])
    assert l > 0, "invalid l!"
    for site in range(l-1):
        psi = psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
    left_psi = psi
    psi = psi @ tensors[l-1][:, X[:, l-1], :].permute(1, 0, 2) @ tensors[l][:, X[:, l], :, :].permute(2, 1, 0, 3)
    if l == n - 1:
        right_psi = torch.ones([m, 1, 1])
    else:
        right_psi = torch.ones([m, tensors[l].shape[3], tensors[l].shape[3]])
        for site in range(l+1, n):
            psi = psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
            right_psi = right_psi @ tensors[site][:, X[:, site], :].permute(1, 0, 2)
    cache = (left_psi, right_psi, psi)
    loss = torch.sum(torch.pow(psi.squeeze()-Y.transpose(0, 1), 2)) / m
    return loss, cache

In [90]:
#test for tnn_cell_forward_right() module
torch.manual_seed(1)
m, n, l, index, Dmax = 5, 10, 8, 10, 10
parameters = {'m': m, 'n': n, 'Dmax': Dmax, 'index': index, 'l': l}
X = torch.randint(2, (m, n))
Y = np.random.binomial(index, 0.5, m)
Y_onehot = one_hot(Y, index)
parameters = tensor_initialize(parameters)
loss, cache = tnn_cell_forward_right(X, Y_onehot, parameters)
(left_psi, right_psi, psi) = cache
print("number of sample in one batch is: ", m)
print("number of features of one sample is: ", n)
print("output shape is: ", psi.shape)
print("shape of left tensor of Al is: ", left_psi.shape)
print("shape of right tensor of Al is: ", right_psi.shape)
print("the loss is: ", loss)
print("the psi is: ", psi.squeeze())

number of sample in one batch is:  5
number of features of one sample is:  10
output shape is:  torch.Size([10, 5, 1, 1])
shape of left tensor of Al is:  torch.Size([5, 1, 10])
shape of right tensor of Al is:  torch.Size([5, 1, 1])
the loss is:  tensor(1.0000)
the psi is:  tensor([[ 6.7132e-07, -7.9290e-08,  2.3963e-07, -1.1292e-07, -1.1057e-07],
        [-3.3133e-07,  1.5196e-08, -1.3075e-07, -1.7225e-07, -1.3029e-07],
        [ 4.3088e-07, -3.0903e-08, -3.3417e-08, -1.0526e-07, -1.5662e-07],
        [ 8.2614e-08,  2.3404e-09,  3.6240e-08, -5.2079e-08, -1.9988e-07],
        [-5.4127e-07,  9.4454e-08,  8.8346e-09, -8.7452e-08, -4.7404e-08],
        [ 1.5456e-09, -8.6599e-08, -1.3526e-07, -1.3640e-07, -2.3207e-08],
        [ 2.1806e-07, -5.2565e-08, -3.4004e-08,  5.8837e-07,  5.6295e-07],
        [-1.5755e-07,  2.8022e-08,  2.6864e-09,  6.8115e-07,  2.8664e-07],
        [-1.3863e-07,  8.6743e-08, -6.8020e-08,  8.0445e-09, -1.7507e-07],
        [-2.0160e-07,  9.4974e-09,  4.2146e-08,  7.

In [94]:
#test for tnn_cell_forward_left() module
torch.manual_seed(1)
m, n, l, index, Dmax = 5, 15, 8, 10, 8
parameters = {'m': m, 'n': n, 'Dmax': Dmax, 'index': 10, 'l': l}
parameters = tensor_initialize(parameters)
X = torch.randint(1, (m, n))
loss, cache = tnn_cell_forward_left(X, Y_onehot, parameters)
(left_psi, right_psi, psi) = cache
print("number of sample in one batch is: ", m)
print("number of features of one sample is: ", n)
print("output shape is: ", psi.shape)
print("shape of left tensor of Al is: ", left_psi.shape)
print("shape of right tensor of Al is: ", right_psi.shape)
print("the loss is: ", loss)

number of sample in one batch is:  5
number of features of one sample is:  15
output shape is:  torch.Size([10, 5, 1, 1])
shape of left tensor of Al is:  torch.Size([5, 1, 8])
shape of right tensor of Al is:  torch.Size([5, 8, 1])
the loss is:  tensor(1.)


### Backward
\begin{align*}
&\frac{\partial{L}}{\partial{f^{l}(x^{(i)})}} = 2(f^{l}(x_{i}) - \delta^{l}_{n})\\
&\frac{\partial{f^{l}(x^{(i)})}}{\partial{B^{l,s_{k}s_{k+1}}_{\alpha_{k-1}\alpha_{k+1}}}} = \tilde{A}^{s_{1}s_{2}\cdots s_{k-1}}_{\alpha_{k-1}}\tilde{B}^{s_{k+2}\cdots s_{n}}_{\alpha_{k+1}}\Phi^{s_{1}s_{2}\cdots s_{n}}(x^{(i)})
\end{align*}

In [24]:
def tnn_cell_backward_right(X, Y, cache, parameters):
    """
    Calculate the gradients of the loss function wrt mps.
    
    Arguments:
    Y -- labels of the samples, pytorch tensor with dimension (m, index)
    
    Return:
    dAl -- gradient of the tensor, pytorch tensor with dimension (m, index, alpha_lf, alpha_rig)
    
    
    """
    left_psi, right_psi, psi = cache
    l =  parameters['l']
    left_index, right_index = left_psi.shape[-1], right_psi.shape[-2]
    m = X.shape[0]
    index = parameters['index']
    dfl = 2 * (psi - Y.reshape_as(psi))
    dBl_m = dfl * (left_psi.transpose(1, 2) @ right_psi.transpose(1, 2))
    index1, index2 = (X[:, l] == 0).view(1, -1, 1, 1), (X[:, l+1] == 0).view(1, -1, 1, 1)
    index3, index4 = 1 - index1, 1 - index2
    dBl_1 = (dBl_m * (index1 * index2).type(torch.float)).mean(1, True)
    dBl_2 = (dBl_m * (index1 * index4).type(torch.float)).mean(1, True)
    dBl_3 = (dBl_m * (index3 * index2).type(torch.float)).mean(1, True)
    dBl_4 = (dBl_m * (index3 * index4).type(torch.float)).mean(1, True)
    dBl = torch.cat((dBl_1, dBl_2, dBl_3, dBl_4), 1).reshape(index, 2, 2, left_index, right_index)
    
    gradients = {'dBl': dBl}
    return gradients

In [25]:
def tnn_cell_backward_left(X, Y, cache, parameters):
    """
    Calculate the gradients of the loss function wrt mps.
    
    Arguments:
    Y -- labels of the samples, pytorch tensor with dimension (m, index)
    
    Return:
    dAl -- gradient of the tensor, pytorch tensor with dimension (m, index, alpha_lf, alpha_rig)
    
    
    """
    left_psi, right_psi, psi = cache
    l =  parameters['l']
    left_index, right_index = left_psi.shape[-1], right_psi.shape[-2]
    m = X.shape[0]
    index = parameters['index']
    dfl = 2 * (psi - Y.reshape_as(psi))
    dBl_m = dfl * (left_psi.transpose(1, 2) @ right_psi.transpose(1, 2))
    index1, index2 = (X[:, l-1] == 0).view(1, -1, 1, 1), (X[:, l] == 0).view(1, -1, 1, 1)
    index3, index4 = 1 - index1, 1 - index2
    dBl_1 = (dBl_m * (index1 * index2).type(torch.float)).mean(1, True)
    dBl_2 = (dBl_m * (index1 * index4).type(torch.float)).mean(1, True)
    dBl_3 = (dBl_m * (index3 * index2).type(torch.float)).mean(1, True)
    dBl_4 = (dBl_m * (index3 * index4).type(torch.float)).mean(1, True)
    dBl = torch.cat((dBl_1, dBl_2, dBl_3, dBl_4), 1).reshape(index, 2, 2, left_index, right_index)
    
    gradients = {'dBl': dBl}
    return gradients

In [27]:
#test for tnn_cell_backward_right()
gradients = tnn_cell_backward_right(X, Y_onehot, cache, parameters)
dBl = gradients['dBl']
print("shape of gradients of tensor Bl: ", dBl.shape)
print("gradients: ", dBl.squeeze())

shape of gradients of tensor Bl:  torch.Size([10, 2, 2, 5, 5])
gradients:  tensor([[[[[ 1.4910e-10,  1.4910e-10,  1.4910e-10,  1.4910e-10,  1.4910e-10],
           [-1.5182e-10, -1.5182e-10, -1.5182e-10, -1.5182e-10, -1.5182e-10],
           [-5.6894e-10, -5.6894e-10, -5.6894e-10, -5.6894e-10, -5.6894e-10],
           [ 4.7469e-10,  4.7469e-10,  4.7469e-10,  4.7469e-10,  4.7469e-10],
           [-6.4108e-11, -6.4108e-11, -6.4108e-11, -6.4108e-11, -6.4108e-11]],

          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


         [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+0

In [95]:
#test for tnn_cell_backward_left()
gradients = tnn_cell_backward_left(X, Y_onehot, cache, parameters)
dBl = gradients['dBl']
print("shape of gradients of tensor Bl: ", dBl.shape)
print("gradients: ", dBl.squeeze())

shape of gradients of tensor Bl:  torch.Size([10, 2, 2, 8, 8])
gradients:  tensor([[[[[-9.0759e-17, -9.0759e-17, -9.0759e-17,  ..., -9.0759e-17,
            -9.0759e-17, -9.0759e-17],
           [-3.4739e-17, -3.4739e-17, -3.4739e-17,  ..., -3.4739e-17,
            -3.4739e-17, -3.4739e-17],
           [ 7.1288e-17,  7.1288e-17,  7.1288e-17,  ...,  7.1288e-17,
             7.1288e-17,  7.1288e-17],
           ...,
           [ 5.8770e-17,  5.8770e-17,  5.8770e-17,  ...,  5.8770e-17,
             5.8770e-17,  5.8770e-17],
           [-6.8740e-17, -6.8740e-17, -6.8740e-17,  ..., -6.8740e-17,
            -6.8740e-17, -6.8740e-17],
           [ 1.7138e-18,  1.7138e-18,  1.7138e-18,  ...,  1.7138e-18,
             1.7138e-18,  1.7138e-18]],

          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+0

\begin{align*}
\Delta B^{l,s_{k-1}s_{k}}_{\alpha_{k-1}\alpha_{k+1}}\to
\begin{bmatrix}
B^{l,x_{m}}_{\alpha_{k-1}\alpha_{k+1}}\\
\cdots\\
B^{l,x_{n}}_{\alpha_{k-1}\alpha_{k+1}}
\end{bmatrix}
\end{align*}

In [127]:
def update_parameters_right(X, parameters, gradients, bond, learning_rate):
    
    m, n, index, tensors, l = parameters['m'], parameters['n'], parameters['index'], parameters['tensors'], parameters['l']
    assert l < n, "invalid l"
    dBl = gradients['dBl']
    Bl = tensors[l].permute(2, 1, 0, 3).unsqueeze(2) @ tensors[l+1].permute(1, 0, 2).unsqueeze(0)
    Bl +=  dBl
    alpha_left, alpha_right = dBl.shape[-2], dBl.shape[-1]
    u, s, vt = torch.svd(Bl.permute(1, 3, 0, 2, 4).contiguous().view(2*alpha_left, 2*alpha_right*index))
    if alpha_left * 2 > bond:
        tensors[l] = ((u * s).split(bond, 1)[0]).reshape(2, alpha_left, bond).permute(1, 0, 2)
        tensors[l+1] = (((vt.transpose(0, 1)).split(bond, 0))[0]).reshape(bond, index, 2, alpha_right).permute(0, 2, 1, 3)
    else:    
        tensors[l] = (u * s).reshape(2, alpha_left, -1).permute(1, 0, 2)
        tensors[l+1] = (vt.transpose(0, 1)).reshape(-1, index, 2, alpha_right).permute(0, 2, 1, 3)
    parameters['l'] = l + 1      
    
    return parameters  

In [128]:
def update_parameters_left(X, parameters, gradients, bond, learning_rate):
    
    m, n, index, tensors, l = parameters['m'], parameters['n'], parameters['index'], parameters['tensors'], parameters['l']
    assert l > 0, "invalid l"
    dBl = gradients['dBl']
    Bl = tensors[l-1].permute(1, 0, 2).unsqueeze(1) @ tensors[l].permute(2, 1, 0, 3).unsqueeze(1)
    Bl +=  dBl
    alpha_left, alpha_right = dBl.shape[-2], dBl.shape[-1]        
    u, s, vt = torch.svd(Bl.permute(1, 3, 0, 2, 4).contiguous().view(2*alpha_left*index, 2*alpha_right))
    if alpha_right * 2 > bond:
        tensors[l-1] = ((u * s).split(bond, 1)[0]).reshape(2, alpha_left, index, bond).permute(1, 0, 2, 3)
        tensors[l] = ((vt.transpose(0, 1).split(bond, 0))[0]).reshape(bond, 2, alpha_right)
    else:
        tensors[l-1] = (u * s).reshape(2, alpha_left, index, -1).permute(1, 0, 2, 3)
        tensors[l] = vt.transpose(0, 1).reshape(-1, 2, alpha_right)
    parameters['l'] = l - 1
    
    return parameters  

In [129]:
#test for update_parameters_right()
torch.manual_seed(1)
m, n, l, index, Dmax = 5, 10, 0, 10, 15
parameters = {'m': m, 'n': n, 'Dmax': Dmax, 'index': 10, 'l': l}
parameters = tensor_initialize(parameters)
X = torch.randint(1, (m, n))
Y = np.random.binomial(index, 0.5, m)
Y_onehot = one_hot(Y, index)
loss, cache = tnn_cell_forward_right(X, Y_onehot, parameters)
(left_psi, right_psi, psi) = cache
gradients = tnn_cell_backward_right(X, Y_onehot, cache, parameters)
bond, learning_rate, n = 3, 0.01, parameters['n']
parameters = update_parameters_right(X, parameters, gradients, bond, learning_rate)
tensors = parameters['tensors']
print("shape of updated A^{l} tensor: ", tensors[l].shape)
print("shape updated A^{l+1} tensor: ", tensors[l+1].shape)
for i in range(n):
    print(tensors[i].shape)
print("updated A^{l} tensor: ", tensors[l])

shape of updated A^{l} tensor:  torch.Size([1, 2, 2])
shape updated A^{l+1} tensor:  torch.Size([2, 2, 10, 15])
torch.Size([1, 2, 2])
torch.Size([2, 2, 10, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 1])
updated A^{l} tensor:  tensor([[[-0.1756,  0.0945],
         [ 0.1110,  0.1495]]])


In [130]:
#test for update_parameters_left()
torch.manual_seed(1)
m, n, l, index, Dmax = 5, 10, 9, 10, 15
parameters = {'m': m, 'n': n, 'Dmax': Dmax, 'index': 10, 'l': l}
parameters = tensor_initialize(parameters)
X = torch.randint(1, (m, n))
Y = np.random.binomial(index, 0.5, m)
Y_onehot = one_hot(Y, index)
loss, cache = tnn_cell_forward_left(X, Y_onehot, parameters)
(left_psi, right_psi, psi) = cache
gradients = tnn_cell_backward_left(X, Y_onehot, cache, parameters)
bond, learning_rate, n = 10, 0.01, parameters['n']
parameters = update_parameters_left(X, parameters, gradients, bond, learning_rate)
tensors = parameters['tensors']
print("shape of updated A^{l} tensor: ", tensors[l].shape)
print("updated A^{l-1} tensor: ", tensors[l-1].shape)
for i in range(n):
    print(tensors[i].shape)
print("updated A^{l-1} tensor: ", tensors[l-1])

shape of updated A^{l} tensor:  torch.Size([2, 2, 1])
updated A^{l-1} tensor:  torch.Size([15, 2, 10, 2])
torch.Size([1, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 15])
torch.Size([15, 2, 10, 2])
torch.Size([2, 2, 1])
updated A^{l-1} tensor:  tensor([[[[-1.2375e-02,  1.3121e-02],
          [-1.1876e-02, -5.6852e-03],
          [ 9.3243e-03,  6.9387e-03],
          [-8.5409e-03, -2.1884e-04],
          [-5.4437e-03, -6.4340e-03],
          [-1.4182e-02, -7.6442e-03],
          [ 2.5098e-02,  1.0793e-02],
          [ 9.8194e-03, -5.4158e-03],
          [-2.5233e-04,  1.5520e-02],
          [-3.7443e-03,  1.7248e-03]],

         [[ 2.1381e-03, -6.1619e-03],
          [-3.8698e-03, -5.7639e-03],
          [ 2.4331e-03,  7.5532e-03],
          [-6.8991e-03, -1.5349e-02],
          [-1.5763e-03, -7.5634e-03],
          [-2.0055e-02,  9.8211e-03],
          [-2.5412e

In [131]:
def optimize(X, Y, parameters, bond, learning_rate=0.01):
    
    n, l, Dmax = parameters['n'], parameters['l'], parameters['Dmax']
    parameters = tensor_initialize(parameters)
    for i in range(l, n-1):
        loss, cache = tnn_cell_forward_right(X, Y, parameters)
        gradients = tnn_cell_backward_right(X, Y, cache, parameters)
        parameters = update_parameters_right(X, parameters, gradients, bond, learning_rate)
    for i in range(n, 1, -1):
        loss, cache = tnn_cell_forward_left(X, Y, parameters)
        gradients = tnn_cell_backward_left(X, Y, cache, parameters)
        parameters = update_parameters_left(X, parameters, gradients, bond, learning_rate)
        
    return parameters, loss, gradients

In [147]:
#test for optimize()
torch.manual_seed(1)
m, n, l, index, Dmax, bond = 15, 25, 2, 10, 10, 8
X = torch.randint(1, (m, n))
Y = np.random.binomial(index, 0.5, m)
Y_onehot = one_hot(Y, index)
parameters = {'m': m, 'n': n, 'Dmax': Dmax, 'index': index, 'l': l}
parameters, loss, gradients = optimize(X, Y_onehot, parameters, bond, learning_rate)
n, tensors = parameters['n'], parameters['tensors']
dAl = gradients['dBl']
for i in range(n):
    print(parameters['tensors'][i].shape)
print(dAl.shape)
print(dAl)

torch.Size([1, 2, 10, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 8])
torch.Size([8, 2, 4])
torch.Size([4, 2, 2])
torch.Size([2, 2, 1])
torch.Size([10, 2, 2, 1, 8])
tensor([[[[[0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003]],

          [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],


         [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]],



        [[[[0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005]],

          [[0.00