In [64]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import timm

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import torchvision.transforms as transforms
from backend.util.dataloader import get_data_loaders
from backend.model.train import train_and_validate

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: {}".format(device))

Using device: cuda


In [58]:
# Load data
train_dir = './data/train'
valid_dir = './data/valid'
test_dir = './data/test'
batch_size = 32
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_loader, valid_loader, test_loader, target_to_class = get_data_loaders(
    train_dir, valid_dir, test_dir, batch_size, transform
)

In [59]:
class CardClassifier(nn.Module):
    def __init__(self, num_classes=53):
        super(CardClassifier, self).__init__()
        self.base_model = timm.create_model('efficientnet_b0', pretrained=True)
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        enet_out_size = 1280
        self.classifier = nn.Linear(enet_out_size, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x 
        

In [63]:
model = CardClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [65]:
results = train_and_validate(model, criterion, optimizer, train_loader, valid_loader, device, num_epochs=30)

Epoch 1/30
Training Loss: 3.6291, Training Accuracy: 12.15%
Validation Loss: 3.1160, Validation Accuracy: 23.77%


KeyboardInterrupt: 

In [61]:
import torch

num_epochs = 30
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    train_loss = running_loss / len(train_loader.dataset)
    train_accuracy = 100 * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Validation phase
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss = running_loss / len(valid_loader.dataset)
    val_accuracy = 100 * correct / total
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%')
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

Epoch 1/30
Training Loss: 3.6171, Training Accuracy: 12.99%
Validation Loss: 3.0752, Validation Accuracy: 23.77%
Epoch 2/30
Training Loss: 2.6111, Training Accuracy: 36.75%
Validation Loss: 1.9848, Validation Accuracy: 51.32%
Epoch 3/30
Training Loss: 1.9082, Training Accuracy: 53.20%
Validation Loss: 1.3863, Validation Accuracy: 64.53%
Epoch 4/30
Training Loss: 1.4746, Training Accuracy: 62.87%
Validation Loss: 1.0187, Validation Accuracy: 73.96%
Epoch 5/30
Training Loss: 1.1815, Training Accuracy: 69.75%
Validation Loss: 0.8073, Validation Accuracy: 77.74%
Epoch 6/30
Training Loss: 0.9439, Training Accuracy: 75.45%
Validation Loss: 0.6486, Validation Accuracy: 80.38%
Epoch 7/30
Training Loss: 0.7594, Training Accuracy: 80.29%
Validation Loss: 0.5152, Validation Accuracy: 85.66%
Epoch 8/30
Training Loss: 0.6302, Training Accuracy: 83.66%
Validation Loss: 0.4706, Validation Accuracy: 84.15%
Epoch 9/30
Training Loss: 0.5196, Training Accuracy: 86.44%
Validation Loss: 0.4122, Validation 

In [62]:
torch.save(model.state_dict(), 'model_card_classifier.pth')