In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd / content/gdrive/MyDrive/cs394n_project/CS394N
! pip3 install -r requirements.txt

In [None]:
# Update path for custom module support in Google Colab
import sys
sys.path.append('/content/drive/MyDrive/cs394n_project/CS394N/src')

In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import train, test, get_recall_per_epoch
from utils.dataset_tools import split_training_data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [3]:
# FashionMNIST
# ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# CIFAR-10
# 

In [4]:
model_selection = 'linear' # linear | cnn | vgg
dataset_selection = 'fashionmnist' # cifar10 | fashionmnist
holdout_classes = [8, 9]

batch_size = 32

# Data Preparation

In [5]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Load Dataset

In [6]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [7]:
total_classes = len(torch.unique(train_data.targets))

## Create Subsets

In [8]:
included_data, excluded_data = split_training_data(train_data, holdout_classes) 

train_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
train_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [9]:
included_data, excluded_data = split_training_data(test_data, holdout_classes)

test_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

# Train Model

## Load Architecture

In [10]:
num_classes = total_classes - len(holdout_classes)

if model_selection == 'linear':
    model = LinearFashionMNIST_alt(28*28, num_classes)
elif model_selection == 'cnn':
    pass

## Hyperparameters

In [11]:
weight_dir = './weights/'
model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + '.pt'
recall_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'recall.npy'
train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'train_loss.txt'
test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'test_loss.txt'

num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=initial_learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#optimizer = torch.optim.AdamW(model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

## Training Loop

In [12]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss, y_pred, y_actual = test(test_inc_loader, model, loss_fn, device)
    print(y_pred[:2])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)

recalls = get_recall_per_epoch(y_actuals, y_preds, num_classes)
np.save(recall_file, recalls)

with open(train_losses_file, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

Epoch 1
-------------------------------
loss: 2.070697  [    0/48000]
loss: 0.595658  [32000/48000]
Test Error: 
 Accuracy: 80.3%, Avg loss: 0.531100 

Epoch 2
-------------------------------
loss: 0.589653  [    0/48000]
loss: 0.282939  [32000/48000]
Test Error: 
 Accuracy: 82.2%, Avg loss: 0.488802 

Epoch 3
-------------------------------
loss: 0.448936  [    0/48000]
loss: 0.593283  [32000/48000]
Test Error: 
 Accuracy: 81.9%, Avg loss: 0.496413 

Epoch 4
-------------------------------
loss: 0.403181  [    0/48000]
loss: 0.496525  [32000/48000]
Test Error: 
 Accuracy: 81.8%, Avg loss: 0.493737 

Epoch 5
-------------------------------
loss: 0.689814  [    0/48000]
loss: 0.445743  [32000/48000]
Test Error: 
 Accuracy: 82.5%, Avg loss: 0.490080 

Epoch 6
-------------------------------
loss: 0.384188  [    0/48000]
loss: 0.367459  [32000/48000]
Test Error: 
 Accuracy: 83.1%, Avg loss: 0.465148 

Epoch 7
-------------------------------
loss: 0.290445  [    0/48000]
loss: 0.291866  [3

In [14]:
recalls_loaded = np.load(recall_file)
print(recalls == recalls_loaded)
# plots

[[ True]
 [ True]
 [ True]
 [ True]
 [ True]
 [ True]
 [ True]
 [ True]]
