In [1]:
from torchvision.datasets import CIFAR10
import numpy as np
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SPLIT=False
MODEL_ID=2
BATCH_SIZE=2**10
SEED=1

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.use_deterministic_algorithms(False)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
train_dataset = CIFAR10('.', train=True, transform=ToTensor(), download=True)
valid_dataset = CIFAR10('.', train=False, transform=ToTensor())

if SPLIT:
    labels = np.array(train_dataset.targets)
    condition = labels < 5 if MODEL_ID == 1 else labels >= 5
    train_dataset = Subset(train_dataset, condition.nonzero()[0])


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

Files already downloaded and verified


In [6]:
from torchvision.models import resnet18

class Conv2dLayerNorm(nn.LayerNorm):        
    def forward(self, x):
        transposed = x.transpose(1,3)
        result = super().forward(transposed)
        
        return result.transpose(3,1)

model = resnet18(num_classes=10, norm_layer=Conv2dLayerNorm).to(device)            

In [7]:
from torch.optim import SGD
optimizer = SGD(model.parameters(), lr=1e-3)

In [8]:
from torch import nn
loss_fn = nn.CrossEntropyLoss()

In [9]:
import torch
import torch.nn.functional as F
def train():
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        loss = loss_fn(outputs, target)
        loss.backward()
        optimizer.step()
    return loss.item()

def validate():
    model.eval()
    total_samples = 0
    val_score = 0
    for data, target in valid_loader:
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        pred = outputs.argmax(dim=1)
        val_score += pred.eq(target).sum().cpu().numpy()
        total_samples += len(target)
    return val_score / total_samples

for epoch in range(20):
    loss = train()
    acc = validate()
    print(f'Epoch {epoch+1}:\tTrain loss:{loss:4f}\tTest accuracy:{acc:4f}')

Epoch 1:	Train loss:2.189628	Test accuracy:0.241500
Epoch 2:	Train loss:2.057933	Test accuracy:0.273200
Epoch 3:	Train loss:2.004574	Test accuracy:0.295100
Epoch 4:	Train loss:1.841234	Test accuracy:0.341200
Epoch 5:	Train loss:1.825524	Test accuracy:0.359700
Epoch 6:	Train loss:1.842961	Test accuracy:0.343200
Epoch 7:	Train loss:1.855386	Test accuracy:0.352000
Epoch 8:	Train loss:1.798304	Test accuracy:0.356200
Epoch 9:	Train loss:1.744208	Test accuracy:0.372300
Epoch 10:	Train loss:1.741127	Test accuracy:0.386500
Epoch 11:	Train loss:1.637662	Test accuracy:0.407900
Epoch 12:	Train loss:1.650924	Test accuracy:0.408700
Epoch 13:	Train loss:1.670300	Test accuracy:0.417500
Epoch 14:	Train loss:1.588423	Test accuracy:0.424600
Epoch 15:	Train loss:1.631279	Test accuracy:0.414400
Epoch 16:	Train loss:1.539430	Test accuracy:0.433900
Epoch 17:	Train loss:1.470001	Test accuracy:0.452900
Epoch 18:	Train loss:1.606215	Test accuracy:0.408000
Epoch 19:	Train loss:1.529489	Test accuracy:0.463200
Ep

In [10]:
import pickle
with open(f'model_{MODEL_ID}.pkl','wb') as f:
    pickle.dump({n: p.numpy(force=True) for n,p in model.named_parameters()}, f)    