In [12]:
import torch 
import torch.nn as nn 
from tqdm import tqdm 
import torchvision 

from torch.utils.data import DataLoader, SubsetRandomSampler

In [29]:
#Extraído de https://github.com/rasbt/stat453-deep-learning-ss21/blob/2202699c5fd38af398e2682f289a0868b1b91f0e/L13/code/helper_evaluation.py

def compute_history(model, data_loader, device, loss):

    with torch.no_grad():

        loss_ac, correct_pred, num_examples = 0, 0, 0

        for i, (features, targets) in enumerate(data_loader):

            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)

            loss_val = loss(logits, targets)
            _, predicted_labels = torch.max(logits, 1)

            num_examples += targets.size(0)
            loss_ac += (loss_val.item())
            correct_pred += (predicted_labels == targets.float()).sum()
    return correct_pred.float()/num_examples * 100, loss_ac/num_examples * 100

## Definição de Hiperparâmetros

In [30]:
epochs = 15 
batch_size = 256 
val_split = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Dataset MNIST

In [31]:
resize_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize((32, 32)),
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5,), (0.5,))])


In [32]:
train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=resize_transform, download=True)

valid_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=resize_transform)

test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=resize_transform)

In [33]:
#Based on 

num = int(val_split * 60000)
train_indices = torch.arange(0, 60000 - num)
valid_indices = torch.arange(60000 - num, 60000)

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True, sampler=train_sampler)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, sampler=valid_sampler)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


## Criação da LeNet5

In [34]:
class LeNet5(torch.nn.Module):

    def __init__(self, num_classes = 10, n_channels=1):
        super().__init__()
        
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(n_channels, 6, kernel_size=5),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.Conv2d(6, 16, kernel_size=5),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size=2)
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(16*5*5, 120),
            torch.nn.Tanh(),
            torch.nn.Linear(120, 84),
            torch.nn.Tanh(),
            torch.nn.Linear(84, num_classes),
        )


    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

## Treinamento

In [35]:
model = LeNet5(num_classes=10)

model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [37]:
minibatch_loss_list, train_acc_list, valid_acc_list, train_loss_list, valid_loss_list = [], [], [], [], []

criterion = torch.nn.functional.cross_entropy

for epoch in range(epochs):

    model.train()
    for (X_train, y_train) in tqdm(train_loader):

        optimizer.zero_grad()

        X_train = X_train.to(device)
        y_train = y_train.to(device)

        output = model(X_train)
        loss = criterion(output, y_train)
        
        loss.backward()
        optimizer.step()

        minibatch_loss_list.append(loss.item())
        
    #Validação
    model.eval()
    with torch.no_grad():
        train_acc, train_loss = compute_history(model, train_loader, device, criterion)
        valid_acc, valid_loss = compute_history(model, valid_loader, device, criterion)
        
        train_acc_list.append(train_acc.item())
        valid_acc_list.append(valid_acc.item())
        
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)


100%|██████████| 187/187 [00:10<00:00, 18.49it/s]
100%|██████████| 187/187 [00:09<00:00, 18.87it/s]
100%|██████████| 187/187 [00:09<00:00, 18.73it/s]
100%|██████████| 187/187 [00:09<00:00, 18.87it/s]
100%|██████████| 187/187 [00:09<00:00, 19.44it/s]
100%|██████████| 187/187 [00:09<00:00, 19.20it/s]
100%|██████████| 187/187 [00:09<00:00, 19.06it/s]
100%|██████████| 187/187 [00:09<00:00, 19.15it/s]
100%|██████████| 187/187 [00:09<00:00, 19.36it/s]
100%|██████████| 187/187 [00:09<00:00, 19.23it/s]
100%|██████████| 187/187 [00:09<00:00, 19.34it/s]
100%|██████████| 187/187 [00:10<00:00, 18.59it/s]
100%|██████████| 187/187 [00:09<00:00, 19.18it/s]
100%|██████████| 187/187 [00:09<00:00, 19.27it/s]
100%|██████████| 187/187 [00:09<00:00, 19.49it/s]
