# Regresión Softmax concisa


Por ahora la mayor parte del código que revisaremos es casi identico a lo que ya conocemos.

In [None]:
import torch
from torch import nn
import torchvision
from IPython import display
from torchvision import transforms
from torch.utils import data

In [None]:
#Ejemplo de dataloader para Fashion MNIST

def load_data_fashion_mnist(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=1),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=1))


In [None]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

## Inicialización.


In [None]:
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

¿Qué hace Flatten?

Convierte la matriz de píxeles a un vector de números.

## Softmax 

Dijimos anterioremente que nuestre primera implementación de Softmax era inestable computacionalmente. Por esta razón, los frameworks preexistentes hacen uso de otras implementaciones que evitan estas inestabilidades. Para más información, dejamos el siguiente link

["LogSumExp trick"](https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations).


In [None]:
loss = nn.CrossEntropyLoss(reduction='none')

## Algoritmo de optimización.


In [None]:
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

In [None]:
def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

# Entrenamiento

In [None]:
num_epochs = 10
lr = 0.01
for epoch in range(num_epochs):
    L = 0.0
    N = 0
    Acc = 0.0
    TestAcc = 0.0
    TestN = 0
    for X, y in train_iter:
        l = loss(net(X) ,y)
        trainer.zero_grad()
        l.mean().backward()
        trainer.step()
        L += l.sum()
        N += l.numel()
        Acc += accuracy(net(X), y)
    for X, y in train_iter:
        TestN += y.numel()
        TestAcc += accuracy(net(X), y)
    print(f'epoch {epoch + 1}, loss {(L/N):f}\
          , train accuracy  {(Acc/N):f}, test accuracy {(TestAcc/TestN):f}')

epoch 1, loss 0.785882          , train accuracy  0.766000, test accuracy 0.805883
epoch 2, loss 0.568853          , train accuracy  0.821867, test accuracy 0.813600
epoch 3, loss 0.525430          , train accuracy  0.833950, test accuracy 0.819950
epoch 4, loss 0.500251          , train accuracy  0.839183, test accuracy 0.836267
epoch 5, loss 0.485062          , train accuracy  0.844033, test accuracy 0.831083
epoch 6, loss 0.473485          , train accuracy  0.846717, test accuracy 0.822250
epoch 7, loss 0.465627          , train accuracy  0.849767, test accuracy 0.842200
epoch 8, loss 0.458643          , train accuracy  0.852383, test accuracy 0.844733
epoch 9, loss 0.452703          , train accuracy  0.852667, test accuracy 0.847200
epoch 10, loss 0.446111          , train accuracy  0.854350, test accuracy 0.846467


Como vemos, no hay mucho más que discutir y analizar. 