In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pickle
from tqdm import trange
try:
  from helpers import get_model_size, estimate_loss, get_parameters, CiFaData
except ModuleNotFoundError:
  import sys
  sys.path.append("../")
  from helpers import get_model_size, estimate_loss, get_parameters, CiFaData

torch.manual_seed(42)

In [None]:
EPOCHS = 100
BATCH_SIZE = 256
LR = 0.1
MOMENTUM = 0.875
WEIGHT_DECAY = 0.00125

In [None]:
# we actually just need it to download cifar dataset
# torchvision.datasets.CIFAR10(train=True, download=True, root='../data/', transform=transforms.ToTensor())
# torchvision.datasets.CIFAR10(train=False, download=True, root='../data/', transform=transforms.ToTensor())

tf = transforms.Compose([transforms.RandomResizedCrop((32,32)), 
                         transforms.RandomHorizontalFlip(p=0.58)]) 
                        #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create loader to get the params
# complete_ds = CiFaData(stage="all")
# big_loader = DataLoader(complete_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=14)
# params = get_parameters(big_loader)

params = torch.tensor([0.4919, 0.4827, 0.4472]), torch.tensor([0.2470, 0.2434, 0.2616])
print(f"normalized parameters of the dataset: {params}")

train_ds = CiFaData(stage="train", transform=tf, dataset_params=params)
val_ds = CiFaData(stage="val", dataset_params=params)
test_ds = CiFaData(stage="test", dataset_params=params)

# pinning memory, takes cpu data and pins it to the gpu.
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=14, pin_memory=True) 
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=14, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=14, pin_memory=True)

ResNet 18

In [None]:
# from the paper: 
# We adopt batch normalization (BN) [16] right after each convolution and
# before activation, following [16].

class PrepBlock(nn.Module):
  # fixed channels for cifar
  def __init__(self):
    super().__init__()
    self.prep_block = nn.Sequential(
      nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )
  def forward(self, x):
    return self.prep_block(x)
  def init_weights(self):
    for layer in self.prep_block:
      if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
        if layer.bias is not None:
          layer.bias.data.zero_()

class ComputeBlock(nn.Module):
  def __init__(self, inchannels, outchannels, stride, downsample=None):
    super().__init__()
    self.convblock = nn.Sequential(
      nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3, padding=1, stride=stride, bias=False),
      nn.BatchNorm2d(outchannels),
      nn.ReLU(),
      nn.Conv2d(in_channels=outchannels, out_channels=outchannels, kernel_size=3, padding=1, stride=1, bias=False),
      nn.BatchNorm2d(outchannels)
    )
    self.downsample = downsample
    if not inchannels == outchannels:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=1, stride=2, bias=False),
        nn.BatchNorm2d(outchannels)
      )
    self.relu = nn.ReLU()
  def forward(self, x):
    x_skip = x 
    x = self.convblock(x)
    if self.downsample:
      x_skip = self.downsample(x_skip)
    out = self.relu(x+x_skip)
    return out
  def init_weights(self):
    for layer in self.convblock:
      if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
        if layer.bias is not None:
          layer.bias.data.zero_()
    if self.downsample:
      for layer in self.downsample:
        if isinstance(layer, nn.Conv2d):
          nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
          if layer.bias is not None:
            layer.bias.data.zero_()
  
class ResNet18New(nn.Module):
  def __init__(self):
    super().__init__()
    self.resnet = nn.Sequential(
      PrepBlock(), # (B,64,8,8)
      ComputeBlock(64, 64, stride=1), # (B,64,8,8)
      ComputeBlock(64,128, stride=2), # (B,128,4,4)
      ComputeBlock(128,256, stride=2), # (B,256,2,2)
      ComputeBlock(256, 512, stride=2), # (B,512,1,1)
      nn.AdaptiveAvgPool2d((1,1)),
      nn.Flatten(start_dim=1), # (B,512)
      nn.Linear(512, 10) # (B, 10)
    )
  def forward(self, x):
    return self.resnet(x)
  def init_weights(self):
    for module in self.modules():
      if isinstance(module, ComputeBlock):
        module.init_weights()
      elif isinstance(module, PrepBlock):
        module.init_weights()
      elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')

In [None]:
model = ResNet18New()
model.init_weights()
model.to(device)

# optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad==True], lr=lr)
optimizer = optim.SGD(params=model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, momentum=MOMENTUM)
scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

model_size = get_model_size(model)

# training loop
losses = []
raw_losses = []
val_losses = []

for i in (t:=trange(EPOCHS)):
  model.train()
  epoch_loss = []
  for step, (x, y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    predictions = model(x)
    loss = criterion(predictions, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    raw_losses.append(loss.item())
    epoch_loss.append(loss.item())

  # only one per iteration
  losses.append(np.mean(epoch_loss))
  val_loss, val_acc = estimate_loss(model, val_loader, criterion, device)
  val_losses.append(val_loss)
  # scheduler.step()
  t.set_description(f"epoch {i+1} | training loss: {losses[-1]:.4f} | validation loss: {val_losses[-1]:.4f} | current lr: {optimizer.param_groups[0]['lr']:.6f}")
  
# test_loss = estimate_loss(model, test_loader, criterion, device) 
# print(f'test loss : {test_loss}')


plt.figure()
plt.title(f'batchnorm  lr={LR}')
plt.plot(range(EPOCHS), losses, label='training')
plt.plot(range(EPOCHS), val_losses, label='validation')
plt.plot(range(EPOCHS), [np.min(val_losses)]*EPOCHS, color='r', label=f'minimum val loss at epoch {np.argmin(val_losses)+1}')
plt.legend()
plt.show()

In [None]:
# inspect the graph
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("torchlogs/")
writer.add_graph(model, x)
writer.close()

In [None]:
with open('../data/cifar-10-batches-py/batches.meta', 'rb') as f:
  meta = pickle.load(f)
meta['label_names']

In [None]:
n = 9
input_ = x[n].cpu().permute(1,2,0).numpy()

plt.title(meta['label_names'][y[n]])
plt.imshow(input_)

# todos:
## extract feature detection layers
## increase size: make a resnet 50
### add bottlenecks

# visualize feature maps

we can loop over the elements with model.children() or just address individual layers like:
model.block0[n]; you can go down until you hit a 'Sequential' block and then go on slicing


seems like a good guide: 
https://ravivaishnav20.medium.com/visualizing-feature-maps-using-pytorch-12a48cd1e573

In [None]:
weights = res18.block4.block[1].block[0].weight.detach().clone()
print(weights.shape)
weights = normalize_tensor(weights)
filter_img = torchvision.utils.make_grid(weights, nrow=int(np.sqrt(weights.shape[0])))
plt.imshow(filter_img.permute(1,2,0))

In [None]:
def visTensor(tensor, ch=0, allkernels=False, nrow=8, padding=1): 
    n,c,w,h = tensor.shape

    if allkernels: tensor = tensor.view(n*c, -1, w, h)
    elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)

    rows = np.min((tensor.shape[0] // nrow + 1, 64))    
    grid = torchvision.utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding)
    plt.figure( figsize=(nrow,rows) )
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

In [None]:
layer = res18.block0[0].weight.detach().clone()
visTensor(layer)
plt.axis('off')
plt.ioff()
plt.show()

In [None]:
# first_block = res18.block0[0].weight.detach().clone()
first_block = res18.block1.block[0].weight.detach.clone()
print(first_block.shape)
first_block =normalize_tensor(first_block)
filter_img = torchvision.utils.make_grid(first_block, nrow=int(np.sqrt(first_block.shape[0])))
plt.axis('off')
plt.ioff()
plt.imshow(filter_img.permute(1,2,0))
plt.show()

In [None]:
weights = []
conv_layers = []

model_children = list(res18.children())

cnt = 0

for i in range(len(model_children)):
  # this only counts the shape shifter-convs! - need to go into the sub blocks
  if type(model_children[i]) == nn.Conv2d:
    cnt +=1

print(cnt)