In [1]:
import torch
import torchvision
from torch import nn


In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,kernel_size)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU6(inplace=True)
        self.block = nn.Sequential(
            self.conv, self.bn, self.relu
        )
    def forward(self, x):
        out = self.block(x)
        return out
def testConvBlock():
  blk = ConvBlock(1, 32, 3)
  x = torch.zeros(1,1,28,28)
  out = blk(x)
  print(out.shape)
testConvBlock()

torch.Size([1, 32, 26, 26])


In [3]:
class MnistCNN(nn.Module):
    def __init__(self, channels, kernel_size, fc_in):
        super().__init__()
        layer = [
            ConvBlock(in_channels= 1 if i == 0 else channels[i - 1],
                      out_channels = c , kernel_size=kernel_size)
            for i , c in enumerate(channels)
        ]
        layer.append(nn.Flatten())
        layer.append(nn.Linear(fc_in, 10))
        layer.append(nn.LazyBatchNorm1d(10))
        self.net = nn.Sequential(*layer)
    def forward(self, x):
        out = self.net(x)
        return out

In [4]:
def testMnistCNN():
  channels = [16*(i+1) for i in range(10)]
  kernel_size = 3
  fc_in = 10240 # in paper M3 is 11264, M5 is 10240 (detect paper error)
  net = MnistCNN(channels, kernel_size, fc_in)
  x = torch.zeros(10,1,28,28)
  out = net(x)
  print(out.shape)

testMnistCNN()

torch.Size([10, 10])




In [5]:
def get_datasets():
    from torchvision import transforms as T
    train_transform = T.Compose([
        T.RandomAffine(20, translate=(0.2, 0.2), scale=(0.9, 1.1), shear=10),
        T.ToTensor()
    ])
    test_transform = T.Compose([
        T.ToTensor()
    ])
    train_ds = torchvision.datasets.MNIST("./data", train=True, download= True, transform= train_transform)
    test_ds = torchvision.datasets.MNIST("./data", train=False, download=True, transform=test_transform)
    return dict(train = train_ds, test = test_ds)
ds = get_datasets()

In [6]:
def get_dataloaders(datasets):
    train_loader = torch.utils.data.DataLoader(
        dataset=datasets["train"], batch_size=64, shuffle=True, num_workers=2 , drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset = datasets["test"], batch_size=64, shuffle=False, num_workers=2, drop_last=False
    )
    return dict(train=train_loader, test=test_loader)
loaders = get_dataloaders(ds)

In [7]:
for x, y in loaders['train']:
  print(x.shape, y.shape)
  break

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [8]:
from tqdm import tqdm

In [10]:
## training & evaluation code
def train_model_and_evaluate(n_epoch=1):
  device = "cuda"

  channels = [16*(i+1) for i in range(10)]
  kernel_size = 3
  fc_in = 10240 # in paper M3 is 11264, M5 is 10240 (detect paper error)
  net = MnistCNN(channels, kernel_size, fc_in)
  net = net.to(device)

  learning_rate = 1e-3
  weight_decay = 1e-4
  loss_func = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(
      net.parameters(), lr=learning_rate, weight_decay=weight_decay)
  lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

  def train_epoch():
    net.train()
    total = 0
    total_correct = 0
    for step, (images, labels) in enumerate(loaders['train']):
      images = images.to(device)
      labels = labels.to(device)
      out = net(images)
      loss = loss_func(out, labels)

      ypred = torch.argmax(out, dim=1)
      batch_correct = torch.sum(ypred==labels)
      total += len(labels)
      total_correct += batch_correct.item()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      pbar.update()
      pbar.set_description(f'{epoch}-train-{step} loss {loss.item():.4f} acc train {total_correct/total:.4f} valid {acc:.4f}')
    return total_correct/total

  def valid_epoch():
    with torch.no_grad():
      net.eval()
      total = 0
      total_correct = 0
      for step, (images, labels) in enumerate(loaders['test']):
        images = images.to(device)
        labels = labels.to(device)
        out = net(images)
        ypred = torch.argmax(out, dim=1)
        batch_correct = torch.sum(ypred==labels)
        total += len(labels)
        total_correct += batch_correct.item()
      
        pbar.update()
        pbar.set_description(f'{epoch}-test-{step} acc train {train_acc:.4f} valid {total_correct/total:.4f}')
    return total_correct/total

  total = n_epoch*(len(loaders['train'])+len(loaders['test']))
  acc = 0
  with tqdm(total=total, position=0, leave=True) as pbar:
    for epoch in range(n_epoch):
      train_acc = train_epoch()
      acc = valid_epoch()

      lr_scheduler.step()
  print()

  return acc

accuracy = train_model_and_evaluate(n_epoch=10)
print(f"{accuracy*100:.2f}")

9-test-156 acc train 0.9775 valid 0.9878: 100%|██████████| 10940/10940 [04:45<00:00, 38.30it/s]             


98.78



