# Code for training with validation

In [4]:
import torch
import copy
from sklearn.metrics import confusion_matrix
import time


def train_model(model, train_data, train_labels, val_data, val_labels, epochs, batch_size, loss_func, optimizer):
    
    #print("Epoch\t train loss\t validation acc\t test acc")
    best_model = copy.deepcopy(model.state_dict())
    best_model_epoch = 0
    best_val_accuracy = 0
    train_losses = []
    validation_losses = []
    validation_accuracies = []

    for epoch in range(epochs):
        
        epoch_train_loss = 0
        model.train() # Set model to train mode
        
        for i in range(len(train_data)//batch_size+1): # BATCH SIZE MUST BE EVEN DIVIDER OF DATA LEN, otherwise we miss stuff here
            start = i*batch_size
            end = (i+1)*batch_size

            train_inputs = train_data[start:end]
            train_truth = train_labels[start:end]
            train_outputs = model(train_inputs)

            loss = loss_func(train_outputs, train_truth)
            
            epoch_train_loss += loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        

        model.eval() # Set model to evaluation mode

        validation_losses.append(loss_func(model(val_data),val_labels).item())
        train_losses.append(loss_func(model(train_data),train_labels).item())
        
        val_accuracy = accuracy_check(model, val_data, val_labels)
        validation_accuracies.append(val_accuracy)

        #Check for new best model
        if val_accuracy >= best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_epoch = epoch
            best_model = copy.deepcopy(model.state_dict())
            #print("Chose new best model")
        #     #print(epoch, "\t ", epoch_train_loss.item()/(len(train_data)//batch_size), "\t", epoch_val_loss.item())
        #print("{0}\t{1:.2f}\t{2:.2f}\t{3:.2f}".format(epoch, trainloss, val_accuracy, test_accuracy))
        #print(epoch, "\t ", format(epoch_train_loss.item()/(len(train_data)//batch_size), ".2f"), "\t", format(val_accuracy, ".2f", "\t", format(test_accuracy,".2f"))

    plt.plot(train_losses, label="train loss")
    plt.plot(validation_losses, label="val loss")
    plt.plot(validation_accuracies, label="val acc")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    print("Best model found on epoch: ", best_model_epoch)
    model.load_state_dict(best_model) # Set model to best performing one.
    


def accuracy_check(network, data, labels):
    network.eval()

    # Accuracy check
    r = network(data)
    p = torch.max(r,1)[1]
    c = torch.sum(p == labels)
    acc = c.item()/len(p)
    #print("ACCURACY:", acc)
    return(acc)

def get_conf_matrix(model, data, labels):
    confusion_matrix = np.zeros((4,4))
    model.eval()
    prediction = model(data)
    for i in range(len(data)):
        guess = torch.argmax(prediction[i], dim=-1)
        confusion_matrix[labels[i].item(), guess.item()] += 1
    
    return confusion_matrix

def save_model(model, name):
    filename = 'models/' + name
    f = open(filename,'w')
    torch.save(model, f)
    