In [None]:
%load_ext autoreload
%autoreload 2
import tinygrad
from tinygrad.nn import Conv2d, BatchNorm2d, Tensor, optim
import tinygrad.nn as nn
from tinygrad.nn.state import get_state_dict, get_parameters
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import trange
import matplotlib.pyplot as plt
import pickle
import numpy as np
try:
  from helpers import get_model_size, estimate_loss, normalize_tensor
except ModuleNotFoundError:
  import sys
  sys.path.append("../")
  from helpers import get_model_size, estimate_loss, normalize_tensor

In [None]:
class CiFaData(Dataset):
  def __init__(self, stage="train", transform=None, device="cpu"):
    self.device = device
    self.base_folder = "cifar-10-batches-py"
    self.transform = transform
    if stage == "train":
      batch_collection = [f"data_batch_{i}" for i in range(1, 5)]
    elif stage == "val":
      batch_collection = ["data_batch_5"]
    elif stage == "test":
      batch_collection = ["test_batch"]
    else:
      raise ValueError("Invalid stage, choose from train, val, test.")
    self.x_data = []
    self.y_data = []
    for batch in batch_collection:
      with open(f"../data/cifar-10-batches-py/{batch}", "rb") as f:
        data = pickle.load(f, encoding="latin1") 
        self.x_data.extend(data["data"])
        self.y_data.extend(data["labels"])
    self.y_data = Tensor(self.y_data, device=self.device)
    self.x_data = normalize_tensor(Tensor(np.vstack(self.x_data).reshape(-1, 3, 32, 32))) # from list to vstack; results in (N, 3, 32, 32)
  def __len__(self):
    return self.y_data.shape[0]
  def __getitem__(self, idx):
    return self.x_data[idx], self.y_data[idx]
  

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_ds = CiFaData(stage="train", device=device)
val_ds = CiFaData(stage="val", device=device)
test_ds = CiFaData(stage="test", device=device)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True) 
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

In [None]:
for m in train_loader:
  print(m[0])
  break

In [None]:
class InitBlock:
  def __init__(self):
    self.conv1 = nn.Conv2d(in_channels=3, kernel_size=7, out_channels=64, stride=2, padding=3)
    self.bnorm = nn.BatchNorm2d(64)
  def __call__(self, x):
    x = self.conv1(x).relu().max_pool2d()
    return x

class SubBlock:
  def __init__(self, inchannels, outchannels, stride, kernelsize=3, padding=1):
    self.conv1 = nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=kernelsize, padding=padding, stride=stride)
    self.bnorm1 = nn.BatchNorm2d(outchannels)
    self.conv2 = nn.Conv2d(in_channels=outchannels, out_channels=outchannels, kernel_size=kernelsize, padding=padding, stride=1)
    self.bnorm2 = nn.BatchNorm2d(outchannels)
  def __call__(self, x):
    x = self.conv1(x)
    x = self.bnorm1(x).relu()
    x = self.conv2(x)
    x = self.bnorm2

class Resblock:
  def __init__(self, inchannels, outchannels, stride):
    self.block1 = SubBlock(inchannels, outchannels, stride)
    self.block2 = SubBlock(outchannels, outchannels, stride=1)
  def __call__(self, x):
    x = self.block1(x)
    x = self.block2(x)
    return x

In [None]:
class ResNet18:
  def __init__(self):
    self.block0 = InitBlock()
    self.block1 = Resblock(64, 64, stride=1)
    self.matchdim1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=2)
    self.block2 = Resblock(64, 128, stride=2)
    self.matchdim2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=2) 
    self.block3 = Resblock(128, 256, stride=2)
    self.matchdim3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=2)
    self.block4 = Resblock(256, 512, stride=2)
    self.fc = nn.Linear(512, 10)
  def __call__(self, x):
    x_skip = self.block0(x).relu().max_pool2d(kernel_size=3, stride=2) # no padding option
    x = self.block1(x_skip)
    x = (x+x_skip).relu()
    x_skip = self.matchdim1(x)
    x = self.block2(x_skip)
    x = (x+x_skip).relu()
    x_skip = self.matchdim2(x)
    x = self.block3(x)
    x = (x+x_skip).relu()
    x_skip = self.matchdim3(x)
    x = self.block4(x_skip)
    x = (x+x_skip).relu().avg_pool2d().flatten(start_dim=1)
    return self.fc(x)

In [None]:
epochs = 10
lr = 1e-5
res18 = ResNet18()
params = get_parameters(res18)
# optimizer = optim.SGD(params=[p for p in res18.parameters() if p.requires_grad == True], momentum=0.9, lr=0.05)
optimizer = optim.AdamW(params=params, lr=lr)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', patience=2)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# training loop
losses = []
raw_losses = []
val_losses = []

for i in (t:=trange(epochs)):
  epoch_loss = []
  for step, (x, y) in enumerate(train_loader):
    predictions = res18(x)
    loss = criterion(predictions, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    raw_losses.append(loss.item())
    epoch_loss.append(loss.item())

  # only one per iteration
  losses.append(np.mean(epoch_loss))
  val_losses.append(estimate_loss(res18, val_loader, criterion))
  # scheduler.step(metrics=val_losses[-1])
  t.set_description(f"epoch {i+1} | training loss: {losses[-1]:.4f} | validation loss: {val_losses[-1]:.4f} | current lr: {optimizer.param_groups[0]['lr']:.6f}")
  
test_loss = estimate_loss(res18, test_loader, criterion) 
print(f'final test loss is : {test_loss}')