In [33]:
# !pip install torch torchvision scikit-learn lgbt kagglehub numpy matplotlib tensorboard
# %load_ext tensorboard

In [34]:
import datetime

import kagglehub
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets, models
from sklearn.metrics import precision_score, f1_score, accuracy_score, classification_report
from lgbt import lgbt
import matplotlib.pyplot as plt
import numpy as np

In [35]:
path = kagglehub.dataset_download("bolg4rin/simpson-dataset-fixed")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/bolg4rin/simpson-dataset-fixed/versions/3


In [36]:
def model_init ():
    model = models.resnet50(pretrained = True)
    model.fc = nn.Linear(2048,42)
    model = model.cuda()
    loss = nn.CrossEntropyLoss().cuda()
    optim = torch.optim.Adam(model.parameters(),lr=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=4, gamma=0.1)

    return model, loss, optim, scheduler

In [37]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


original_train = datasets.ImageFolder(
    root=f'{path}/the_simpson_dataset/train',
    transform=train_transform
)

original_test = datasets.ImageFolder(
    root=f'{path}/the_simpson_dataset/test',
    transform=val_test_transform
)

train_size = int(0.8 * len(original_train))
val_size = len(original_train) - train_size
new_train, val_dataset = random_split(original_train, [train_size, val_size])

val_dataset.dataset.transform = val_test_transform

batch_size = 32
pin_memory = torch.cuda.is_available()  # Enable only if GPU available

train_loader = DataLoader(
    new_train, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4,
    pin_memory=pin_memory
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=pin_memory
)

test_loader = DataLoader(
    original_test,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=pin_memory
)

class_names = original_train.classes
num_classes = len(class_names)

In [38]:
def valid_model(model, loss):
    model.eval()
    predictions = []
    targets = []

    validation_loss = 0

    val_bar = lgbt(val_loader, desc=f'Validation')

    with torch.no_grad():
        for inputs, labels in val_bar:
            inputs = inputs.cuda()
            labels = labels.cuda()
            outputs = model(inputs)
            loss_f = loss(outputs, labels)
            validation_loss += loss_f.item()
            predictions.extend(torch.argmax(outputs,dim=1).cpu().numpy())
            targets.extend(labels.cpu().numpy())
                
    macro_f1 = f1_score(targets, predictions, average='macro')

    return macro_f1, validation_loss

In [39]:
def train_model(model, loss, optim, scheduler):
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = f"logs/simpsons_{current_time}"
    writer = SummaryWriter(log_dir)

    num_epochs = 10

    for epoch in range(num_epochs):
        train_bar = lgbt(train_loader, desc=f'Train {epoch+1}/{num_epochs}', hero = "unicorn")

        running_loss = 0

        model.train()
        for inputs, labels in train_bar:
            inputs = inputs.cuda()
            labels = labels.cuda()
            # writer.add_graph(model, inputs)
            
            optim.zero_grad()
        
            outputs = model(inputs)
            loss_f = loss(outputs, labels)
            running_loss += loss_f.item()
            loss_f.backward()
            optim.step()
            
        torch.save(model.state_dict(), f'simpsons_model/simpsons_scheduler{epoch}.pth')

        macro_f1, validation_loss = valid_model(model, loss)
        
        running_loss /= len(train_loader)
        validation_loss /= len(val_loader)
        print(f'Epoch {epoch+1}:\tLoss {running_loss}\tValidation loss {validation_loss}\tValidation F1 {(macro_f1*100):.2f}')

        writer.add_scalar('Loss/train', running_loss, epoch)
        writer.add_scalar('Loss/validation', validation_loss, epoch)
        writer.add_scalar('F1/validation', macro_f1, epoch)

        writer.add_scalar('Learning Rate', optim.param_groups[0]['lr'], epoch)
        
        scheduler.step()

In [40]:
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 2214), started 6:40:02 ago. (Use '!kill 2214' to kill it.)

In [41]:
simpsons_model, loss_func, optimizer, scheduler = model_init()
train_model(simpsons_model, loss_func, optimizer, scheduler)



🦄Train 1/10  :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[524/524] [123.35s, 4.25it/s]  [m8it/s]  [m
🌈Validation  :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[131/131] [9.72s, 13.47it/s]  [m9it/s]  [m
Epoch 1:	Loss 0.6025746708046218	Validation loss 0.19334294375046404	Validation F1 74.10
🦄Train 2/10  :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[524/524] [118.91s, 4.41it/s]  [m6it/s]  [m
🌈Validation  :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[131/131] [9.96s, 13.16it/s]  [m8it/s]  [m
Epoch 2:	Loss 0.1136788714637994	Validation loss 0.1705976187585647	Validation F1 90.93
🦄Train 3/10  :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[524/524] [

In [46]:
def test_model(model):
    preds = []
    targets = []
    class_preds = []
    class_targets = []
    for i in range(42):
        class_preds.append([])
        class_targets.append([])

    test_bar = lgbt(test_loader, desc='Test', hero = 'kitten')

    model.eval()
    with torch.no_grad():
        for inputs, target in test_bar:
            inputs = inputs.cuda()
            target = target.cuda()
            outputs = model(inputs)
            pred = torch.argmax(outputs, dim=1)
            for i in range (len(target)):
                class_code = target[i]
                class_targets[class_code].append(class_code.cpu().numpy())
                class_preds[class_code].append(pred[i].cpu().numpy())
            preds.extend(pred.cpu().numpy())
            targets.extend(target.cpu().numpy())

    macro_f1 = f1_score(targets, preds, average='macro')
    accuracy = accuracy_score(targets,preds)
    print(f'Macro F1: {(macro_f1*100):.2f}%')
    print(f'Accuracy: {(accuracy*100):.2f}%')

    for i in range(len(class_preds)):
        acc = accuracy_score(class_preds[i], class_targets[i])
        print(f'{original_test.classes[i]} accuracy: {(acc*100):.2f}%')

In [47]:
test_model(simpsons_model)

🐱Test        :[35m100% [31m▋▋▋▋▋▋▋▋[38;5;214m▋▋▋▋▋▋▋▋[33m▋▋▋▋▋▋▋▋[32m▋▋▋▋▋▋▋▋[36m▋▋▋▋▋▋▋▋[34m▋▋▋▋▋▋▋▋[35m▋▋▋▋▋▋▋▋[35m[14/14] [1.23s, 11.40it/s]  [mit/s]  [m
Macro F1: 91.63%
Accuracy: 92.05%
abraham_grampa_simpson accuracy: 100.00%
agnes_skinner accuracy: 40.00%
apu_nahasapeemapetilon accuracy: 100.00%
barney_gumble accuracy: 70.00%
bart_simpson accuracy: 100.00%
carl_carlson accuracy: 100.00%
charles_montgomery_burns accuracy: 100.00%
chief_wiggum accuracy: 90.91%
cletus_spuckler accuracy: 100.00%
comic_book_guy accuracy: 100.00%
disco_stu accuracy: 80.00%
edna_krabappel accuracy: 100.00%
fat_tony accuracy: 90.00%
gil accuracy: 90.00%
groundskeeper_willie accuracy: 80.00%
homer_simpson accuracy: 100.00%
kent_brockman accuracy: 100.00%
krusty_the_clown accuracy: 100.00%
lenny_leonard accuracy: 100.00%
lionel_hutz accuracy: 100.00%
lisa_simpson accuracy: 100.00%
maggie_simpson accuracy: 90.00%
marge_simpson accuracy: 100.00%
martin_prince accuracy: 90.00%
mayor_quimby accurac