# Pokemon Classification

Train classification model on blurred pokemon images. Augmentations are applied and the fourth blurred image of each Pokémon is always held out for the test set when test_mode=True. Accuracy on test set over 90%

In [None]:
import os
from PIL import Image, ImageFilter
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np

# Define a custom dataset class for loading images with specific test set logic
class PokemonDataset(Dataset):
    def __init__(self, folder_path, transform=None, augment=False, test_mode=False):
        self.folder_path = folder_path
        self.transform = transform
        self.augment = augment
        self.test_mode = test_mode
        self.images = []
        self.labels = []
        
        # Dictionary to store images for each Pokémon
        pokemon_images = {}
        
        # Load images and labels from the dataset folder
        for filename in os.listdir(folder_path):
            if filename.endswith('.png'):
                image_path = os.path.join(folder_path, filename)
                image = Image.open(image_path).convert('RGB')  # Convert to RGB to ensure 3 channels
                
                # Extract the label from the filename (e.g., "abra.png" -> "abra")
                label = filename.split('.')[0].lower()
                
                # Create a list for each Pokémon if not already created
                if label not in pokemon_images:
                    pokemon_images[label] = []
                
                # Append the original blurred image
                blurred_image = image.filter(ImageFilter.GaussianBlur(radius=2))
                pokemon_images[label].append(blurred_image)
                
                # Add augmentations if enabled
                if augment:
                    cropped_image = blurred_image.crop((10, 10, 110, 110)).resize(image.size)  # Cropped
                    pokemon_images[label].append(cropped_image)
                    
                    flipped_image = blurred_image.transpose(Image.FLIP_LEFT_RIGHT)  # Horizontally flipped
                    pokemon_images[label].append(flipped_image)

                    rotated_image = blurred_image.rotate(15)  # Slightly rotated
                    pokemon_images[label].append(rotated_image)
        
        # Separate training and test images
        for label, images in pokemon_images.items():
            if len(images) >= 4:  # Ensure there are enough images to use the fourth for testing
                # Append the fourth image to test set if in test mode
                if test_mode:
                    self.images.append(images[3])
                    self.labels.append(label)
                else:
                    # Append all but the fourth image to the training set
                    self.images.extend(images[:3] + images[4:])
                    self.labels.extend([label] * (len(images) - 1))

        # Map labels to indices for training
        self.label_to_index = {label: idx for idx, label in enumerate(set(self.labels))}
        self.index_to_label = {idx: label for label, idx in self.label_to_index.items()}

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        label_idx = self.label_to_index[label]
        
        return image, label_idx

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for compatibility with pre-trained models
    transforms.ToTensor()
])

# Load training and test datasets
dataset_path = 'dataset'
train_dataset = PokemonDataset(dataset_path, transform=transform, augment=True, test_mode=False)
test_dataset = PokemonDataset(dataset_path, transform=transform, augment=True, test_mode=True)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Define the model (ResNet)
class PokemonClassifier(nn.Module):
    def __init__(self, num_classes):
        super(PokemonClassifier, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# Initialize model, criterion, and optimizer
num_classes = len(train_dataset.label_to_index)
model = PokemonClassifier(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

# Evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total}%')

Epoch 1/10, Loss: 4.298060512542724
Epoch 2/10, Loss: 2.6347580035527547
Epoch 3/10, Loss: 1.7910629947980246
Epoch 4/10, Loss: 1.2924952149391173
Epoch 5/10, Loss: 0.9369195600350698
Epoch 6/10, Loss: 0.7014775663614273
Epoch 7/10, Loss: 0.6871185819307963
Epoch 8/10, Loss: 0.5396935532490412
Epoch 9/10, Loss: 0.36071471323569615
Epoch 10/10, Loss: 0.33459382504224777
Accuracy on test set: 95.0%