In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd / content/gdrive/MyDrive/cs394n_project/CS394N
! pip3 install -r requirements.txt

In [None]:
# Update path for custom module support in Google Colab
import sys
sys.path.append('/content/gdrive/MyDrive/cs394n_project/CS394N/src')

In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import *
from utils.dataset_tools import split_training_data
from utils.feature_extractor import *
from utils.cosine_similarity import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [16]:
model_dir = './models/'
log_dir = './logs'

model_selection = 'cnn' # linear | cnn | vgg
dataset_selection = 'cifar10' # cifar10 | fashionmnist

holdout_classes = [8, 9]
new_class = 8

ckpt_file = model_dir + 'cnn_cifar10_holdout_[8, 9].pt'
batch_size = 32
num_classes = 9
# 'plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck'

#### Hyperparameters

In [10]:
num_epochs = 1 #15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()

# Data Preparation

In [11]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [12]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [13]:
total_classes = 10 #len(torch.unique(train_data.targets))

## FOL

In [None]:
if model_selection == 'linear':
    fol_model = add_output_nodes(ckpt_file, arch='linear')
    fol_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    fol_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    fol_model.conv1.requires_grad_(False)
    fol_model.conv2.requires_grad_(False)
    fol_model.fc1.requires_grad_(False)
    
fol_model = fol_model.to(device)

In [None]:
fol_optimizer = torch.optim.Adam(fol_model.parameters(), lr=initial_learning_rate)
fol_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=fol_optimizer, gamma=decay_rate)

In [None]:
# just train on the new class
included_data, excluded_data = split_training_data(train_data, [new_class]) 
train_fol_loader = DataLoader(excluded_data, batch_size=batch_size, shuffle=True, num_workers=2)

# but test on the full 9 classes (old classes + new one, still excluding one)
included_data, excluded_data = split_training_data(test_data, [8])
test_fol_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
model_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '.pt'
recall_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_recall.npy'
train_losses_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_train_loss.txt'
test_losses_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_test_loss.txt'

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_fol_loader, fol_model, loss_fn, fol_optimizer, device)
    test_loss, y_pred, y_actual = test(test_fol_loader, fol_model, loss_fn, device, swap=True, swap_labels=[9,8])
    print('y_pred:', y_pred[:2])
    print('y_actual:', y_actual[:2])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    fol_lr_scheduler.step()
    
torch.save(fol_model.state_dict(), model_file_fol)


recalls = get_recall_per_epoch(y_actuals, y_preds, num_classes)
np.save(recall_file_fol, recalls)

with open(train_losses_file_fol, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file_fol, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

# Ok I'm struggling with getting data loading correct for FOL but I'm not sure we even need it

## SWIL

In [17]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, arch='linear')
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn':
    swil_model = add_output_nodes(ckpt_file, arch='cnn')
    swil_model.conv1.requires_grad_(False)
    swil_model.conv2.requires_grad_(False)
    swil_model.fc1.requires_grad_(False)
    
swil_model = swil_model.to(device)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
swil_optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=swil_optimizer, gamma=decay_rate)

In [None]:
cnn_classes = [0,2,3,4,5,6,7,8,9]

class_subsets, class_idxs, subset_size = generate_dls(train_data, cnn_classes)

In [None]:
with open(r'./data/cifar10_sim_scores_cat.txt', 'r') as fp:
    sim_scores = [float(i) for i in fp.readlines()]

sim_sum = sum(sim_scores)

sim_norms = [x/sim_sum for x in sim_scores]

cat_sample_size = 720
sim_sample_sizes = [int(x * cat_sample_size*3.52) for x in sim_norms] + [cat_sample_size]

In [None]:
from random import sample

sampled_idxs = []

for i in range(len(fmnist_classes)):
    idx_sample = sample(class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample

swil_train_subset = torch.utils.data.Subset(train_data, sampled_idxs)

swil_train_dl = torch.utils.data.DataLoader(swil_train_subset, batch_size=1, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [8])
test_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

### Training Loop

In [None]:
weight_dir = './weights/'
model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil.pt'
recall_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_recall.npy'
train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_train_loss.txt'
test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'swil_test_loss.txt'
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(swil_train_dl, swil_model, loss_fn, swil_optimizer, device, swap=True, swap_labels=[9,8])
    test_loss, y_pred, y_actual = test(test_swil_loader, swil_model, loss_fn, device, swap=True, swap_labels=[9,8])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(swil_model.state_dict(), model_file)

recalls = get_recall_per_epoch(y_actuals, y_preds, 9)
np.save(recall_file, recalls)

with open(train_losses_file, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

## G-SWIL

In [None]:
if model_selection == 'linear':
    gswil_model = add_output_nodes(ckpt_file, arch='linear')
    gswil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    gswil_model.conv1.requires_grad_(False)
    gswil_model.conv2.requires_grad_(False)
    gswil_model.fc1.requires_grad_(False)
    
gswil_model = gswil_model.to(device)

In [None]:
optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)
print("Done!")