## Fashion MNIST mit Conv-Net

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

test_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=False
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=100)


In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(7*7*32, 10)
        
    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t)
        t = self.pool(t)
           
        t = self.conv2(t)
        t = F.relu(t)
        t = self.pool(t)
        
        t = t.reshape(-1, 7*7*32)
        t = self.fc(t)

        return t


In [5]:
%%time
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.01)

for epoch in range(2):

    total_loss = 0
    total_correct = 0

    for batch in train_loader: # Get Batch
        images, labels = batch 

        preds = net(images) # Pass Batch
        loss = F.cross_entropy(preds, labels) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct, 
        "loss:", total_loss
    )
    


epoch 0 total_correct: 50532 loss: 261.1365255266428
epoch 1 total_correct: 53088 loss: 190.78211621940136
Wall time: 3min 13s


In [6]:
correct = 0
for batch in test_loader:
    images, labels = batch 
    preds = net(images)
    
    images.shape, labels
    correct += get_num_correct(preds,labels)
    
print("correct:",correct/len(test_set))

correct: 0.8778
