In [None]:
import os

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt # visualization
import seaborn as sns # visualization
# machine learning
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets, models

import progressbar

In [None]:
from BasicClassifier import BasicClassifier
from DataAugment import DataAug
from Metrics import Metrics

In [None]:
daug = DataAug()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
widgets = [
    ' [', progressbar.Timer(), '] ',
    progressbar.Percentage(), ' ',
    progressbar.Bar(),
    ' (', progressbar.ETA(), ') ',
]

In [None]:
# import mnist dataset
dataset = "MNIST"
BATCH_SIZE = 500
num_classes = 10

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data',
                               train=True,
                               download=True,
                               transform=torchvision.transforms.ToTensor()),
    batch_size=BATCH_SIZE,shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data',
                               train=False,
                               download=True,
                               transform=torchvision.transforms.ToTensor()),
    batch_size=BATCH_SIZE,shuffle=True)

# Model Trainer

In [None]:
def train(model,train_loader,test_loader,proportion,funcs,func_proportions,NUM_CLASSSES=10,NUM_EPOCHS=25):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr=3e-5)
    
    epoch_metrics = {}
    epoch_test_metrics = {}
    
    bar = progressbar.ProgressBar(NUM_EPOCHS*len(train_loader),widgets=widgets).start()
    for epoch in range(NUM_EPOCHS):
        model.metric.reset_confusion_matrix(NUM_CLASSSES)
        
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(func_divider(inputs,proportion,funcs,func_proportions).to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # statistics
            model.metric.update_confusion_matrix(outputs.to('cpu'),labels)
            # progressbar
            bar.update(epoch*len(train_loader)+i)
        epoch_metrics[epoch] = model.metric.classification_metrics()
        
        with torch.no_grad():
            model.metric.reset_confusion_matrix(NUM_CLASSSES)
            for (inputs, labels) in test_loader:
                outputs = model(datafunc(inputs,proportion).to(device))
                # statistics
                model.metric.update_confusion_matrix(outputs.to('cpu'),labels)
                # progressbar
            epoch_test_metrics[epoch] = model.metric.classification_metrics()
            
    return (epoch_metrics,epoch_test_metrics)

def func_divider(inputs,proportion,funcs,func_proportions):
    func_num = np.ceil(func_proportions*inputs.shape[0]).astype(int)
    lossyinputs = torch.clone(inputs)
    h = 0
    t = 0
    for i, func in enumerate(funcs):
        t += func_num[i]
        lossyinputs[h:t] = funcs[func](lossyinputs[h:t],proportion)
        h += t
    return lossyinputs
    
def display_training_metrics(name,epoch_metrics):
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,1]) # precision
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,2]) # recall
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,0],) # accuracy
    plt.xlabel('epoch')
    plt.title(name)
    
def display_testing_metrics(name,epoch_metrics):
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,1]) # precision
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,2]) # recall
    sns.lineplot(x=list(epoch_metrics.keys()),y=np.array(list(epoch_metrics.values()),dtype=float)[:,0],) # accuracy
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.xlabel('% loss')
    plt.title(name)
    
def display_testing_metrics_hist(name,metrics):
    sns.barplot(x=['accuracy','precision','recall'],y=metrics[0:3])
    plt.ylim([0, 1])
    plt.title(name)

In [None]:
func_dict = {'rand_pixel':getattr(daug, 'rand_pixel'),
             'rand_row':getattr(daug, 'rand_row'),
             'rand_column':getattr(daug, 'rand_column'),
             'rand_rowcol':getattr(daug, 'rand_rowcol'),
             'rand_block':getattr(daug, 'rand_rowcol'),
             'pattern_checkerboard':getattr(daug,'pattern_checkerboard'),
             'pattern_column':getattr(daug,'pattern_column'),
             'pattern_row':getattr(daug,'pattern_row')}

proportion = 0.5
funcs = {'rand_row':getattr(daug, 'rand_row'),
             'rand_column':getattr(daug, 'rand_column')}

func_proportions = np.array([0.5,0.5])

In [None]:
model = BasicClassifier(num_classes)
model.to(device)
(train_metrics,test_metrics) = train(model,train_loader,test_loader,proportion,funcs,func_proportions,NUM_CLASSSES=num_classes,NUM_EPOCHS=30)