In [1]:
from torchvision.datasets import CIFAR10
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SPLIT=True
MODEL_ID=0
BATCH_SIZE=2**9
NUM_EPOCHS=100

In [3]:
torch.manual_seed(MODEL_ID)
np.random.seed(MODEL_ID)
torch.use_deterministic_algorithms(False)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
train_transform = transforms.Compose([
    transforms.RandomAffine(0, scale=[0.8, 1.2]),
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(30),
    transforms.ToTensor()
])

valid_transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = CIFAR10('.', train=True, transform=train_transform, download=True)
valid_dataset = CIFAR10('.', train=False, transform=valid_transform)
def get_non_iid_indices(labels):
    condition = labels < 5 if MODEL_ID == 1 else labels >= 5
    return condition.nonzero()[0]
def get_iid_indices(labels):
    indices = np.random.permutation(len(labels))
    indices = np.array_split(indices, 2)[MODEL_ID]
    return indices

if SPLIT:
    labels = np.array(train_dataset.targets)
    indices = get_iid_indices(labels)
    train_dataset = Subset(train_dataset, indices)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

Files already downloaded and verified


In [6]:
from torchvision.models import resnet18

class Conv2dLayerNorm(nn.LayerNorm):        
    def forward(self, x):
        transposed = x.transpose(1,3)
        result = super().forward(transposed)
        
        return result.transpose(3,1)
    
def gn(channels):
    return nn.GroupNorm(32, channels)

model = resnet18(num_classes=10,norm_layer=gn).to(device)    

In [7]:
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = SGD(model.parameters(), momentum=0.9, lr=5e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

In [8]:
from torch import nn
loss_fn = nn.CrossEntropyLoss()

In [None]:
import torch
import torch.nn.functional as F
def train():
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        loss = loss_fn(outputs, target)
        loss.backward()
        optimizer.step()
    scheduler.step()
    return loss.item()

def validate():
    model.eval()
    total_samples = 0
    val_score = 0
    for data, target in valid_loader:
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        loss = loss_fn(outputs, target)
        pred = outputs.argmax(dim=1)
        val_score += pred.eq(target).sum().cpu().numpy()
        total_samples += len(target)
    return loss.item(), val_score / total_samples

for epoch in range(NUM_EPOCHS):
    train_loss = train()
    val_loss, acc = validate()
    print(f'Epoch {epoch+1}:\tTrain loss: {train_loss:4f}\tTest loss: {val_loss:4f}\tTest accuracy: {acc:4f}')

Epoch 1:	Train loss: 1.944900	Test loss: 1.895434	Test accuracy: 0.332800
Epoch 2:	Train loss: 1.719854	Test loss: 1.686871	Test accuracy: 0.400700
Epoch 3:	Train loss: 1.642070	Test loss: 1.564700	Test accuracy: 0.439600
Epoch 4:	Train loss: 1.580784	Test loss: 1.511770	Test accuracy: 0.467000
Epoch 5:	Train loss: 1.544345	Test loss: 1.450397	Test accuracy: 0.479900
Epoch 6:	Train loss: 1.353594	Test loss: 1.456356	Test accuracy: 0.489600
Epoch 7:	Train loss: 1.335266	Test loss: 1.436106	Test accuracy: 0.506200
Epoch 8:	Train loss: 1.438748	Test loss: 1.347981	Test accuracy: 0.531200
Epoch 9:	Train loss: 1.280161	Test loss: 1.367846	Test accuracy: 0.532500
Epoch 10:	Train loss: 1.313264	Test loss: 1.364547	Test accuracy: 0.530900
Epoch 11:	Train loss: 1.314696	Test loss: 1.368729	Test accuracy: 0.542500
Epoch 12:	Train loss: 1.176945	Test loss: 1.295809	Test accuracy: 0.563400
Epoch 13:	Train loss: 1.177987	Test loss: 1.300059	Test accuracy: 0.564600
Epoch 14:	Train loss: 1.239163	Tes

In [None]:
import pickle
with open(f'model_{MODEL_ID}.pkl','wb') as f:
    pickle.dump({n: p.numpy(force=True) for n,p in model.state_dict().items()}, f) 

In [None]:
for n,p in model.state_dict().items():
    print(f'{n}: {p.size()}')

In [None]:
torch.tensor([64.3]).numpy(force=True)