In [5]:
# First download packages required
pip install torch torchvision matplotlib

Note: you may need to restart the kernel to use updated packages.


In [29]:
import os
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm import tqdm


device = "cuda" if torch.cuda.is_available() else "cpu"

# Find cat breed dataset
cat_dataset_train = "./cat_dataset/train"
cat_dataset_test = "./cat_dataset/test"


# Create a transforms pipeline manually (required for torchvision < 0.13)
manual_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create Datasets
train_data = datasets.ImageFolder(cat_dataset_train, transform=manual_transforms)
test_data = datasets.ImageFolder(cat_dataset_test, transform=manual_transforms)
breed_names = train_data.classes

# Create Train DataLoader
train_dataloader = DataLoader(
    train_data, 
    batch_size=32, 
    shuffle=True,
    num_workers=os.cpu_count(),
    pin_memory=True
)
# Create Test DataLoader
test_dataloader = DataLoader(
    test_data, 
    batch_size=32, 
    shuffle=True,
    num_workers=os.cpu_count(),
    pin_memory=True
)

# Create ResNet model
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, len(breed_names))
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [32]:
def train_cat_breed_detector_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_val_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        description = f"Epoch {epoch+1}/{num_epochs}"
        for images, labels in tqdm(train_loader, desc=description):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        # Validate model and save the one with best accuracy
        model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        val_accuracy = val_correct / val_total
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.4f}, Validation Accuracy: {val_accuracy:.4f}")
        # Save the model if it performs better
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), "cat_breed_detector_model.pth")

In [34]:
# Train the model
train_cat_breed_detector_model(model, train_dataloader, test_dataloader, criterion, optimizer, num_epochs=5)

Epoch 1/5: 100%|██████████| 226/226 [22:07<00:00,  5.87s/it]


Epoch [1/5], Loss: 0.4918, Train Accuracy: 0.8391, Validation Accuracy: 0.3811


Epoch 2/5: 100%|██████████| 226/226 [22:42<00:00,  6.03s/it]


Epoch [2/5], Loss: 0.2985, Train Accuracy: 0.9050, Validation Accuracy: 0.4599


Epoch 3/5: 100%|██████████| 226/226 [18:25<00:00,  4.89s/it]


Epoch [3/5], Loss: 0.1514, Train Accuracy: 0.9573, Validation Accuracy: 0.4616


Epoch 4/5: 100%|██████████| 226/226 [18:45<00:00,  4.98s/it]


Epoch [4/5], Loss: 0.1697, Train Accuracy: 0.9465, Validation Accuracy: 0.4205


Epoch 5/5: 100%|██████████| 226/226 [20:19<00:00,  5.39s/it]


Epoch [5/5], Loss: 0.2689, Train Accuracy: 0.9129, Validation Accuracy: 0.3956
