In [None]:
import torch
from torch.nn import Module
torch.manual_seed(42)
torch.cuda.is_available()

In [None]:
channels_in = 3
num_classes = 10

class Tanh(Module):

  def forward(self, x):
    self.out = torch.tanh(x)
    return self.out

class ConvBlock(Module):
    def __init__(self, channels_in, channels_out, kernel_size=3, stride=1, padding=1, act=torch.tanh):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Conv2d(channels_in, channels_out, kernel_size, stride=stride, padding=padding)
        self.bn = torch.nn.BatchNorm2d(channels_out)
        self.act = act

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.act is not None:
            self.save = self.act(x)
            return self.save
        return x
    
    def parameters(self):
        return self.conv.parameters() + self.bn.parameters()


class ResidualBlock(Module):
    def __init__(self, channels_in, channels_out, kernel_size=3, stride=1, padding=1):
        self.conv1 = ConvBlock(channels_in, channels_out, kernel_size, stride, padding)
        self.conv2 = ConvBlock(channels_out, channels_out, kernel_size, stride, padding)

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        return x + y
    
    def parameters(self):
        return self.conv1.parameters() + self.conv2.parameters()
    
#model = torch.nn.Sequential(
#    ConvBlock(channels_in, 32, stride = 2), # 3x32x32 -> 32x16x16
#    ConvBlock(32, 64), # 32x16x16 -> 64x16x16
#    ConvBlock(64, 128, stride = 2), # 64x16x16 -> 128x8x8
#    ConvBlock(128, 256), # 128x8x8 -> 256x8x8
#    ConvBlock(256, num_classes, kernel_size=1, padding=0, stride =2, act=None), # 256x8x8 -> 10x4x4
#    torch.nn.AvgPool2d(4)).cuda() # 10x4x4 -> 10x1x1

model = torch.nn.Sequential(
    torch.nn.Linear(32*32*3, 1024), torch.nn.BatchNorm1d(1024), Tanh(), # 3x32x32 -> 32x16x16
    torch.nn.Linear(1024, 512), torch.nn.BatchNorm1d(512), Tanh(), # 3x32x32 -> 32x16x16
    torch.nn.Linear(512, 256), torch.nn.BatchNorm1d(256), Tanh(), # 3x32x32 -> 32x16x16
    torch.nn.Linear(256, 10)).cuda() # 10x4x4 -> 10x1x1

In [None]:
import torchvision
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

batch_size = 128
val_batch_size = 128
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criterion = torch.nn.CrossEntropyLoss()

train_loss_log = []

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
def eval_net(model, dataloader):
  accuracy_tensor = torch.Tensor([]).cuda()
  loss_log = []

  with torch.no_grad():
    for i, batch in enumerate(dataloader):
        x, y = batch
        x, y = x.cuda(), y.cuda()
        y_pred = model(x.view(x.shape[0], -1))

        loss = criterion(y_pred.view(y_pred.shape[0], -1), y)
        loss_log.append(loss)

        y_pred = torch.argmax(y_pred.view(y_pred.shape[0], -1), dim=1)
        accuracy_tensor = torch.cat((accuracy_tensor, (y_pred==y)), dim=0)
  #print(f"Accuracy: {accuracy_tensor.sum()/len(accuracy_tensor):.2}")
  #print(f"Val loss: {torch.Tensor(loss_log).mean():.4}")    
  return accuracy_tensor.sum()/len(accuracy_tensor), loss.mean()

In [None]:
train_loss_log = []
val_acc_log = []
val_loss_log = []
for _ in tqdm(range(2)):
  # single epoch
  train_loss = []
  for i, batch in enumerate(train_dataloader):
      x, y = batch

      y_pred = model(x.view(x.shape[0], -1).cuda())
      loss = criterion(y_pred.view(y_pred.shape[0], -1), y.cuda())
      train_loss.append(loss.item())
      model.zero_grad()
      loss.backward()

      optimizer.step()
  val_acc, val_loss = eval_net(model, test_dataloader)
  val_acc_log.append(val_acc.cpu().item()), val_loss_log.append(val_loss.cpu().item())
  train_loss_log.append(torch.Tensor(train_loss).mean().cpu().item())
  print(f"Train loss: {train_loss_log[-1]:.4}")
  print(f"Val loss: {val_loss_log[-1]:.4}") 
  print(f"Accuracy: {val_acc_log[-1]:.2}")

In [None]:
import matplotlib.pyplot as plt

#plt.figure()
plt.plot([x for x in range(len(train_loss_log))], train_loss_log, )
plt.plot([x for x in range(len(val_loss_log))], val_loss_log)
plt.plot([x for x in range(len(val_acc_log))], val_acc_log)

In [None]:
hy, hx = torch.histogram(model[1].out.flatten().cpu(), density=True)
plt.plot(hx[:-1], hy)

In [None]:
for i, layer in enumerate(model):
  print(layer)
  if isinstance(layer, torch.nn.Linear):
    conv_grad = layer.weight.grad.flatten().cpu().detach()
    print(f"layer {i}, mean: {conv_grad.mean():.4}, std: {conv_grad.std()}")
    hy, hx = torch.histogram(conv_grad, density=True)
    plt.plot(hx[:-1], hy, label=f"layer {i}")
plt.legend()

In [None]:
for i, layer in enumerate(model[:-1]):
  print(layer)
  if isinstance(layer, Tanh):
    conv_grad = layer.out.flatten().cpu().detach()
    print(f"layer {i}, mean: {conv_grad.mean():.4}, std: {conv_grad.std()} sat: {(conv_grad.abs() > 0.95).sum()/len(conv_grad)}")
    hy, hx = torch.histogram(conv_grad, density=True)
    plt.plot(hx[:-1], hy, label=f"layer {i}")
plt.legend()

In [None]:
#hy, hx = torch.histogram(model[1].save.flatten().cpu(), density=True)
#plt.plot(hx[:-1], hy)
model[1].bn.weight