In [None]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np 

In [None]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

train_set = MNIST(
    './', 
    download=True,
    transform=transforms.ToTensor(), 
    train=True
)

test_set = MNIST(
    './', 
    download=True,
    transform=transforms.ToTensor(), 
    train=False
)

--2021-08-18 07:49:35--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-08-18 07:49:35--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [            <=>     ]  33.20M  6.30MB/s    in 5.5s    

2021-08-18 07:49:42 (6.01 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-ubyte
MNIST/raw/tra

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=128, shuffle=False)

In [None]:
# A simple convolutional classifier
class ConvClassifier(nn.Module):
  def __init__(self, in_channels, out_channels, input_shape=np.array([28,28])):
    super().__init__()
    self.input_shape = input_shape
    self.cnn = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.linear = nn.Sequential(
        nn.Linear(out_channels * np.prod(self.input_shape) // 16, 10)
    )

  def forward(self, x):
    x = self.cnn(x)
    x = x.view(x.size(0), -1)
    return self.linear(x)

In [None]:
# HYPERPARAMETERS
epochs = 5
lr = 0.001
epsilon = 0.1

# Standard (non-adversarial) training loop
def train(device, train_loader, model):
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)

  for epoch in range(epochs):
    for i, (x,y) in enumerate(train_loader):
      x, y = x.to(device), y.to(device)
      logits = model(x)
      loss = loss_fn(logits, y)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if (i%100 == 0):
        print ('Epoch [{}/{}], Step [{}/{}], Loss: {}'.format(epoch+1, epochs, i+1, len(train_loader), loss.item()))


In [None]:
model = ConvClassifier(1, 4).cuda()

In [None]:
train(device, train_loader, model) # gey ata

Epoch [1/5], Step [1/469], Loss: 2.479057788848877
Epoch [1/5], Step [101/469], Loss: 0.5927906632423401
Epoch [1/5], Step [201/469], Loss: 0.3305518925189972
Epoch [1/5], Step [301/469], Loss: 0.25751420855522156
Epoch [1/5], Step [401/469], Loss: 0.15906080603599548
Epoch [2/5], Step [1/469], Loss: 0.20467111468315125
Epoch [2/5], Step [101/469], Loss: 0.18018808960914612
Epoch [2/5], Step [201/469], Loss: 0.160672128200531
Epoch [2/5], Step [301/469], Loss: 0.1707994043827057
Epoch [2/5], Step [401/469], Loss: 0.13455817103385925
Epoch [3/5], Step [1/469], Loss: 0.07427234202623367
Epoch [3/5], Step [101/469], Loss: 0.17444077134132385
Epoch [3/5], Step [201/469], Loss: 0.1609479933977127
Epoch [3/5], Step [301/469], Loss: 0.12243125587701797
Epoch [3/5], Step [401/469], Loss: 0.0548778772354126
Epoch [4/5], Step [1/469], Loss: 0.1775771528482437
Epoch [4/5], Step [101/469], Loss: 0.05574285611510277
Epoch [4/5], Step [201/469], Loss: 0.09769746661186218
Epoch [4/5], Step [301/469],