In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
%cd /content/gdrive/MyDrive/research/SWIL-Comparisons
! pip3 install -r requirements.txt

/content/gdrive/MyDrive/research/SWIL-Comparisons


In [3]:
# Update path for custom module support in Google Colab
import sys
sys.path.append('/content/gdrive/MyDrive/research/SWIL-Comparisons/src')

In [4]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

#from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import train, test, get_recall_per_epoch
from utils.dataset_tools import split_training_data, reorder_classes

In [5]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


# Data Preparation

In [6]:
model_selection = 'cnn' # linear | cnn | cnn-demo | vgg
dataset_selection = 'cifar10' # cifar10 | fashionmnist

In [7]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Load Dataset

In [8]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    
total_classes = len(np.unique(train_data.targets))

Files already downloaded and verified
Files already downloaded and verified


## Reorder Classes

In [None]:
# FashionMNIST (torchvision & paper)
# ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# CIFAR-10
# (torchvision): ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# (paper): ['bird', 'deer', 'dog', 'frog', 'horse', 'airplane','ship', 'truck', 'cat', 'automobile']


In [9]:
# CIFAR10 match torchvision with paper
ordering = {
    0:(5, False),
    1:(9, False),
    2: (0, False),
    3: (8, False),
    4: (1, False),
    5: (2, False),
    6: (3, False),
    7: (4, False),
    8: (6, False),
    9: (7, False),
}

In [10]:
targets, classes = reorder_classes(train_data, ordering)
train_data.targets = targets
train_data.classes = classes

targets, classes = reorder_classes(test_data, ordering)
test_data.targets = targets
test_data.classes = classes

## Create Subsets

In [11]:
holdout_classes = [8, 9]
batch_size = 32

In [12]:
included_data, excluded_data = split_training_data(train_data, holdout_classes) 

train_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
train_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [13]:
included_data, excluded_data = split_training_data(test_data, holdout_classes)

test_inc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_exc_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

# Train Model

## Load Architecture

In [14]:
num_classes = total_classes - len(holdout_classes)

if model_selection == 'linear':
    input_size = train_data[0][0].shape[0] * train_data[0][0].shape[1] * train_data[0][0].shape[2]
    model = LinearFashionMNIST_alt(input_size, num_classes)
elif model_selection == 'cnn':
    model = CNN_3B(num_classes)
elif model_selction == 'cnn-demo':
    model = CNN_demo(num_classes)
elif model_slection == 'vgg':
    print('Model not implemented')
    
model.to(device)

CNN_3B(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (conv_block1): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (mpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block2): Sequential(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (mpool): MaxPool2d

## File Paths

In [18]:
weight_dir = '/content/gdrive/MyDrive/research/SWIL-Comparisons/src/models/'
log_dir = './logs/'

model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + '.pt'
recall_file = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'recall.npy'
train_losses_file = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'train_loss.txt'
test_losses_file = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + str(holdout_classes) + 'test_loss.txt'

## Hyperparameters

In [16]:
num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

## Training Loop

In [17]:
train_losses = []
test_losses = []
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)

#with open(train_losses_file, 'w') as fp:
#    for s in train_losses:
#        fp.write("%s\n" % s)
        
#with open(test_losses_file, 'w') as fp:
#    for x in test_losses:
#        fp.write("%s\n" % x)

print("Done!")

Epoch 1
-------------------------------
loss: 2.085680  [    0/40000]
loss: 1.090821  [32000/40000]
Test Error: 
 Accuracy: 71.5%, Avg loss: 0.831643 

Epoch 2
-------------------------------
loss: 0.569671  [    0/40000]
loss: 0.640264  [32000/40000]
Test Error: 
 Accuracy: 80.1%, Avg loss: 0.576776 

Epoch 3
-------------------------------
loss: 0.408754  [    0/40000]
loss: 0.484026  [32000/40000]
Test Error: 
 Accuracy: 83.6%, Avg loss: 0.477641 

Epoch 4
-------------------------------
loss: 0.311924  [    0/40000]
loss: 0.734633  [32000/40000]
Test Error: 
 Accuracy: 85.5%, Avg loss: 0.456581 

Epoch 5
-------------------------------
loss: 0.163965  [    0/40000]
loss: 0.122116  [32000/40000]
Test Error: 
 Accuracy: 85.2%, Avg loss: 0.481593 

Epoch 6
-------------------------------
loss: 0.170392  [    0/40000]
loss: 0.082805  [32000/40000]
Test Error: 
 Accuracy: 86.8%, Avg loss: 0.500714 

Epoch 7
-------------------------------
loss: 0.069022  [    0/40000]
loss: 0.002612  [3

RuntimeError: ignored

In [19]:
torch.save(model.state_dict(), model_file)

In [None]:
recalls_loaded = np.load(recall_file)
print(recalls == recalls_loaded)
# plots