In [2]:
import os
import requests
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
from tqdm import tqdm
import gc

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import prune
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
SEED = 42
RNG = torch.Generator().manual_seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

import sys
sys.path.append('../')
from utils import *

Running on device: CUDA


In [3]:
import ssl

# Create an unverified SSL context
ssl._create_default_https_context = ssl._create_unverified_context

In [4]:
batch_size = 64

# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=True, download=False, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=False, download=False, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=1)

# download the forget and retain index split
local_path = "../example notebooks/forget_idx.npy"
# if not os.path.exists(local_path):
#     response = requests.get(
#         "https://storage.googleapis.com/unlearning-challenge/" + local_path
#     )
#     open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=batch_size, shuffle=False, drop_last=True
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=batch_size, shuffle=True, generator=RNG, drop_last=True
)

In [8]:
# load model with pre-trained weights
net = resnet18(weights=None, num_classes=10)
net.to(DEVICE);

epochs = 50
val_loss = np.inf


current_batch = 0
total_samples = len(retain_loader.dataset)
batch_size = retain_loader.batch_size
batches_per_epoch  = math.ceil(total_samples / batch_size)
total_batches = epochs * batches_per_epoch
initial_lr = 0.01
warmup_batches = math.ceil(10*batches_per_epoch)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=0.9, weight_decay=5e-3)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
net.train()

for ep in range(epochs):

    net.train()

    for inputs, targets in retain_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        current_batch += 1

        # Warm-up for the first 'warmup_batches' batches
        if current_batch <= warmup_batches:
            adjust_learning_rate(optimizer, current_batch, warmup_batches, initial_lr)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        nn.utils.clip_grad_value_(net.parameters(), 10)

        optimizer.step()


    net.eval()  # handle drop-out/batch norm layers
    loss = 0
    with torch.no_grad():
        for x,y in val_loader:
            out = net(x.to(DEVICE))  # only forward pass - NO gradients!!
            loss += criterion(out, y.to(DEVICE))
    # total loss - divide by number of batches
    temp_loss = loss / len(val_loader)

    print('--------'*5)
    print(f'Epoch: {ep}')
    print(f'Val loss: {temp_loss}')
    
    val_acc = accuracy(net, val_loader)
    # print(f"Retain set accuracy: {100.0 * accuracy(net, retain_loader):0.1f}%")
    # print(f"Forget set accuracy: {100.0 * accuracy(net, forget_loader):0.1f}%")
    print(f"Val set accuracy: {100.0 * val_acc:0.1f}%")
    # print(f"Test set accuracy: {100.0 * accuracy(net, test_loader):0.1f}%")

    # if temp_loss < val_loss:
    #     val_loss = temp_loss
    
    scheduler.step(temp_loss)

    gc.collect()
    torch.cuda.empty_cache()

torch.save({
    'net': net.state_dict(),
}, f'./checkpoints/checkpoint25.pth')
# }, f'../example notebooks/weights/internal_weights_resnet18_cifar10.pth')

----------------------------------------
Epoch: 0
Val loss: 1.5058852434158325
Val set accuracy: 44.5%
----------------------------------------
Epoch: 1
Val loss: 1.2743005752563477
Val set accuracy: 52.6%
----------------------------------------
Epoch: 2
Val loss: 1.16933274269104
Val set accuracy: 57.7%
----------------------------------------
Epoch: 3
Val loss: 1.1101582050323486
Val set accuracy: 62.1%
----------------------------------------
Epoch: 4
Val loss: 1.0009515285491943
Val set accuracy: 64.8%
----------------------------------------
Epoch: 5
Val loss: 0.9413677453994751
Val set accuracy: 67.4%
----------------------------------------
Epoch: 6
Val loss: 0.932471752166748
Val set accuracy: 67.1%
----------------------------------------
Epoch: 7
Val loss: 0.8454344868659973
Val set accuracy: 70.5%
----------------------------------------
Epoch: 8
Val loss: 0.8583778738975525
Val set accuracy: 70.4%
----------------------------------------
Epoch: 9
Val loss: 0.91984957456588