In [None]:
%matplotlib inline
import torch
from torch import nn
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision import datasets

### O código da célula abaixo contém funções para efetuar a carga dos dados, treinamento teste dos modelos

In [None]:
def get_loaders(batch_size):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_loader = torch.utils.data.DataLoader(
        dataset=datasets.CIFAR10(
            root='../data', 
            train=True, 
            download=True,
            transform=transform,
        ),
        batch_size=batch_size, 
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.CIFAR10(
            root='../data', 
            train=False, 
            download=True,
            transform=transform,
        ),
        batch_size=batch_size, 
        shuffle=True
    )
    return train_loader, test_loader

from collections import defaultdict

history = defaultdict(list)

def train_epoch(
        model, 
        device, 
        train_loader, 
        optimizer, 
        criterion, 
        epoch, 
        log_interval
    ):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        history['loss'].append(loss.item())
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(
        model, 
        device, 
        criterion, 
        test_loader
    ):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
    return accuracy


def train(
        model,
        train_loader,
        test_loader,
        device,
        lr,
        nb_epochs=3,
        log_interval=100,
    ):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(1, nb_epochs + 1):
        print('\n* * * Training * * *')
        train_epoch(
            model=model, 
            device=device, 
            train_loader=train_loader, 
            optimizer=optimizer, 
            criterion=criterion, 
            epoch=epoch, 
            log_interval=log_interval
        )
        print('\n* * * Evaluating * * *')
        acc = test(model, device, criterion, test_loader)        
        history['val_acc'].append(acc)
    
    return acc


### Hyper-parâmetros que você pode definir

In [None]:
batch_size = 16
device_name = 'cpu'
nb_epochs = 3
log_interval = 50
lr = 1e-3

In [None]:
device = torch.device(device_name)

### Conferência dos dados

In [None]:
train_loader, test_loader = get_loaders(batch_size=batch_size)

In [None]:
print(
    'Train size: ', 
    train_loader.dataset.train_data.shape, 
    len(train_loader.dataset.train_labels)
)
print(
    'Test size : ', 
    test_loader.dataset.test_data.shape, 
    len(test_loader.dataset.test_labels)
)

In [None]:
class_list = ['airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']

In [None]:
fig, axs = plt.subplots(1, 5)
for i, ax in enumerate(axs):
    label_id = train_loader.dataset.train_labels[i]
    ax.imshow(train_loader.dataset.train_data[i],)
    ax.set_title('{}-{}'.format(label_id, class_list[label_id]))
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
instance = next(iter(train_loader))
print('Instance Example: ', instance[0].shape, instance[1].shape)

## Seu trabalho começa aqui:

## 1. Implemente uma rede convolucional para classificar imagens do CIFAR10

#### Arquitetura:
* Input: (3, 32, 32)
* Conv(32, 3)
* MaxPool(2)
* Conv(64, 3)
* MaxPool(2)
* Flatten 
* Linear(10)
    

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
    def forward(self, x):
        
        return out

### 1.1 Verifique se a saída do seu modelo está correta

In [None]:
model = ConvNet().to(device)

### 1.2 Treine seu modelo por uma ou mais épocas. 

Você deve conseguir ~99% de acurácia na terceira época. 

In [None]:
acc = train(model, train_loader, test_loader, device, lr, 3, log_interval)
print('Final acc: {:.2f}%'.format(acc))

In [None]:
plt.plot(history['val_acc'], title='Validation Acc')

In [None]:
plt.plot(history['loss'], title='Training loss')

## 2. Finetune uma SqueezeNet no CIFAR-10


In [None]:
net = torchvision.models.squeezenet1_0(pretrained=True)

In [None]:
class SqueezeCifar(nn.Module):
    def __init__(self, ):
        super().__init__()

    def forward(self, x):

        return x

### 2.1 Verifique se a saída do seu modelo está correta

In [None]:
model = SqueezeCifar().to(device)


### 2.2 Treineseu modelo por uma ou mais épocas. 


In [None]:
acc = train(model, train_loader, test_loader, device, lr/10, nb_epochs, log_interval)
print('Final acc: {:.2f}%'.format(acc))

### 2.3 Tente descobrir hiper-parâmetros mais adequados para efetuar esse treino (ex: outras lr, adicionar mais camadas no final, congelar os pesos da rede pré-treinada) 


## 3. Escolha outra arquitetura pré-treinada e finetune no CIFAR10

## 4. Treine a rede escolhida diretamente no CIFAR (sem utilizar pesos pré-treinados, i.e., pretrained=False) e veja a diferença de resultado
