# Crafting worst-case inital model parameters through pre-training

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import os
from copy import deepcopy

from models import Models
from utils.data import load_data
from audit_model import test_model

device = 'device'

## MNIST
### Pre-train on half of MNIST

In [None]:
# hyper-parameters
data_name = 'mnist'
lr = 0.01
n_epochs = 5
batch_size = 32

# reproducibility
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# load full dataset
X, y, out_dim = load_data(data_name, None, device=device, split='train')
X_test, y_test, _ = load_data(data_name, None, device=device, split='test')

len(X)

In [None]:
# use only first half of dataset for pre-training
X_train, y_train = X[:len(X)//2], y[:len(X)//2]

len(X_train)

In [5]:
# define model
model = Models['cnn'](X_train.shape, out_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr)

In [None]:
# train model
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=False)

pbar = tqdm(range(n_epochs))
losses = []
save_model_epochs = [1, 2, 3, 4]
saved_models = []
for curr_epoch in pbar:
    for curr_X, curr_y in train_loader:
        optimizer.zero_grad()

        output = model(curr_X)
        loss = criterion(output, curr_y)
        loss.backward()

        optimizer.step()

        losses.append(loss.cpu().item())
        pbar.set_postfix({'loss': losses[-1]})
    
    if curr_epoch in save_model_epochs:
        saved_models.append(deepcopy(model)) 

In [None]:
model.load_state_dict(torch.load('pretrained_models/cnn_mnist_half.pt'))

In [None]:
# test accuracy
test_acc = test_model(model, X_test, y_test) * 100
print(f'Test accuracy (%): {test_acc:.3f}')

In [9]:
# save model
torch.save(model.cpu().state_dict(), f'pretrained_models/cnn_mnist_half.pt')
for i, (save_model_epoch, model) in enumerate(zip(save_model_epochs, saved_models)):
    torch.save(model.cpu().state_dict(), f'pretrained_models/cnn_mnist_half_epochs/{save_model_epoch}epochs.pt')

In [4]:
# save remaining half to ensure no overlap
folder = f'data/{data_name}_finetune_half/'
os.makedirs(folder, exist_ok=True)

X_finetune, y_finetune = X[len(X)//2:], y[len(y)//2:]

np.save(f'{folder}/X_train.npy', X_finetune.cpu().numpy())
np.save(f'{folder}/y_train.npy', y_finetune.cpu().numpy())
np.save(f'{folder}/X_test.npy', X_test.cpu().numpy())
np.save(f'{folder}/y_test.npy', y_test.cpu().numpy())

## CIFAR-10
### Pre-train on CIFAR-100

In [None]:
# hyper-parameters
data_name = 'cifar100'
lr_schedule = [(0, 0.1), (128, 0.01), (192, 0.001)]
momentum = 0.9
nesterov = True
weight_decay = 5e-4
n_epochs = 300
batch_size = 128
augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip()
])

# reproducibility
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# load full CIFAR-100 dataset for pre-training
X, y, out_dim = load_data('cifar100', None, device=device, split='train')

len(X)

In [None]:
# define model
model = Models['cnn'](X_train.shape, out_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr_schedule[0][1], momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)

In [None]:
# train model
train_loader = DataLoader(TensorDataset(X, y), batch_size=batch_size, shuffle=False)

pbar = tqdm(range(n_epochs))
losses = []
lr_schedule_idx = 1
for curr_epoch in pbar:
    if lr_schedule_idx < len(lr_schedule) and curr_epoch == lr_schedule[lr_schedule_idx][0]:
        optimizer = optim.SGD(model.parameters(), lr_schedule[lr_schedule_idx][1], momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
        lr_schedule_idx += 1

    for curr_X, curr_y in train_loader:
        optimizer.zero_grad()

        output = model(augment(curr_X))
        loss = criterion(output, curr_y)
        loss.backward()

        optimizer.step()

        losses.append(loss.cpu().item())
        pbar.set_postfix({'loss': losses[-1]})

# save model
torch.save(model.state_dict(), f'pretrained_models/cnn_cifar100_pretrained.pt')

### Fine-tune on half of CIFAR-10

In [None]:
# hyper-parameters
data_name = 'cifar10'
lr_schedule = [(0, 0.1), (25, 0.01)]
n_epochs = 100
batch_size = 256
augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip()
])

# reproducibility
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# load full dataset
X, y, out_dim = load_data(data_name, None, device=device, split='train')
X_test, y_test, _ = load_data(data_name, None, device=device, split='test')

len(X)

In [None]:
# use only first half of dataset for pre-training
X_train, y_train = X[:len(X)//2], y[:len(X)//2]

len(X_train)

In [None]:
# load model pre-trained on CIFAR-100 and reset final layer
pretrain_model_state = torch.load(f'cnn_cifar100_pretrained.pt')

# initialize new model for fine-tuning
model = Models['cnn'](X_train.shape, out_dim).to(device)

# import state from pre-trained model, overriding final classifier / linear layer
model_state = model.state_dict()
layer_name = 'net.classifier.2'
pretrain_model_state[f'{layer_name}.weight'] = model_state[f'{layer_name}.weight']
pretrain_model_state[f'{layer_name}.bias'] = model_state[f'{layer_name}.bias']
model.load_state_dict(pretrain_model_state)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr_schedule[0][1])

In [None]:
# train model
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=False)

pbar = tqdm(range(n_epochs))
lr_schedule_idx = 1
losses = []
test_accs = []
for curr_epoch in pbar:
    test_accs.append(test_model(model, X_test, y_test) * 100)

    if lr_schedule_idx < len(lr_schedule) and curr_epoch == lr_schedule[lr_schedule_idx][0]:
        optimizer = optim.SGD(model.parameters(), lr_schedule[lr_schedule_idx][1])
        lr_schedule_idx += 1

    for curr_X, curr_y in train_loader:
        optimizer.zero_grad()

        output = model(augment(curr_X))
        loss = criterion(output, curr_y)
        loss.backward()

        optimizer.step()

        losses.append(loss.cpu().item())
        pbar.set_postfix({'test acc': test_accs[-1], 'loss': losses[-1]})

In [None]:
# test accuracy
test_acc = test_model(model, X_test, y_test) * 100
print(f'Test accuracy (%): {test_acc:.3f}')

In [None]:
# save model
torch.save(model.state_dict(), f'cnn_cifar100_cifar10_half.pt')

## Pre-train logistic regression on last layer activations of CIFAR-10 on WRN-28-10

In [None]:
# hyper-parameters
data_name = 'cifar10_half_finetune_last'
out_dim = 10
lr = 0.1
n_epochs = 20 
batch_size = 32

# reproducibility
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# load finetune dataset (last layer activations of CNN back-bone of WRN-28-10)
data_name = 'cifar10_half_finetune_last'

X_train, y_train = torch.from_numpy(np.load(f'data/{data_name}/X_pretrain.npy')).to(device), torch.from_numpy(np.load(f'data/{data_name}/y_pretrain.npy')).to(device)
X_test, y_test = torch.from_numpy(np.load(f'data/{data_name}/X_test.npy')).to(device), torch.from_numpy(np.load(f'data/{data_name}/y_test.npy')).to(device)

len(X_train)

In [10]:
# define model
model = Models['lr'](X_train.shape, out_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr)

In [None]:
# train model
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=False)

pbar = tqdm(range(n_epochs))
losses = []
for _ in pbar:
    for curr_X, curr_y in train_loader:
        optimizer.zero_grad()

        output = model(curr_X)
        loss = criterion(output, curr_y)
        loss.backward()

        optimizer.step()

        losses.append(loss.cpu().item())
        pbar.set_postfix({'loss': losses[-1]})

In [None]:
# test accuracy
test_acc = test_model(model, X_test, y_test) * 100
print(f'Test accuracy (%): {test_acc:.3f}')

In [7]:
# save model
torch.save(model.cpu().state_dict(), f'pretrained_models/{data_name}.pt')