In [27]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

from tqdm.notebook import tqdm

In [20]:
# model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )
    
    def forward(self, x):
        logits = self.model(x)
        return logits
    
model = Model().cuda()

In [21]:
# Define optimizer
optimizer = optim.Adam(model.parameters())
# Define loss
loss_fn = nn.CrossEntropyLoss()

In [22]:
# train, val
batch_size = 32
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=batch_size)
val_loader = DataLoader(val, batch_size=batch_size)

In [31]:
# training loop
n_epochs = 5
for epoch in tqdm(range(n_epochs)):
    model.train()
    losses = list()
    for batch in train_loader:
        x, y = batch
        batch_size = x.size(0)
        x = x.view(batch_size, -1).cuda()
        y_hat = model(x)
        # for python debug
        # import pdb; pdb.set_trace()
        loss = loss_fn(y_hat, y.cuda())
        losses.append(loss.item())
        
        model.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch : {epoch+1}, train_loss : {torch.tensor(losses).mean():.3f}')
    
    model.eval()
    losses = list()
    for batch in val_loader:
        x, y = batch
        batch_size = x.size(0)
        x = x.view(batch_size, -1).cuda()
        
        with torch.no_grad():
            y_hat = model(x)
            loss = loss_fn(y_hat, y.cuda())
            losses.append(loss.item())   
        
    print(f'Epoch : {epoch+1}, val_loss : {torch.tensor(losses).mean():.3f}')

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

Epoch : 1, train_loss : 0.107
Epoch : 1, val_loss : 0.113
Epoch : 2, train_loss : 0.082
Epoch : 2, val_loss : 0.107
Epoch : 3, train_loss : 0.065
Epoch : 3, val_loss : 0.109
Epoch : 4, train_loss : 0.052
Epoch : 4, val_loss : 0.117
Epoch : 5, train_loss : 0.043
Epoch : 5, val_loss : 0.122

