In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
torch.set_printoptions(linewidth=120)



In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

test_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=False
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [6]:
len(train_set), len(test_set), train_set.targets.bincount(), test_set.targets.bincount()

(60000,
 10000,
 tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000]),
 tensor([1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]))

In [7]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=100)

In [9]:
images, labels = next(iter(train_loader))
images.shape, labels.shape

(torch.Size([100, 1, 28, 28]), torch.Size([100]))

In [13]:
images1 = images.reshape(100,-1)
images1.shape

torch.Size([100, 784])

In [16]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

class LinNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=784, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        # hidden linear layer
        t = t.reshape(-1, 784)
        t = self.fc1(t)
        t = F.relu(t)

        # hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)

        # output layer
        t = self.out(t)
        #t = F.softmax(t, dim=1)

        return t

In [20]:
net = LinNet()
optimizer = optim.Adam(net.parameters(), lr=0.01)

for epoch in range(10):

    total_loss = 0
    total_correct = 0

    for batch in train_loader: # Get Batch
        images, labels = batch 

        preds = net(images) # Pass Batch
        loss = F.cross_entropy(preds, labels) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct, 
        "loss:", total_loss
    )

epoch 0 total_correct: 48688 loss: 309.192642390728
epoch 1 total_correct: 51400 loss: 237.35867756605148
epoch 2 total_correct: 51910 loss: 223.205138489604
epoch 3 total_correct: 52256 loss: 212.5849279910326
epoch 4 total_correct: 52577 loss: 203.09006017446518
epoch 5 total_correct: 52728 loss: 200.6128739863634
epoch 6 total_correct: 52864 loss: 195.5978786945343
epoch 7 total_correct: 53092 loss: 191.25434762239456
epoch 8 total_correct: 53202 loss: 188.83281981945038
epoch 9 total_correct: 53256 loss: 184.19952076673508


```
epoch 0 total_correct: 48688 loss: 309.192642390728
epoch 1 total_correct: 51400 loss: 237.35867756605148
epoch 2 total_correct: 51910 loss: 223.205138489604
epoch 3 total_correct: 52256 loss: 212.5849279910326
epoch 4 total_correct: 52577 loss: 203.09006017446518
epoch 5 total_correct: 52728 loss: 200.6128739863634
epoch 6 total_correct: 52864 loss: 195.5978786945343
epoch 7 total_correct: 53092 loss: 191.25434762239456
epoch 8 total_correct: 53202 loss: 188.83281981945038
epoch 9 total_correct: 53256 loss: 184.19952076673508
Test-daten: correct: 0.8618
```

#### Anwenden auf Testdaten

In [21]:
correct = 0
for batch in test_loader:
    images, labels = batch 
    preds = net(images)
    
    images.shape, labels
    correct += get_num_correct(preds,labels)
    
print("correct:",correct/len(test_set))

correct: 0.8618
