In [None]:
import os
import random
import numpy as np
import torch

In [None]:
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed = seed
np.random.seed = seed
torch.manual_seed = seed
print('Random seed: ', seed)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: ', device)

In [None]:
PATH_ROOT = os.path.dirname(os.getcwd())
PATH_DATA = os.path.join(PATH_ROOT, 'data')

print('Project root: ', PATH_ROOT)
print('Project data: ', PATH_DATA)

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.cifar10 = dataset
        
    def __getitem__(self, idx):
        data, target = self.cifar10[idx]
        return data, target, idx
    
    def __len__(self):
        return len(self.cifar10)

In [None]:
mean, std = 0.5, 0.5

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean, mean, mean), (std, std, std))
])

cifar10_dataset = datasets.CIFAR10(PATH_DATA, train=True, transform=transform, download=True)
custom_dataset = CustomDataset(cifar10_dataset)

In [None]:
BATCH_SIZE = 2 ** 6
NUM_WORKERS = 0
print('Batch size: ', BATCH_SIZE)
print('Num workers: ', NUM_WORKERS)
    
trainloader = DataLoader(custom_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if (stride != 1) or (in_planes != self.expansion * planes):
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if (stride != 1) or (in_planes != self.expansion * planes):
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3, 64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [None]:
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
import time
import pandas as pd
import torch.optim as optim

from datetime import timedelta

### --- Resume from Here ---

In [None]:
d = {'indices': np.arange(len(custom_dataset)),
     'targets': cifar10_dataset.targets}
df = pd.DataFrame(d)

df['predictions'] = 0
df['learn'] = 0
df['forget'] = 0
df['forgettable'] = 0

display(df.head())

df_fed = df.copy()

In [None]:
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NESTEROV = True
# MILESTONES = [60, 120, 160]
# GAMMA = 0.2

model = ResNet18().to(device)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV)
# scheduler = MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA)

In [None]:
def train(model, optimizer, criterion, images, labels):
    model.train()
    model.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    return model

In [None]:
def record(df, indices, preds, learn):
    df.loc[indices, 'predictions'] = preds.detach().to('cpu').numpy()
    if learn:
        df.loc[(df.index.isin(indices)) & (df['learn'] == 0) & (df['targets'] == df['predictions']), 'learn'] += 1
    else:
        df.loc[(df.index.isin(indices)) & (df['learn'] > 0) & (df['targets'] != df['predictions']), 'forget'] += 1

In [None]:
def test(model, images, labels, indices, df, learn):
    model.eval()
    outputs = model(images)
    _, preds = torch.max(outputs, 1)
    preds = preds.view(-1)
    
    record(df, indices, preds, learn)

In [None]:
EPOCHS = 100

st = time.time()
for ep in range(EPOCHS):
    print(f'[Epoch {ep + 1} / {EPOCHS}]')
    
    curr_images = []
    curr_labels = []
    curr_indices = []
    
    for batch_idx, (images, labels, indices) in enumerate(trainloader):
        curr_images.append(images)
        curr_labels.append(labels)
        curr_indices.append(indices)

        images, labels = images.to(device), labels.to(device)
        model = train(model, optimizer, criterion, images, labels)
        test(model, images, labels, indices, df, learn=True)
        if ep > 0:
            test(model, prev_images[batch_idx].to(device), prev_labels[batch_idx].to(device), prev_indices[batch_idx], df, learn=False)
            
    prev_images = curr_images.copy()
    prev_labels = curr_labels.copy()
    prev_indices = curr_indices.copy()
    
    print(f'| Learning Events: {df["learn"].sum()} | Forgetting Events: {df["forget"].sum()}')
    print(f'|-- Elapsed time: {timedelta(seconds=time.time()-st)}')

In [None]:
df.loc[(df['forget'] > 0), 'forgettable'] = 1
df.loc[(df['learn'] == 0) & (df['forget'] == 0), 'forgettable'] = 1

print('Number of Forgettable Samples')
df['forgettable'].value_counts()

In [None]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

In [None]:
NUM_PARTIES = 10

indices = list(range(len(custom_dataset)))
random.shuffle(indices)
parties = list(chunks(indices, int(len(custom_dataset) / NUM_PARTIES)))

trainloaders = []
for p in parties:
    train_subset = Subset(custom_dataset, p)
    trainloaders.append(
        DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    )

In [None]:
import copy

In [None]:
def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(len(w)))
    return w_avg

In [None]:
ROUNDS = EPOCHS

fed_model = ResNet18().to(device)
fed_weights = fed_model.state_dict()

curr_images, curr_labels, curr_indices = {}, {}, {}
prev_images, prev_labels, prev_indices = {}, {}, {}

st = time.time()
for r in range(ROUNDS):
    print(f'[Round {r + 1} / {ROUNDS}]')
    local_weights = []
    
    for i in range(NUM_PARTIES):
        curr_images[i] = []
        curr_labels[i] = []
        curr_indices[i] = []
        
        local_model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(local_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV)

        for batch_idx, (images, labels, indices) in enumerate(trainloaders[i]):
            curr_images[i].append(images)
            curr_labels[i].append(labels)
            curr_indices[i].append(indices)

            images, labels = images.to(device), labels.to(device)
            local_model = train(local_model, optimizer, criterion, images, labels)
            test(local_model, images, labels, indices, df_fed, learn=True)
            if r > 0:
                test(local_model, prev_images[i][batch_idx].to(device), prev_labels[i][batch_idx].to(device), prev_indices[i][batch_idx], df_fed, learn=False)

        prev_images[i] = curr_images[i].copy()
        prev_labels[i] = curr_labels[i].copy()
        prev_indices[i] = curr_indices[i].copy()
        
        local_weights.append(copy.deepcopy(local_model.state_dict()))
        print('|---- [Party {:>2}] Complete'.format(i + 1))
        
    fed_weights = average_weights(local_weights)
    fed_model.load_state_dict(fed_weights)

    print(f'| Learning Events: {df_fed["learn"].sum()} | Forgetting Events: {df_fed["forget"].sum()}')
    print(f'|-- Elapsed time: {timedelta(seconds=time.time()-st)}')

In [None]:
df_fed.loc[(df_fed['forget'] > 0), 'forgettable'] = 1
df_fed.loc[(df_fed['learn'] == 0) & (df_fed['forget'] == 0), 'forgettable'] = 1

print('Number of Forgettable Samples')
df_fed['forgettable'].value_counts()

In [None]:
print(df['forgettable'].value_counts())
print(df_fed['forgettable'].value_counts())