In [None]:
import torchvision
import torch
import numpy as np
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import time
import pandas as pd
import torchvision.datasets as datasets
from IPython import display
import matplotlib.pyplot as plt

In [None]:
batch_size = 64
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = DataLoader(datasets.CIFAR10("data", train=True, transform=transform, download=True), batch_size=batch_size)
valid_ds = DataLoader(datasets.CIFAR10("data", train=False, transform=transform, download=True), batch_size=batch_size)

In [None]:
class CNNModel(nn.Module):
    
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
model = CNNModel(num_classes=10)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model.to(device)

In [None]:
def display_message_and_metrics(message, metrics):
    display.clear_output(wait=False) 
    if len(metrics["loss"]) > 0:
        pd.DataFrame(metrics).plot()
        plt.show()
    print(message)

In [None]:
epochs = 5
train_steps = len(train_ds)
valid_steps = len(valid_ds)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {"loss": [], "val_loss": [], "val_accuracy": []}
for epoch in range(epochs):
    train_losses = []
    valid_losses = []
    model.train()
    begin = time.time()
    for batch in train_ds:
        optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        output = model(inputs)
        loss = loss_fn(output, targets)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.data.item())
        if len(train_losses) > 0 and len(train_losses) % 50 == 0:
            current = time.time()
            elapsed = current - begin
            display_message_and_metrics(
                "Epoch %d: [Training] %.2fs/%.2fs"%(epoch + 1, elapsed, elapsed / float(len(train_losses)) * train_steps), 
                metrics
            )
    model.eval()
    num_correct = 0
    num_samples = 0
    begin = time.time()
    for batch in valid_ds:
        optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        output = model(inputs)
        loss = loss_fn(output, targets)
        valid_losses.append(loss.data.item())
        correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
        num_correct += torch.sum(correct).item()
        num_samples += correct.shape[0]
        if len(valid_losses) > 0 and len(valid_losses) % 50 == 0:
            current = time.time()
            elapsed = current - begin
            display_message_and_metrics(
                "Epoch %d: [Validation] %.2fs/%.2fs"%(epoch + 1, elapsed, elapsed / float(len(valid_losses)) * valid_steps), 
                metrics
            )
    train_loss = torch.mean(torch.Tensor(train_losses)).item()
    valid_loss = torch.mean(torch.Tensor(valid_losses)).item()
    accuracy = num_correct / num_samples if num_samples > 0 else 0
    metrics["loss"].append(train_loss)
    metrics["val_loss"].append(valid_loss)
    metrics["val_accuracy"].append(accuracy)
    display.clear_output(wait=False) 
    pd.DataFrame(metrics).plot()
    plt.show()
    display_message_and_metrics(
        "Training Loss: %.2f Validation Loss: %.2f accuracy: %.2f" %(train_loss, valid_loss, accuracy), 
        metrics
    )

In [None]:
path = "model"
torch.save(model, path)

In [None]:
model = torch.load(path)
labels = np.array(train_ds.dataset.classes)
for batch in valid_ds:
    images, targets = batch
    images = images.to(device)
    targets = targets.to(device)
    prediction = model(images).argmax(dim=1)
    print("Prediction:", prediction)
    print("Actual result:", targets)
    break