In [None]:
%pip install pytorch-nlp

Collecting pytorch-nlp
  Downloading pytorch_nlp-0.5.0-py3-none-any.whl.metadata (9.0 kB)
Downloading pytorch_nlp-0.5.0-py3-none-any.whl (90 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch-nlp
Successfully installed pytorch-nlp-0.5.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch.nn as nn
import torchvision
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from random import choices, sample
import numpy as np
import random
from sklearn.model_selection import train_test_split
import warnings
from torchnlp.nn.weight_drop import WeightDropLinear
from torch.optim import Adam

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
warnings.filterwarnings("ignore")

In [None]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [None]:
train_data = torchvision.datasets.MNIST("./", train=True, download=True)
test_data = torchvision.datasets.MNIST("./", train=False, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 41.7MB/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.14MB/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 8.99MB/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.59MB/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






In [None]:
train_images = train_data.data.unsqueeze(1).float()
train_labels = train_data.targets
test_images = test_data.data.unsqueeze(1).float()
test_labels = test_data.targets

In [None]:
y_train = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in train_labels])
y_test = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in test_labels])

In [None]:
X_train, X_val, y_train, y_val = train_test_split(train_images, y_train, test_size=0.2)
y_val = torch.argmax(y_val, dim=1)

In [None]:
class CNN(nn.Module):

  def __init__(self):
    super().__init__()
    self.dropout = nn.Dropout(0.35)

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=2)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
    self.conv5 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3)
    self.conv6 = nn.Conv2d(in_channels=1024, out_channels=2000, kernel_size=3, padding=1)

    self.bn1 = nn.BatchNorm2d(num_features=64)
    self.bn2 = nn.BatchNorm2d(num_features=128)
    self.bn3 = nn.BatchNorm2d(num_features=256)
    self.bn4 = nn.BatchNorm2d(num_features=512)
    self.bn5 = nn.BatchNorm2d(num_features=1024)
    self.bn6 = nn.BatchNorm2d(num_features=2000)

    self.maxpool = nn.MaxPool2d(kernel_size=2)

  def forward(self, x):

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.dropout(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.dropout(x)
    x = self.conv3(x)
    x = self.bn3(x)

    x = self.maxpool(x)

    x = self.dropout(x)
    x = self.conv4(x)
    x = self.bn4(x)
    x = self.dropout(x)
    x = self.conv5(x)
    x = self.bn5(x)
    x = self.dropout(x)
    x = self.conv6(x)
    x = self.bn6(x)

    x = self.maxpool(x)

    x = self.dropout(x)

    return x


class SubNetwork(nn.Module):

  def __init__(self, in_dim):
    super().__init__()
    self.fc1 = nn.Linear(in_features=in_dim, out_features=512)
    self.bn = nn.BatchNorm1d(num_features=512)
    self.dropout = nn.Dropout(0.5)
    self.fc2 = nn.Linear(512, 10)

  def forward(self, x):
    x = torch.flatten(x, 1, -1)
    x = self.fc1(x)
    x = self.bn(x)
    x = self.dropout(x)

    x = self.fc2(x)
    return x


In [None]:
def switch_model_params_status(model):
  for param in model.parameters():
    param.requires_grad ^= True

In [None]:
cnn = CNN().to(device)
cnn_optimizer = Adam(cnn.parameters(), lr=0.001)
small_subnetworks = [SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device), SubNetwork(7200).to(device)]
small_optimizers = [Adam(small_subnetworks[0].parameters(), lr=0.001), Adam(small_subnetworks[1].parameters(), lr=0.001), Adam(small_subnetworks[2].parameters(), lr=0.001), Adam(small_subnetworks[3].parameters(), lr=0.001), Adam(small_subnetworks[4].parameters(), lr=0.001), Adam(small_subnetworks[5].parameters(), lr=0.001), Adam(small_subnetworks[6].parameters(), lr=0.001), Adam(small_subnetworks[7].parameters(), lr=0.001), Adam(small_subnetworks[8].parameters(), lr=0.001), Adam(small_subnetworks[9].parameters(), lr=0.001)]
cnn_fc = SubNetwork(72000).to(device)
fc_optimizer = Adam(cnn_fc.parameters(), lr=0.001)

In [None]:
EPOCHS = 50
batch_size = 100
train_batches = DataLoader([*zip(X_train, y_train)], batch_size=batch_size, shuffle=True)
val_batches = DataLoader([*zip(X_val, y_val)], batch_size=batch_size, shuffle=True)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
  cnn.train()
  cnn_fc.train()
  for small in small_subnetworks:
    small.train()
  for batch in train_batches:
    features, target = batch[:-1], batch[-1]
    features = features[0].to(device)
    target = target.to(device)
    cnn_optimizer.zero_grad()
    fc_optimizer.zero_grad()
    for opt in small_optimizers:
      opt.zero_grad()
    for model in small_subnetworks:
      switch_model_params_status(model)
    cnn_out = cnn(features)
    cnn_fc_out = cnn_fc(cnn_out)
    perte = loss_fn(cnn_fc_out, target)
    perte.backward()
    cnn_optimizer.step()

    for model in small_subnetworks:
      switch_model_params_status(model)
    switch_model_params_status(cnn)
    temp = cnn_out.detach()

    for i in range(10):
      sub_out = small_subnetworks[i](temp[:, i*200:(i+1)*200])
      perte = loss_fn(sub_out, target)
      perte.backward()
      small_optimizers[i].step()
    switch_model_params_status(cnn)

  with torch.no_grad():
      cnn.eval()
      cnn_fc.eval()
      correct = 0
      for small in small_subnetworks:
        small.eval()
      for batch in val_batches:
        features, target = batch[:-1], batch[-1]
        features = features[0].to(device)
        target = target.to(device)
        val_out = cnn(features.to(device))
        val_out.detach()
        total_loss = 0
        preds = []
        for i in range(10):
          temp = small_subnetworks[i](val_out[:, i*200:(i+1)*200])
          temp = nn.functional.softmax(temp)
          preds.append(temp)
        preds = torch.mode(torch.argmax(torch.stack(preds), dim=2), dim=0).values
        correct += torch.where(preds==target, 1, 0).sum()
      print(f"Epoch {epoch}: {(correct/y_val.shape[0])*100}%")


Epoch 0: 97.2249984741211%
Epoch 1: 98.19166564941406%
Epoch 2: 98.5250015258789%
Epoch 3: 98.5250015258789%
Epoch 4: 98.55833435058594%
Epoch 5: 98.8083267211914%
Epoch 6: 98.61666870117188%
Epoch 7: 98.7249984741211%
Epoch 8: 98.81666564941406%
Epoch 9: 98.8083267211914%
Epoch 10: 98.8499984741211%
Epoch 11: 98.89999389648438%
Epoch 12: 98.98332977294922%
Epoch 13: 98.91666412353516%
Epoch 14: 98.875%
Epoch 15: 98.90833282470703%
Epoch 16: 99.01666259765625%
Epoch 17: 98.89999389648438%
Epoch 18: 99.03333282470703%
Epoch 19: 99.15833282470703%
Epoch 20: 98.94999694824219%
Epoch 21: 98.95832824707031%
Epoch 22: 98.94166564941406%
Epoch 23: 98.9749984741211%
Epoch 24: 99.14999389648438%
Epoch 25: 99.03333282470703%
Epoch 26: 99.06666564941406%
Epoch 27: 99.14166259765625%
Epoch 28: 99.17500305175781%
Epoch 29: 99.25%
Epoch 30: 99.19166564941406%
Epoch 31: 99.01666259765625%
Epoch 32: 99.13333129882812%
Epoch 33: 99.16666412353516%
Epoch 34: 99.06666564941406%
Epoch 35: 99.2333297729492

In [None]:
# y_test = torch.argmax(y_test, dim=1)
test_batches = DataLoader([*zip(test_images, y_test)], batch_size=batch_size, shuffle=True)
with torch.no_grad():
  cnn.eval()
  cnn_fc.eval()
  correct = 0
  for small in small_subnetworks:
    small.eval()
  for batch in test_batches:
    features, target = batch[:-1], batch[-1]
    features = features[0].to(device)
    target = target.to(device)
    test_out = cnn(features.to(device))
    test_out.detach()
    total_loss = 0
    preds = []
    for i in range(10):
      temp = small_subnetworks[i](test_out[:, i*200:(i+1)*200])
      temp = nn.functional.softmax(temp)
      preds.append(temp)
    preds = torch.mode(torch.argmax(torch.stack(preds), dim=2), dim=0).values
    correct += torch.where(preds==target, 1, 0).sum()
  print(f"Test result: {(correct/y_test.shape[0])*100}%")

Test result: 99.30999755859375%
