In [None]:
%matplotlib inline

### Atenção: Rode esta linha apenas se estiver usando o Google Colab

In [None]:
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
import torch

In [None]:
import torch
from torch import nn
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision import datasets

### O código da célula abaixo contém funções para efetuar a carga dos dados, treinamento teste dos modelos

In [None]:
def get_loaders(batch_size):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_loader = torch.utils.data.DataLoader(
        dataset=datasets.MNIST(
            root='../data', 
            train=True, 
            download=True,
            transform=transform,
        ),
        batch_size=batch_size, 
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.MNIST(
            root='../data', 
            train=False, 
            download=True,
            transform=transform,
        ),
        batch_size=batch_size, 
        shuffle=True
    )
    return train_loader, test_loader

def train_epoch(
        model, 
        device, 
        train_loader, 
        optimizer, 
        criterion, 
        epoch, 
        log_interval
    ):
    model.train()
    history = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(
        model, 
        device, 
        criterion, 
        test_loader
    ):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
    return accuracy


def train(
        model,
        train_loader,
        test_loader,
        device,
        lr,
        nb_epochs=3,
        log_interval=100,
    ):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(1, nb_epochs + 1):
        print('\n* * * Training * * *')
        train_epoch(
            model=model, 
            device=device, 
            train_loader=train_loader, 
            optimizer=optimizer, 
            criterion=criterion, 
            epoch=epoch, 
            log_interval=log_interval
        )
        print('\n* * * Evaluating * * *')
        acc = test(model, device, criterion, test_loader)        
    
    return acc

def check_input(model, device):
    dummy_data = torch.zeros(5, 1, 28, 28).to(device)
    dummy_pred = model(dummy_data)        
    assert dummy_pred.shape == (5, 10), '\nOutput expected: (batch_size, 10) \nOutput found   : {}'.format(dummy_pred.shape)
    print('Passed')
    return dummy_pred

### Hyper-parâmetros que você pode definir

In [None]:
batch_size = 16
device_name = 'cpu'
nb_epochs = 3
log_interval = 500
lr = 1e-3

In [None]:
device = torch.device(device_name)

### Conferência dos dados

In [None]:
train_loader, test_loader = get_loaders(batch_size=batch_size)

In [None]:
print(
    'Train size: ', 
    train_loader.dataset.train_data.shape, 
    train_loader.dataset.train_labels.shape
)
print(
    'Test size : ', 
    test_loader.dataset.test_data.shape, 
    test_loader.dataset.test_labels.shape
)

In [None]:
instance = next(iter(train_loader))
print('Instance Example: ', instance[0].shape, instance[1].shape)

In [None]:
fig, axs = plt.subplots(1, 5)
for i, ax in enumerate(axs):
    ax.imshow(train_loader.dataset.train_data[i], cmap='gray')
    ax.set_title(train_loader.dataset.train_labels[i].item())
    ax.set_xticks([])
    ax.set_yticks([])

## Seu trabalho começa aqui:

## 1. Implemente aqui sua primeira arquitetura com `nn.LSTM()` 

Sua LSTM deve ser capaz de classificar as imagens do MNIST processando de forma recorrente as linhas ou colunas. Lembre-se que as imagens do MNIST tem apenas 1 canal, isto é, elas são em escala de cinza (e não RBG!). 
* Spoiler: LSTM com 32 neurônios, atinge ~96% de acurácia em 3 épocas. 

In [None]:
class DigitsLSTM(nn.Module):
    def __init__(self):
        super(DigitsLSTM, self).__init__()        
    def forward(self, x):        
        return out

### 1.1 Verifique se a saída do seu modelo está correta

In [None]:
model = DigitsLSTM().to(device)
dummy_pred = check_input(model, device)

### 1.2 Treine seu modelo por uma 1 época
Valores de acc esperados por época: 
1. 93.5%
2. 94.5%
3. 96.8%

In [None]:
acc = train(model, train_loader, test_loader, device, lr, nb_epochs, log_interval)
print('Final acc: {:.2f}%'.format(acc))

## 2. Implemente aqui sua arquitetura com `nn.LSTMCell()` 

Semelhante à arquitetura anterior, sua LSTM deve processar imagens do MNIST iterando sobre as linhas ou as colunas das imagens. A diferença é que agora você deve implementar utilizando uma nn.LSTMCell(). Para isso, você deverá utilizar um laço de repetição `for`.

In [None]:
class DigitsCellLSTM(nn.Module):
    def __init__(self):
        super(DigitsCellLSTM, self).__init__()
        
    def forward(self, x):        
        return out

In [None]:
model = DigitsCellLSTM(device=device).to(device)
dummy_pred = check_input(model, device)

In [None]:
acc = train(model, train_loader, test_loader, device, lr, nb_epochs, log_interval)
print('Final acc: {:.2f}%'.format(acc))

## 3. Fique à vontade para testar outras variações das suas arquiteturas para conseguir resultados melhores

Ideias: 
* Aumente o número de camadas na LSTM
* Teste GRU
* Teste arquiteturas bidirecionais

In [None]:
class YourLSTM(nn.Module):
    def __init__():
        super(YourLSTM, self).__init__()            

    def forward(self, x):        
        return out