# Notebook to test different weight initialization methods

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
import torch.nn as nn

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 256

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 74703575.20it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ResidualBlock, self).__init__()
        self.conv_res1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, stride=stride, bias=False)
        self.conv_res1_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
        self.conv_res2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, bias=False)
        self.conv_res2_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x

        out = self.relu(self.conv_res1_bn(self.conv_res1(x)))
        out = self.conv_res2_bn(self.conv_res2(out))
        out = self.relu(out)
        out = residual + out
        return out


class ResNet(nn.Module):
    """
    A Residual network.
    """
    def __init__(self):
        super(ResNet, self).__init__()

        self.conv = nn.Sequential(
            #size (3,32,32)
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            #size (64,32,32)
            nn.BatchNorm2d(num_features=64, momentum=0.9),
            nn.ReLU(inplace=True),
            #size (64,32,32)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.ReLU(inplace=True),
            #size (128,32,32)
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (128,16,16)
            ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            #size (128,16,16)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            #size (256,16,16)
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,8,8)
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,4,4)
            ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            #size (256,4,4)
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,2,2)
        )

        self.fc = nn.Linear(in_features=1024, out_features=10, bias=True)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(-1, out.shape[1] * out.shape[2] * out.shape[3])
        out = self.fc(out)
        return out

In [None]:
#Removed normalization layers

class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ResidualBlock, self).__init__()
        self.conv_res1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, stride=stride, bias=False)
        self.conv_res2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, bias=False)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x

        out = self.relu(self.conv_res1(x))
        out = self.conv_res2(out)
        out = self.relu(out)
        out = residual + out
        return out


class ResNet(nn.Module):
    """
    A Residual network.
    """
    def __init__(self):
        super(ResNet, self).__init__()

        self.conv = nn.Sequential(
            #size (3,32,32)
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            #size (64,32,32)
            nn.ReLU(inplace=True),
            #size (64,32,32)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            #size (128,32,32)
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (128,16,16)
            ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            #size (128,16,16)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            #size (256,16,16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,8,8)
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,4,4)
            ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            #size (256,4,4)
            nn.MaxPool2d(kernel_size=2, stride=2),
            #size (256,2,2)
        )

        self.fc = nn.Linear(in_features=1024, out_features=10, bias=True)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(-1, out.shape[1] * out.shape[2] * out.shape[3])
        out = self.fc(out)
        return out

In [None]:
import torch.optim as optim

In [None]:
import numpy as np
def get_accuracy(model,loader,dataset,criterion):
  with torch.no_grad():
    pred = []
    all_labels = []
    losses = []
    for data in loader:
      inputs, labels = data
      if torch.cuda.is_available():
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
      outputs = model(inputs)
      losses.append(criterion(outputs,labels).item())
      batch_pred = torch.argmax(outputs,dim=1)
      pred.append(batch_pred.detach())
      all_labels.append(labels.detach())
      del outputs
      del batch_pred
      del inputs
      del labels


    pred = torch.concat(pred).to('cpu')
    all_labels = torch.concat(all_labels).to('cpu')
    loss = np.mean(losses)
    accuracy = ((all_labels==pred).sum()/len(all_labels)).item()
    return accuracy,loss

In [None]:
import math
def init_network_kaiming_in(gain=1):
  torch.manual_seed(0)
  resnet = ResNet()

  conv_layers = [resnet.conv[0],resnet.conv[3],resnet.conv[7].conv_res1,resnet.conv[7].conv_res2,
                resnet.conv[8],resnet.conv[12],resnet.conv[16].conv_res1,resnet.conv[16].conv_res2,
                ]


  # Initialization strategy

  for layer in conv_layers:
    nn.init.kaiming_uniform_(layer.weight,mode='fan_in',nonlinearity='relu')
    layer.weight.data = layer.weight.detach()/math.sqrt(6)
    layer.weight.data = layer.weight.detach()*gain
  return resnet

In [None]:
from sklearn.decomposition import PCA
from scipy.linalg import solve_sylvester
import numpy as np

# Sylvester Data driven initialization
# Based on Das et al. (2021)
def sylvester_init(inputs,c_out,kernel_size=3,alpha=10,latent_representation = 'PCA'):
  """
  inputs : tensor of size (n,c_in,im_size,im_size)
  c_out : # of output channels
  alpha : Weight of the encoding error, see Das et al. paper
  """
  n,c_in,im_size,_ = inputs.size()
  #Randomly picks parts of pictures in the inputs
  n_p = int(3e3)
  X = torch.tensor([])
  for k in range(n_p):
    image = np.random.randint(n)
    point = np.random.randint(im_size-kernel_size+1)
    part = inputs.detach()[image,:,point:point+kernel_size,point:point+kernel_size].unsqueeze(0)
    X = torch.cat([X,part])
  # From size (n_p,c_in,kernel_size,kernel_size) to (n_p,c_in*kernel_size*kernel_size)
  X = X.view((-1,c_in*kernel_size*kernel_size))

  if latent_representation == 'PCA':
    pca = PCA(c_out)
    S = torch.tensor(pca.fit_transform(X),dtype=torch.float)
  else:
    # Other methods are possible
    return
  A = S.T@S #size (c_out,c_out)
  A = A.numpy()
  B = alpha * X.T@X #size (c_in*kernel_size*kernel_size,c_in*kernel_size*kernel_size)
  B = B.numpy()
  C = (1+alpha) * S.T@X #size (c_out,c_in*kernel_size*kernel_size)
  C = C.numpy()

  # Solve equation : AW + WB = C
  W = solve_sylvester(A,B,C)
  W = torch.tensor(W).view(c_out,c_in,kernel_size,kernel_size)
  return W

In [None]:
from tqdm import tqdm
def init_network_sylvester(alpha=10):
  torch.manual_seed(0)
  np.random.seed(0)
  resnet = ResNet()

  # Use this when batch normalization layers are used
  # conv_layers = [resnet.conv[0],resnet.conv[3],resnet.conv[7].conv_res1,resnet.conv[7].conv_res2,
  #               resnet.conv[8],resnet.conv[12],resnet.conv[16].conv_res1,resnet.conv[16].conv_res2,
  #               ]
  # Use this when batch normalization layers are not used
  conv_layers = [resnet.conv[0],resnet.conv[2],resnet.conv[5].conv_res1,resnet.conv[5].conv_res2,
                resnet.conv[6],resnet.conv[9],resnet.conv[12].conv_res1,resnet.conv[12].conv_res2,
                ]

  # Initialization strategy
  inputs=torch.tensor([])
  for i,batch in enumerate(trainloader):
    inputs = torch.cat([inputs,batch[0]])
    break


  for i,layer in enumerate(tqdm(conv_layers)):
    # c_out/c_in too large for first layer
    if i > 0:
      c_out = layer.weight.size(0)
      W = sylvester_init(inputs,c_out,alpha=alpha)
      layer.weight.data = W.detach()
    inputs = layer(inputs)
  return resnet

In [None]:
resnet = ResNet()
resnet.conv

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): ResidualBlock(
    (conv_res1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv_res2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (relu): ReLU()
  )
  (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (7): ReLU(inplace=True)
  (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): ReLU(inplace=True)
  (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (12): ResidualBlock(
    (conv_res1): Conv2d(256, 2

In [None]:
def train_model(model,n_epoch=50,verbose=True):

  if torch.cuda.is_available():
    resnet.to('cuda')
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(resnet.parameters(), lr=0.001, weight_decay=1e-5)
  memory_running_loss = []
  memory_train_loss = []
  memory_test_loss = []
  memory_train_accuracy = []
  memory_test_accuracy = []
  for epoch in range(n_epoch):  # loop over the dataset multiple times

      running_loss = 0.0
      for i, data in enumerate(trainloader, 0):
          # get the inputs; data is a list of [inputs, labels]
          inputs, labels = data
          if torch.cuda.is_available():
            inputs = inputs.to('cuda')
            labels = labels.to('cuda')
          # zero the parameter gradients
          optimizer.zero_grad()

          # forward + backward + optimize
          outputs = resnet(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          # print statistics
          running_loss += loss.item()
          if i % 49 == 48:    # print every 30 mini-batches
              if verbose>1:
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 49:.3f}')
              with torch.no_grad():
                memory_running_loss.append(running_loss)
              running_loss = 0.0
      with torch.no_grad():
        train_accuracy,train_loss = get_accuracy(resnet,trainloader,trainset,criterion)
        memory_train_accuracy.append(train_accuracy)
        memory_train_loss.append(train_loss)
        test_accuracy,test_loss = get_accuracy(resnet,testloader,testset,criterion)
        memory_test_accuracy.append(test_accuracy)
        memory_test_loss.append(test_loss)
        if verbose:
          print(f"\nepoch : {epoch+1}, train accuracy : {train_accuracy}, test accuracy : {test_accuracy}\n")
          print(f"train loss : {train_loss}, test loss : {test_loss}\n")
  return memory_running_loss,memory_train_accuracy,memory_test_accuracy,memory_train_loss,memory_test_loss

In [None]:
from tqdm import tqdm
import numpy as np
tested_gains = [1/100,1/10,1/math.sqrt(6),1/math.sqrt(3),1,math.sqrt(3),math.sqrt(6),10,100]

for gain in tqdm(tested_gains):
  resnet = init_network_kaiming_in(gain=gain)
  memory_running_loss,memory_train_accuracy,memory_test_accuracy,memory_train_loss,memory_test_loss = train_model(resnet,n_epoch=20,verbose=1)

  np.save(f"./results2/running_loss_{gain:.3f}",memory_running_loss)
  np.save(f"./results2/train_accuracy_{gain:.3f}",memory_train_accuracy)
  np.save(f"./results2/test_accuracy_{gain:.3f}",memory_test_accuracy)
  np.save(f"./results2/train_loss_{gain:.3f}",memory_train_loss)
  np.save(f"./results2/test_loss_{gain:.3f}",memory_test_loss)


In [None]:
import shutil
shutil.make_archive("./outputs saves 2", 'zip', './results2')

'/content/outputs saves 2.zip'

In [None]:
for alpha in [0.1,1,10]:
  resnet = init_network_sylvester(alpha=alpha)
  memory_running_loss,memory_train_accuracy,memory_test_accuracy,memory_train_loss,memory_test_loss = train_model(resnet,n_epoch=20,verbose=1)
  np.save(f"./results3/running_loss_{alpha:.1f}",memory_running_loss)
  np.save(f"./results3/train_accuracy_{alpha:.1f}",memory_train_accuracy)
  np.save(f"./results3/test_accuracy_{alpha:.1f}",memory_test_accuracy)
  np.save(f"./results3/train_loss_{alpha:.1f}",memory_train_loss)
  np.save(f"./results3/test_loss_{alpha:.1f}",memory_test_loss)

100%|██████████| 8/8 [02:43<00:00, 20.45s/it]



epoch : 1, train accuracy : 0.3946399986743927, test accuracy : 0.39800000190734863

train loss : 1.683637232804785, test loss : 1.6699026644229888


epoch : 2, train accuracy : 0.5346400141716003, test accuracy : 0.5200999975204468

train loss : 1.3310637352417927, test loss : 1.3460669189691543


epoch : 3, train accuracy : 0.6069999933242798, test accuracy : 0.588100016117096

train loss : 1.114758890806412, test loss : 1.162567573785782


epoch : 4, train accuracy : 0.6823400259017944, test accuracy : 0.6448000073432922

train loss : 0.9008758363066888, test loss : 1.001242396235466


epoch : 5, train accuracy : 0.7673599720001221, test accuracy : 0.704200029373169

train loss : 0.6807498782873154, test loss : 0.8473420888185501


epoch : 6, train accuracy : 0.8090400099754333, test accuracy : 0.7335000038146973

train loss : 0.5620736262323905, test loss : 0.7870441183447838


epoch : 7, train accuracy : 0.8418400287628174, test accuracy : 0.7439000010490417

train loss : 0.46804

FileNotFoundError: ignored

In [None]:
# Initial network (before changes from sylvester)

torch.manual_seed(0)
np.random.seed(0)
resnet = ResNet()
memory_running_loss,memory_train_accuracy,memory_test_accuracy,memory_train_loss,memory_test_loss = train_model(resnet,n_epoch=20,verbose=1)
np.save(f"./results3/running_loss_init",memory_running_loss)
np.save(f"./results3/train_accuracy_init",memory_train_accuracy)
np.save(f"./results3/test_accuracy_init",memory_test_accuracy)
np.save(f"./results3/train_loss_init",memory_train_loss)
np.save(f"./results3/test_loss_init",memory_test_loss)

In [None]:
import shutil
shutil.make_archive("./outputs saves 3", 'zip', './results3')

'/content/outputs saves 3.zip'