In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import *
from utils.dataset_tools import split_training_data
from utils.feature_extractor import *
from utils.cosine_similarity import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [3]:
model_dir = './weights/'
log_dir = './logs'

model_selection = 'linear' # linear | cnn | vgg
dataset_selection = 'fashionmnist' # cifar10 | fashionmnist

holdout_classes = [8, 9]
new_class = 9

ckpt_file = './weights/linear_fashionmnist_holdout_[8, 9].pt'
batch_size = 32
num_classes = 9

#### Hyperparameters

In [4]:
num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

# Data Preparation

In [5]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [6]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [7]:
total_classes = len(torch.unique(train_data.targets))

## FOL

In [8]:
if model_selection == 'linear':
    fol_model = add_output_nodes(ckpt_file, arch='linear')
    fol_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    fol_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    fol_model.conv1.requires_grad_(False)
    fol_model.conv2.requires_grad_(False)
    fol_model.fc1.requires_grad_(False)
    
fol_model = fol_model.to(device)

input_size 784
num_outputs 9


In [9]:
fol_optimizer = torch.optim.Adam(fol_model.parameters(), lr=initial_learning_rate)
fol_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=fol_optimizer, gamma=decay_rate)

In [13]:
# just train on the new class
included_data, excluded_data = split_training_data(train_data, [new_class]) 
train_fol_loader = DataLoader(excluded_data, batch_size=batch_size, shuffle=True, num_workers=2)

# but test on the full 9 classes (old classes + new one, still excluding one)
included_data, excluded_data = split_training_data(test_data, [8])
test_fol_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [14]:
model_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '.pt'
recall_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_recall.npy'
train_losses_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_train_loss.txt'
test_losses_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_test_loss.txt'

### Training Loop

In [10]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_fol_loader, fol_model, loss_fn, fol_optimizer, device)
    test_loss, y_pred, y_actual = test(test_fol_loader, fol_model, loss_fn, device, swap=True, swap_labels=[9,8])
    print('y_pred:', y_pred[:2])
    print('y_actual:', y_actual[:2])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    fol_lr_scheduler.step()
    
torch.save(fol_model.state_dict(), model_file_fol)


recalls = get_recall_per_epoch(y_actuals, y_preds, num_classes)
np.save(recall_file_fol, recalls)

with open(train_losses_file_fol, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file_fol, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

# Ok I'm struggling with getting data loading correct for FOL but I'm not sure we even need it

Epoch 1
-------------------------------


NameError: name 'train_fol_loader' is not defined

## SWIL

In [8]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, arch='linear')
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    swil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    swil_model.conv1.requires_grad_(False)
    swil_model.conv2.requires_grad_(False)
    swil_model.fc1.requires_grad_(False)
    
swil_model = swil_model.to(device)

input_size 784
num_outputs 9


In [11]:
swil_optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=swil_optimizer, gamma=decay_rate)

In [12]:
fmnist_classes = list(range(8)) + [9]

# might not need these
FMNIST_trainloader_gen = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
FMNIST_testloader_gen = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

class_subsets, class_idxs, subset_size = generate_dls(train_data, fmnist_classes)

In [14]:
with open(r'./data/fmnist_sim_scores_boot.txt', 'r') as fp:
    sim_scores = [float(i) for i in fp.readlines()]

sim_sum = sum(sim_scores)

sim_norms = [x/sim_sum for x in sim_scores]

boots_sample_size = 75
sim_sample_sizes = [27 if x < 0.2 else int(x * boots_sample_size*3.52) for x in sim_norms] + [75]

In [15]:
from random import sample

sampled_idxs = []

for i in range(len(fmnist_classes)):
    idx_sample = sample(class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample

swil_train_subset = torch.utils.data.Subset(train_data, sampled_idxs)

swil_train_dl = torch.utils.data.DataLoader(swil_train_subset, batch_size=1, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [8])
test_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

### Training Loop

In [16]:
weight_dir = './weights/'
model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil.pt'
recall_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_recall.npy'
train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_train_loss.txt'
test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'swil_test_loss.txt'
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(swil_train_dl, swil_model, loss_fn, swil_optimizer, device, swap=True, swap_labels=[9,8])
    test_loss, y_pred, y_actual = test(test_swil_loader, swil_model, loss_fn, device, swap=True, swap_labels=[9,8])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(swil_model.state_dict(), model_file)

recalls = get_recall_per_epoch(y_actuals, y_preds, 9)
np.save(recall_file, recalls)

with open(train_losses_file, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

Epoch 1
-------------------------------
loss: 0.602320  [    0/  349]
Test Error: 
 Accuracy: 72.3%, Avg loss: 1.310042 

Epoch 2
-------------------------------
loss: 0.232019  [    0/  349]
Test Error: 
 Accuracy: 76.4%, Avg loss: 0.925488 

Epoch 3
-------------------------------
loss: 0.003179  [    0/  349]
Test Error: 
 Accuracy: 76.6%, Avg loss: 0.847425 

Epoch 4
-------------------------------
loss: 0.000000  [    0/  349]
Test Error: 
 Accuracy: 76.5%, Avg loss: 0.804749 

Epoch 5
-------------------------------
loss: 0.113733  [    0/  349]
Test Error: 
 Accuracy: 78.7%, Avg loss: 0.752019 

Epoch 6
-------------------------------
loss: 0.676546  [    0/  349]
Test Error: 
 Accuracy: 78.1%, Avg loss: 0.730093 

Epoch 7
-------------------------------
loss: 0.946021  [    0/  349]
Test Error: 
 Accuracy: 78.1%, Avg loss: 0.716133 

Epoch 8
-------------------------------
loss: 0.215354  [    0/  349]
Test Error: 
 Accuracy: 78.9%, Avg loss: 0.699688 

Epoch 9
----------------

[[0.81199998 0.81       0.65100002 0.57200003 0.86000001 0.75599998
  0.67500001 0.80500001 0.67500001 0.77899998 0.79900002 0.73699999
  0.72899997 0.77700001 0.74000001]
 [0.95200002 0.94599998 0.949      0.95200002 0.94700003 0.949
  0.94800001 0.94800001 0.949      0.94999999 0.949      0.94999999
  0.94999999 0.94999999 0.95099998]
 [0.588      0.66299999 0.66799998 0.64899999 0.68699998 0.61000001
  0.69400001 0.70499998 0.66299999 0.62699997 0.66900003 0.66399997
  0.67000002 0.68000001 0.671     ]
 [0.84299999 0.84799999 0.85600001 0.85900003 0.85100001 0.87
  0.85399997 0.81900001 0.84299999 0.85299999 0.84299999 0.85500002
  0.85399997 0.84799999 0.84600002]
 [0.52700001 0.62800002 0.71200001 0.73100001 0.69099998 0.67799997
  0.70099998 0.70300001 0.70300001 0.73199999 0.704      0.704
  0.70499998 0.70099998 0.70200002]
 [0.63700002 0.72299999 0.76300001 0.764      0.73699999 0.78799999
  0.815      0.81300002 0.82800001 0.82700002 0.82499999 0.82700002
  0.83099997 0.82599

## G-SWIL

In [11]:
if model_selection == 'linear':
    gswil_model = add_output_nodes(ckpt_file, arch='linear')
    gswil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    gswil_model.conv1.requires_grad_(False)
    gswil_model.conv2.requires_grad_(False)
    gswil_model.fc1.requires_grad_(False)
    
gswil_model = gswil_model.to(device)

input_size 784
num_outputs 9


In [None]:
optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)
print("Done!")