In [10]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import train, test, add_output_nodes
from utils.dataset_tools import split_training_data

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [8]:
model_selection = 'linear' # linear | cnn | vgg
dataset_selection = 'fashionmnist' # cifar10 | fashionmnist
holdout_classes = [8, 9]
new_class = 8
ckpt_file = './weights/linear_fashionmnist_holdout_[8, 9].pt'
batch_size = 32

#### Hyperparameters

In [None]:
num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

# Data Preparation

In [3]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [4]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
total_classes = len(torch.unique(train_data.targets))

## FOL

In [11]:
if model_selection == 'linear':
    fol_model = add_output_nodes(ckpt_file, arch='linear')
    fol_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    fol_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    fol_model.conv1.requires_grad_(False)
    fol_model.conv2.requires_grad_(False)
    fol_model.fc1.requires_grad_(False)
    
fol_model = fol_model.to(device)

input_size 784
num_outputs 9


In [None]:
fol_optimizer = torch.optim.Adam(fol_model.parameters(), lr=initial_learning_rate)
fol_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=fol_optimizer, gamma=decay_rate)

In [None]:
included_data, excluded_data = split_training_data(train_data, [8]) 
train_fol_loader = DataLoader(excluded_data, batch_size=batch_size, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [8])
test_fol_loader = DataLoader(excluded_data, batch_size=batch_size, shuffle=True, num_workers=2)

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_fol_loader, fol_model, loss_fn, fol_optimizer, device)
    test_loss = test(test_fol_loader, fol_model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    fol_lr_scheduler.step()
    
#torch.save(model.state_dict(), model_file)
print("Done!")

## SWIL

In [11]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, arch='linear')
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    swil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    swil_model.conv1.requires_grad_(False)
    swil_model.conv2.requires_grad_(False)
    swil_model.fc1.requires_grad_(False)
    
swil_model = swil_model.to(device)

input_size 784
num_outputs 9


In [None]:
optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)
print("Done!")

## G-SWIL

In [11]:
if model_selection == 'linear':
    gswil_model = add_output_nodes(ckpt_file, arch='linear')
    gswil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    gswil_model.conv1.requires_grad_(False)
    gswil_model.conv2.requires_grad_(False)
    gswil_model.fc1.requires_grad_(False)
    
gswil_model = gswil_model.to(device)

input_size 784
num_outputs 9


In [None]:
optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_inc_loader, model, loss_fn, optimizer, device)
    test_loss = test(test_inc_loader, model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(model.state_dict(), model_file)
print("Done!")