In [None]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [None]:
torch.randn(5).cuda()

tensor([-0.6028,  0.6574, -0.4390, -0.9853, -0.9587], device='cuda:0')

In [None]:
# Train Val split
train_data = datasets.MNIST('data',train=True,download=True,transform=transforms.ToTensor())
train, val = random_split(train_data,[55000,5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
# Define my model
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Linear(64,10)
)

In [None]:
# define a more flexible model
class ResNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28 * 28, 64)
    self.l2 = nn.Linear(64,64)
    self.l3 = nn.Linear(64,10)
    self.do = nn.Dropout(0.1)
  
  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))
    do = self.do(h2+h1)
    logits = self.l3(do)
    return logits

model = ResNet().cuda()





In [None]:
# Define my optimizer
# params = model.parameters()
optimizer = optim.SGD(model.parameters(), lr = 1e-2)

In [None]:
# Define my loss
loss = nn.CrossEntropyLoss()

In [None]:
# My training and validation loops
nb_epochs = 5
for epoch in range(nb_epochs):
  losses = list()
  model.train()
  accuracies = list()
  for batch in train_loader:
    x,y = batch
    # x: bx 1 x 28 x 28
    b = x.size(0)
    x = x.view(b,-1).cuda()
    
    # Step 1 : Forward
    l = model(x) # l: logit

    #Debugging
    # import pdb; pdb.set_trace()
    
    # Step 2 : Compute the objective function
    J = loss(l,y.cuda())

    # Step 3 : Cleaning the gradient
    model.zero_grad()
    #optimizer.zero_grad()
    #params.grad._zero()

    # Step 4 : Accumulate the partial derivatives of J wrt params
    J.backward()
    #params.grad.add(dJ/dparams)

    # Step 5 : Step in the opposite direction of the gradient
    optimizer.step()
    # with torch.no_grad(): params = params - eta * params.grad

    losses.append(J.item())
    accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())

  print(f'Epoch {epoch+1}',end=', ' )
  print(f'train loss: {torch.tensor(losses).mean():.2f}',end=', ')
  print(f'train accuracy: {torch.tensor(accuracies).mean():.2f}')
  
  losses = list()
  accuracies = list()
  model.eval()
  for batch in val_loader:
    x,y = batch
    # x: bx 1 x 28 x 28
    b = x.size(0)
    x = x.view(b,-1).cuda()
    
    # Step 1 : Forward
    with torch.no_grad():
      l = model(x) # l: logit
    
    # Step 2 : Compute the objective function
    J = loss(l,y.cuda())

    losses.append(J.item())
    accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
  print(f'Epoch {epoch+1}',end=', ' )
  print(f'validation loss: {torch.tensor(losses).mean():.2f}',end=', ')
  print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')