In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from torchmetrics.functional.classification import accuracy
from torchmetrics.aggregation import MeanMetric
from src.model import CIFAR10

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cuda')
print(f"Device = {device}")

Device = cuda


In [3]:
train_data = datasets.CIFAR10(root="./data/cifar10/",
                            train=True,
                            download=True,
                            transform=transforms.ToTensor()
                            )

test_data = datasets.CIFAR10(root="./data/cifar10/",
                           train=False,
                           download=True,
                           transform=transforms.ToTensor()
                           )

print(len(train_data), len(test_data))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


100.0%


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10/
Files already downloaded and verified
50000 10000


In [4]:
batch_size = 8
learning_rate = 0.0001
epoch_num = 10

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

In [6]:
model = CIFAR10.Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
metric_fn = accuracy

In [7]:
model.train()

loss_mean = MeanMetric()
metric_mean = MeanMetric()

for epoch in range(epoch_num):
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        metric = metric_fn(output, target)
        loss.backward()
        optimizer.step()
        
        loss_mean.update(loss.to('cpu'))
        metric_mean.update(metric.to('cpu'))
    print(f"Epoch: {epoch+1}, Accuracy: {metric_mean.compute():.4f}, Loss: {loss_mean.compute():.4f}")
        


In [None]:
model.eval()

correct = 0
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()
print('Test set Accuracy : {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

In [None]:
state_dict = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
checkpoint_path = './checkpoints/cifar10_net.pth'

torch.save(state_dict, checkpoint_path)