# Klassifizierung Datensatz CIFAR-10

Der CIFAR-10-Datensatz betseht aus 60.000 Bildern im Format 32 x 32 Pixeln, aufgeteilt in 10 Klassen mit je 6.000 Bildern:

Er besteht aus 50.000 Bildern für das Training und 10.000 Bildern für Tests.

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [6]:
root = './data'

training_data = datasets.CIFAR10(root=root,train=True,download=True,transform=ToTensor())
test_data = datasets.CIFAR10(root=root,train=False,download=True,transform=ToTensor())

batch_size = 64
training_data_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


## Das Neuronale Netz

In [7]:
import torch.nn as nn
import torch.nn.functional as F

In [22]:
class SimpleNet(nn.Module):

    def __init__(self, n_features, n_classes):
        super(SimpleNet, self).__init__()
        self.n_features = n_features
        self.fc1 = nn.Linear(n_features, 800)
        self.fc2 = nn.Linear(800, 200)
        self.fc3 = nn.Linear(200,n_classes)
    
    def forward(self, x):

        x = x.view(-1, self.n_features)
        #print(len(x))        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x



Wir erzeugen das Netzwerk und verschieben es auf die GPU, falls vorhanden:

In [23]:
simplenet = SimpleNet(3 * 32 * 32, 10)    

device = 'cuda' if torch.cuda.is_available() else 'cpu'
simplenet.to(device)

SimpleNet(
  (fc1): Linear(in_features=3072, out_features=800, bias=True)
  (fc2): Linear(in_features=800, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=10, bias=True)
)

In [24]:
import torch.optim as optim

optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

In [25]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            #print(len(inputs), len(targets))
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            #print(len(output), len(targets))
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2%}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [27]:
#train(simplenet, optimizer,nn.CrossEntropyLoss(), training_data_loader,test_data_loader, epochs=10, device=device)
#torch.save(simplenet, "simplenet") 
simplenet = torch.load("simplenet_49",weights_only=False)

In [37]:
labels = training_data.classes
simplenet.to("cpu")

for n in range(10):
    image, label = training_data[n]
    simplenet.eval()
    prediction = F.softmax(simplenet(image), dim=1)
    prediction = prediction.argmax()
    
    print(f'{n}: {labels[prediction]} {prediction} - {training_data.targets[n]} {labels[training_data.targets[n]]}')     

0: frog 6 - 6 frog
1: automobile 1 - 9 truck
2: truck 9 - 9 truck
3: deer 4 - 4 deer
4: ship 8 - 1 automobile
5: truck 9 - 1 automobile
6: bird 2 - 2 bird
7: horse 7 - 7 horse
8: ship 8 - 8 ship
9: truck 9 - 3 cat
