In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


class LetterCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)

        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.out = nn.Linear(in_features=84, out_features=3)

    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        # torch.Size([1, 6, 14, 14])

        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        # torch.Size([1, 16, 5, 5])

        t = t.reshape(-1, 16 * 5 * 5)
        t = self.fc1(t)
        t = F.relu(t)
        # torch.Size([1, 120])

        t = self.fc2(t)
        t = F.relu(t)
        # torch.Size([1, 84])

        t = self.out(t)
        # torch.Size([1, 3])
        return t

Device: cuda


In [2]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_set = datasets.ImageFolder(
    root="data/train",
    transform=train_transform,
)

test_set = datasets.ImageFolder(
    root="data/test",
    transform=test_transform,
)

train_loader = DataLoader(
    dataset=train_set,
    batch_size=10,
    shuffle=True,
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=10,
    shuffle=False,
)

print(train_set.classes)
print(test_set.classes)


['p', 'sh', 't']
['p', 'sh', 't']


In [3]:
model = LetterCNN().to(device)
torch.save(model.state_dict(), "init_state.pth")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)

num_epochs = 20
best_acc = 0.0


@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    for images, labels in data_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

In [4]:
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0

    for images, labels in train_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        total += labels.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()

    train_loss = running_loss / total
    train_acc = correct / total

    val_loss, val_acc = evaluate(model, test_loader, device)

    print(
        f"Epoch [{epoch}/{num_epochs}] "
        f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
        f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} "
    )

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best.pth")
        print(f"  -> new best model saved (val_acc={best_acc:.4f})")

Epoch [1/20] train_loss=1.0906 train_acc=0.3731 val_loss=1.1064 val_acc=0.3372 
  -> new best model saved (val_acc=0.3372)
Epoch [2/20] train_loss=1.0424 train_acc=0.4506 val_loss=1.0703 val_acc=0.4729 
  -> new best model saved (val_acc=0.4729)
Epoch [3/20] train_loss=0.8714 train_acc=0.5862 val_loss=0.6889 val_acc=0.6822 
  -> new best model saved (val_acc=0.6822)
Epoch [4/20] train_loss=0.5477 train_acc=0.7800 val_loss=0.5269 val_acc=0.8062 
  -> new best model saved (val_acc=0.8062)
Epoch [5/20] train_loss=0.3735 train_acc=0.8605 val_loss=0.4108 val_acc=0.8217 
  -> new best model saved (val_acc=0.8217)
Epoch [6/20] train_loss=0.2997 train_acc=0.8798 val_loss=0.3710 val_acc=0.8411 
  -> new best model saved (val_acc=0.8411)
Epoch [7/20] train_loss=0.2493 train_acc=0.9021 val_loss=0.2888 val_acc=0.8527 
  -> new best model saved (val_acc=0.8527)
Epoch [8/20] train_loss=0.1750 train_acc=0.9293 val_loss=0.3475 val_acc=0.8527 
Epoch [9/20] train_loss=0.1774 train_acc=0.9322 val_loss=0.