In [1]:
import torch as tn
import torch.nn as nn
import torchtt as tntt
from torch import optim
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

device = tn.device('cuda' if tn.cuda.is_available() else 'cpu')

In [2]:
train_data = datasets.MNIST(root = 'data', train = True, transform = ToTensor(), download = True)
test_data = datasets.MNIST(root = 'data', train = False, transform = ToTensor())

In [3]:
dataloader_train = tn.utils.data.DataLoader(train_data, batch_size=1000, shuffle=True, num_workers=1)
dataloader_test = tn.utils.data.DataLoader(test_data, batch_size=100, shuffle=True, num_workers=1)

In [4]:

class BasicTT(nn.Module):
    def __init__(self):
        super().__init__()
        self.ttl1 = tntt.nn.LinearLayerTT([1,7,4,7,4], [8,10,10,10,10], [1,4,2,2,2,1])
        self.ttl2 = tntt.nn.LinearLayerTT([8,10,10,10,10], [8,3,3,3,3], [1,2,2,2,2,1])
        self.linear = nn.Linear(81*8, 10, dtype = tn.float32)
        self.logsoftmax = nn.LogSoftmax(1)

    def forward(self, x):
        x = self.ttl1(x)
        x = tn.relu(x)
        x = self.ttl2(x)
        x = tn.relu(x)
        x = x.view(-1,81*8)
        x = self.linear(x)
        return self.logsoftmax(x)



In [5]:
model = BasicTT().to(device)
loss_function = nn.CrossEntropyLoss()   
optimizer = optim.Adam(model.parameters(), lr = 0.001)   

In [6]:
n_epochs = 30
 
for epoch in range(n_epochs):
    
    for i,(input,label) in enumerate(dataloader_train):
        
        input = tn.reshape(input.to(device),[-1,1,7,4,7,4])
        label = label.to(device)
        
        optimizer.zero_grad()
        output = model(input)
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        print('Epoch %d/%d iteration %d/%d loss %e'%(epoch+1,n_epochs,i+1,len(dataloader_train),loss))
        
        

Epoch 1/30 iteration 1/60 loss 2.296816e+00
Epoch 1/30 iteration 2/60 loss 2.264281e+00
Epoch 1/30 iteration 3/60 loss 2.238789e+00
Epoch 1/30 iteration 4/60 loss 2.214024e+00
Epoch 1/30 iteration 5/60 loss 2.176440e+00
Epoch 1/30 iteration 6/60 loss 2.150805e+00
Epoch 1/30 iteration 7/60 loss 2.117272e+00
Epoch 1/30 iteration 8/60 loss 2.078816e+00
Epoch 1/30 iteration 9/60 loss 2.054431e+00
Epoch 1/30 iteration 10/60 loss 2.013903e+00
Epoch 1/30 iteration 11/60 loss 1.967319e+00
Epoch 1/30 iteration 12/60 loss 1.929525e+00
Epoch 1/30 iteration 13/60 loss 1.872452e+00
Epoch 1/30 iteration 14/60 loss 1.838095e+00
Epoch 1/30 iteration 15/60 loss 1.798223e+00
Epoch 1/30 iteration 16/60 loss 1.765143e+00
Epoch 1/30 iteration 17/60 loss 1.700099e+00
Epoch 1/30 iteration 18/60 loss 1.637113e+00
Epoch 1/30 iteration 19/60 loss 1.578486e+00
Epoch 1/30 iteration 20/60 loss 1.514828e+00
Epoch 1/30 iteration 21/60 loss 1.470776e+00
Epoch 1/30 iteration 22/60 loss 1.418589e+00
Epoch 1/30 iteratio

In [7]:
n_correct = 0
n_total = 0
for (input,label) in dataloader_test:
    input = tn.reshape(input.to(device),[-1,1,7,4,7,4])
        
    output = model(input).cpu()
    
    n_correct += tn.sum(tn.max(output,1)[1] == label)   
    
    n_total += input.shape[0]
    
print('Test accuracy ',n_correct/n_total)


Test accuracy  tensor(0.9759)
