In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import train, test
from utils.dataset_tools import split_training_data

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# FashionMNIST
# ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# CIFAR-10
# []

In [None]:
model_selection = 'linear' # linear | cnn | vgg
dataset_selection = 'fashionmnist' # cifar10 | fashionmnist
holdout_classes = [8, 9]

batch_size = 32

# Data Preparation

In [None]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Load Dataset

In [None]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
total_classes = len(torch.unique(train_data.targets))

## Create Subsets

In [None]:
included_data, excluded_data = split_training_data(train_data, holdout_classes) 

train_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
train_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
included_data, excluded_data = split_training_data(test_data, holdout_classes)

test_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

# Train Model

## Load Architecture

In [None]:
num_classes = total_classes - len(holdout_classes)

if model_selection == 'linear':
    model = LinearFashionMNIST(num_classes)
elif model_selection == 'cnn':
    pass

## Hyperparameters

In [None]:
weight_dir = './weights/'
model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + '.pt'

num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=initial_learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#optimizer = torch.optim.AdamW(model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

## Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)
print("Done!")