# K-MNIST experiment for FF models

## Imports

In [None]:
from models import ff_eucl, ff_hyp
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import geoopt
from time import time
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
torch.cuda.is_available()

#Disable Debugging APIs
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

#cuDNN Autotuner
torch.backends.cudnn.benchmark = True

## CUDA check

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Data Transformation

In [None]:
transform = transforms.Compose([transforms.ToTensor() 
                              ])

## Training, validation and test data

In [None]:
train_set = datasets.KMNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
test_set = datasets.KMNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)

size = len(train_set)
print(size)

train_data, val_data = torch.utils.data.random_split(train_set, [int(size-size*0.2), int(size*0.2)])
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=6, pin_memory=True)
valloader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle= True, num_workers=6, pin_memory=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle= True, num_workers=6, pin_memory=True)

In [None]:
def get_num_correct(preds, labels):
    """
    Single Prediction function
    """
    return preds.argmax(dim=1).eq(labels).sum().item()

def train_epoch(model, dataloader, optimizer, criterion):
    """
    Model training function
    """
    model.train()
    train_loss = 0
    total_correct = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        # Flatten MNIST images into a 784 long vector
        images = images.view(images.shape[0], -1)
        # Training pass
        optimizer.zero_grad()
        # for param in model.parameters():
        #     param.grad = None

        output = model(images)
        loss = criterion(output, labels)  
        train_loss += loss.item()
        total_correct += get_num_correct(output, labels)
        #backpropagation
        loss.backward()      
        #Weight optimization
        optimizer.step()  

    return train_loss/len(dataloader.dataset), total_correct

### Validation function
def val_epoch(model, dataloader, criterion):
    """
    Model validation function
    """
    model.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for  images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            # Flatten MNIST images into a 784 long vector
            images = images.view(images.shape[0], -1)
            #images = ball.projx(images.view(images.shape[0], -1))
            output = model(images)
            loss = criterion(output, labels)  
            val_loss += loss.item()
            val_correct += get_num_correct(output, labels)
    
    return val_loss/len(dataloader.dataset), val_correct


def model_eval(model, epochs, trainloader, valloader, optimizer, criterion):
    """
    Function for model evaluation
    """
    tb = SummaryWriter()
    t_loss = []
    v_loss = []
    t_accuracy = []
    v_accuracy = []
    epoch_values = []
    for epoch in range(epochs):
        train_loss, total_correct = train_epoch(model, trainloader, optimizer, criterion) 
        t_loss.append(train_loss)
        t_accuracy.append(total_correct/len(train_data))
        val_loss, val_correct = val_epoch(model, valloader, criterion)
        v_loss.append(val_loss)
        v_accuracy.append(val_correct/len(val_data))
        epoch_values.append(epoch)
        
        tb.add_scalar("Training Loss", train_loss, epoch)
        tb.add_scalar("Validation Loss", val_loss, epoch)
        tb.add_scalar("Training Accuracy", total_correct/len(train_data), epoch)
        tb.add_scalar("Validation Accuracy", val_correct/len(val_data), epoch)
        print("epoch:", epoch, "training loss:",train_loss, "validation loss:", val_loss,
        "training accuracy:", total_correct/len(train_data), "validation accuracy:", val_correct/len(val_data))


    return t_loss, v_loss, t_accuracy, v_accuracy, epoch_values

## Initialize and train model

In [None]:
#model = ff_eucl.EuclFF(784, 512, 256, 10, nn.ReLU())
model = ff_hyp.HypFF(784, 512, 256, 10, nn.ReLU())
print(model)

epochs = 10
#Hyperparameter tuning
#hparams_tune(epochs)
#Model evaluation
lr=0.01
criterion = torch.nn.CrossEntropyLoss()
#criterion = torch.nn.NLLLoss()

#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=lr, momentum=0.9)
t_loss ,v_loss, t_accuracy, v_accuracy, epoch_values = model_eval(model, epochs, trainloader, valloader, optimizer, criterion)

## Curve Plotting

In [None]:

fig , (ax0, ax1) = plt.subplots(1, 2)

# ax0 = fig.add_subplot(121, title="Loss curves")
# ax1 = fig.add_subplot(122, title="Accuracy curves")
ax0.set_title('Loss Curves')
ax1.set_title('Accuracy Curves')
ax0.plot(epoch_values, t_loss, 'bo-', label='train')
ax0.plot(epoch_values, v_loss, 'ro-', label='val')
ax1.plot(epoch_values, t_accuracy, 'bo-', label='train')
ax1.plot(epoch_values, v_accuracy, 'ro-', label='val')


ax1.yaxis.set_ticks(np.arange(0.7, 1.0, 0.02))
ax1.set_ylim(0.7, 1.0)

ax0.set_xlabel('Epochs')
ax1.set_xlabel('Epochs')
ax0.set_ylabel('Losses')
ax1.set_ylabel('Accuracy')
ax0.legend()
ax1.legend()

fig.suptitle('no. of epochs = {}, lr = {}, batch size = 64'.format(epochs, lr))
fig.tight_layout()