In [None]:
!pip install torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np

# standard plotting library
import matplotlib.pyplot as plt

In [None]:
## Helper zum Plotten 

def show_image_grid(X, y=None, y_pred=None, title=None, nrow=6, ncol=4, **kwargs):
    max_num = nrow*ncol
    X = X[:max_num]
    if len(X) < max_num:
        ncol = len(X) // nrow + 1
    if isinstance(X, np.ndarray):
        X = torch.from_numpy(X)
    if X.dim() != 4:
        X = X[:, None]

    plt.figure(title, figsize=(2*nrow, 2*ncol + (0 if y is None else 1)))
    if title:
        plt.title(title)
        
    if isinstance(y, torch.Tensor):
        y = y.numpy()
        
    for i, Xi in enumerate(X):
        plt.subplot(ncol, nrow, i+1)
        img = Xi.numpy().transpose((1, 2, 0))
        if img.shape[2] == 1:
            img = img[..., 0]
        plt.imshow(img)

        title = ""
        if y is not None:
          title += str(int(y[i]))
        if y_pred is not None:
          title += "/" + str(int(y_pred[i]))
        if len(title) > 0: 
          plt.title(title)

        plt.axis('off')
    
    plt.axis('off')


# Unser erstes Netz

## Daten Laden

In [None]:
train_data = datasets.MNIST(root='./data', train=True, 
                            download=True, transform=transforms.ToTensor())
train_data

In [None]:

batch_size = 128
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
train_loader

valid_data = datasets.MNIST(root='./data', train=False, 
                            download=True, transform=transforms.ToTensor())
valid_loader = torch.utils.data.DataLoader(valid_data, shuffle=False, batch_size=batch_size)

In [None]:
x, y = next(iter(train_loader))

plt.gray() # sets matplotlib colormap
plt.imshow(x[0,0])
print("label",y[0])

## Das Netz trainieren

In [None]:
def accuracy(output, label):
    return (output.argmax(1) == label).float().mean()

lr = 0.1

# wir definieren unser "netz" 
weight = torch.zeros((28*28, 10), requires_grad=True)
bias = torch.zeros(10, requires_grad=True)

losses = []
for epoch in range(1):
    # train on mini batches
    for X, y in train_loader:
        # forward pass
        output = X.view(-1, 28*28) @ weight + bias    
        
        # compute loss        
        loss = F.cross_entropy(output,y)
        losses.append(loss.item())
        
        # backward pass
        loss.backward()
        with torch.no_grad():
          weight -= lr * weight.grad
          weight.grad.zero_()
          bias -= lr * bias.grad
          bias.grad.zero_()
    
    # measure accuracy on the validation set
    acc = accuracy(output,y)
    print("epoch % 5d: loss: % 8.5f accuracy: %4.2f" % (epoch, torch.mean(torch.tensor(losses)), acc.item()))    


In [None]:
plt.plot(losses)

In [None]:
Xv, yv = next(iter(valid_loader))
prediction = Xv.view(-1, 28*28) @ weight + bias  
show_image_grid(Xv, y=yv, y_pred=prediction.argmax(dim=1))

In [None]:
W = weight.detach()
W.shape

In [None]:
W[:, 0]

In [None]:
plt.imshow(W[:, 0].reshape(28, 28))

In [None]:
show_image_grid(W.reshape(28, 28, 10).permute(2, 0, 1), nrow=4)