In [None]:
import torch as tn
from torchvision import datasets, transforms
import torchtt as tntt
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import datetime

In [None]:
device_name = 'cuda:0'
data_dir_test = 'seg_test/'
data_dir_train = 'seg_train/'
N_shape = [15,10]

In [None]:
transform_train = transforms.Compose([transforms.Resize(N_shape[0]*N_shape[1]), transforms.CenterCrop(N_shape[0]*N_shape[1]), transforms.ToTensor()]) #, transforms.Normalize(tn.tensor([0.4885, 0.4525, 0.4163]), tn.tensor([0.2549, 0.2476, 0.2495]))])
dataset_train = datasets.ImageFolder(data_dir_train, transform=transform_train)
dataloader_train = tn.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=True, pin_memory = True, num_workers = 16)

transform_test = transforms.Compose([transforms.Resize(N_shape[0]*N_shape[1]), transforms.CenterCrop(N_shape[0]*N_shape[1]), transforms.ToTensor()]) #, transforms.Normalize(tn.tensor([0.4885, 0.4525, 0.4163]), tn.tensor([0.2549, 0.2476, 0.2495])) ])
dataset_test = datasets.ImageFolder(data_dir_test, transform=transform_test)
dataloader_test = tn.utils.data.DataLoader(dataset_test, batch_size=32, shuffle=True, pin_memory = True, num_workers = 16)

# inputs_train = list(dataloader_train)[0][0].to(device_name)
# labels_train = list(dataloader_train)[0][1].to(device_name)
# 
# inputs_test = list(dataloader_test)[0][0].to(device_name)
# labels_test = list(dataloader_test)[0][1].to(device_name)

In [None]:
class BasicTT(nn.Module):
    def __init__(self):
        super().__init__()
        p = 0.5
        self.ttl1 = tntt.nn.LinearLayerTT([3]+N_shape+N_shape, [16]+N_shape+N_shape, [1,9,3,3,2,1], initializer = 'He')
        self.dropout1 = nn.Dropout(p)
        self.ttl2 = tntt.nn.LinearLayerTT([16]+N_shape+N_shape, [32,8,8,8,8], [1,8,4,3,2,1], initializer = 'He')
        self.dropout2 = nn.Dropout(p)
        self.ttl3 = tntt.nn.LinearLayerTT([32,8,8,8,8], [8,4,4,4,4], [1,4,4,4,4,1], initializer = 'He')
        self.dropout3 = nn.Dropout(p)
        self.ttl4 = tntt.nn.LinearLayerTT([8,4,4,4,4], [4,4,4,4,4], [1,2,2,2,2,1], initializer = 'He')
        self.dropout4 = nn.Dropout(p)
        self.ttl5 = tntt.nn.LinearLayerTT([4,4,4,4,4], [4,4,4,4,4], [1,2,2,2,2,1], initializer = 'He')
        self.dropout5 = nn.Dropout(p)
        self.ttl6 = tntt.nn.LinearLayerTT([4,4,4,4,4], [3,3,3,3,3], [1,3,3,3,3,1], initializer = 'He')
        self.dropout6 = nn.Dropout(p)
        self.linear = nn.Linear(3**5, 6, dtype = tn.float32)
        self.logsoftmax = nn.LogSoftmax(1)

    def forward(self, x):
        x = self.ttl1(x)
        x = self.dropout1(x)
        x = tn.relu(x)
        x = self.ttl2(x)
        x = self.dropout2(x)
        x = tn.relu(x)
        x = self.ttl3(x)
        x = self.dropout3(x)
        x = tn.relu(x)
        x = self.ttl4(x)
        x = self.dropout4(x)
        x = tn.relu(x)
        x = self.ttl5(x)
        x = self.dropout5(x)
        x = tn.relu(x)
        x = self.ttl6(x)
        x = self.dropout6(x)
        x = tn.relu(x)
        x = x.view(-1,3**5)
        x = self.linear(x)
        return self.logsoftmax(x)



In [None]:
model = BasicTT()        
model.to(device_name)

print('Number of parameters', sum(tn.numel(p) for p in model.parameters()))

optimizer = tn.optim.SGD(model.parameters(), lr=0.001, momentum=0.1)
# optimizer = tn.optim.Adam(model.parameters(), lr=0.005)
scheduler = tn.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

loss_function = tn.nn.CrossEntropyLoss()


In [None]:
def do_epoch(i):
    
    loss_total = 0.0
    n_total = 0
    n_correct = 0
    #tme = datetime.datetime.now()
    for k, data in enumerate(dataloader_train):
        #tme = datetime.datetime.now() - tme
        #print('t0',tme)
        #tme = datetime.datetime.now()
        inputs, labels = data[0].to(device_name), data[1].to(device_name)
        #tme = datetime.datetime.now() - tme
        #print('t1',tme)
        
        #tme = datetime.datetime.now()
        inputs = tn.reshape(inputs,[-1,3]+2*N_shape)
        #tme = datetime.datetime.now() - tme
        #print('t2',tme)
        
        #tme = datetime.datetime.now()
        optimizer.zero_grad()
        # Make predictions for this batch
        outputs = model(inputs)
        # Compute the loss and its gradients
        loss = loss_function(outputs, labels)
        
        # regularization
        #l2_lambda = 0.005
        #l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
        loss = loss#+l2_lambda*l2_norm
        
        loss.backward()
        # Adjust learning weights
        optimizer.step()
        #tme = datetime.datetime.now() - tme
        #print('t3',tme)
        n_correct += tn.sum(tn.max(outputs,1)[1] == labels).cpu()   
        n_total+=inputs.shape[0]
        
        loss_total += loss.item()
        # print('\t\tbatch %d error %e'%(k+1,loss))
        #tme = datetime.datetime.now()
        
    return loss_total/len(dataloader_train), n_correct/n_total

def test_loss():
    loss_total = 0 
    for data in dataloader_test:
        inputs, labels = data[0].to(device_name), data[1].to(device_name)
        inputs = tn.reshape(inputs,[-1,3]+2*N_shape)
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss_total += loss.item()
        
    return loss_total/len(dataloader_test)

def test_data():
    n_total = 0 
    n_correct = 0
    loss_total = 0
    
    for data in dataloader_test:
        inputs, labels = data[0].to(device_name), data[1].to(device_name)
        inputs = tn.reshape(inputs,[-1,3]+2*N_shape)
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss_total += loss.item()
        n_correct += tn.sum(tn.max(outputs,1)[1] == labels)   
        n_total+=inputs.shape[0]
        
    return loss_total/len(dataloader_test), n_correct/n_total

def train_accuracy():
    n_total = 0 
    n_correct = 0
    
    for data in dataloader_train:
        inputs, labels = data[0].to(device_name), data[1].to(device_name)
        inputs = tn.reshape(inputs,[-1,3]+2*N_shape)
        outputs = model(inputs)
        n_correct += tn.sum(tn.max(outputs,1)[1] == labels)   
        n_total+=inputs.shape[0]
        
    return n_correct/n_total

In [None]:
n_epochs = 150

history_test_accuracy = []
history_test_loss = []
history_train_accuracy = []
history_train_loss = []

for epoch in range(n_epochs):
    print('Epoch %d/%d'%(epoch+1,n_epochs))
    
    time_epoch = datetime.datetime.now()
    
    model.train(True)
    train_loss, train_acc = do_epoch(epoch)
    model.train(False)
    
    test_loss, test_acc = test_data()
    scheduler.step()
    
    time_epoch = datetime.datetime.now() - time_epoch
    
    print('\tTraining loss %e training accuracy %5.4f test loss %e test accuracy %5.4f'%(train_loss,train_acc,test_loss,test_acc))
    print('\tTime for the epoch',time_epoch)
    history_test_accuracy.append(test_acc)
    history_test_loss.append(test_loss)
    history_train_accuracy.append(train_acc)
    history_train_loss.append(train_loss)
    


In [None]:
plt.figure()
plt.plot(np.arange(len(history_train_accuracy))+1,np.array(history_train_accuracy))
plt.plot(np.arange(len(history_test_accuracy))+1,np.array(history_test_accuracy))
plt.legend(['training','test'])

plt.figure()
plt.plot(np.arange(len(history_train_loss))+1,np.array(history_train_loss))
plt.plot(np.arange(len(history_test_loss))+1,np.array(history_test_loss))
plt.legend(['training','test'])
    

In [None]:
max(history_test_accuracy)