In [1]:
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch

from torchvision.models import feature_extraction
import torchvision

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from PIL import Image
import pandas as pd
import numpy as np
import os

In [2]:
# set seeds for reproducibility
SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [3]:
# set hyperparameters
WORKDIR = "/kaggle/input/copy-of-simpsons-faces-but-3-app"
VAL_SIZE = 0.2
BATCH_SIZE = 128
N_EPOCHS = 30
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BEST_MODEL_FILEPATH = 'best_model.pt'

# feature extractor params
WEIGHTS = torchvision.models.ViT_L_32_Weights.DEFAULT
FEAT_EXTRACTOR = torchvision.models.vit_l_32
FEAT_LAYER = 'getitem_5'
FEAT_DIMS = 1024

In [4]:
# set filepaths of train and test csv files
train_csv_filepath = os.path.join(WORKDIR, 'train.csv')
test_csv_filepath = os.path.join(WORKDIR, 'test.csv')

# read train and test csv's as dataframe
df_train = pd.read_csv(train_csv_filepath)
df_test = pd.read_csv(test_csv_filepath)

In [5]:
# dict for class names and their corresponding IDs in the dataset
id2name = {df_train.iloc[i]['classId']:df_train.iloc[i]['className'] for i in range(len(df_train))}
n_class = len(id2name)
print(f"Number of classes in the dataset = {n_class}")
print(id2name)

Number of classes in the dataset = 18
{2: 'bart_simpson', 15: 'nelson_muntz', 8: 'kent_brockman', 14: 'ned_flanders', 16: 'principal_skinner', 0: 'abraham_grampa_simpson', 11: 'marge_simpson', 3: 'charles_montgomery_burns', 5: 'comic_book_guy', 7: 'homer_simpson', 6: 'edna_krabappel', 4: 'chief_wiggum', 10: 'lisa_simpson', 9: 'krusty_the_clown', 13: 'moe_szyslak', 1: 'apu_nahasapeemapetilon', 12: 'milhouse_van_houten', 17: 'sideshow_bob'}


In [6]:
x_trainval = df_train['path'].values
y_trainval = df_train['classId'].values
x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=VAL_SIZE, random_state=SEED)

x_test = df_test['path'].values

print(x_train.shape, y_train.shape)
print(x_val.shape, y_val.shape)
print(x_test.shape)

(4366,) (4366,)
(1092,) (1092,)
(596,)


In [7]:
x_train = [os.path.join(WORKDIR, 'simpsons', 'train', fn) for fn in x_train]
x_val = [os.path.join(WORKDIR, 'simpsons', 'train', fn) for fn in x_val]
x_test = [os.path.join(WORKDIR, 'simpsons', 'test', fn) for fn in x_test]

In [8]:
# initialize the feature extractor and set it to evaluation mode
feat_extractor = feature_extraction.create_feature_extractor(
    FEAT_EXTRACTOR(weights=WEIGHTS),
    [FEAT_LAYER],
).to(DEVICE)
feat_extractor.eval()

# initialize corresponding image transformations (at preprocessing step)
transforms = WEIGHTS.transforms()

Downloading: "https://download.pytorch.org/models/vit_l_32-c7638314.pth" to /root/.cache/torch/hub/checkpoints/vit_l_32-c7638314.pth
100%|██████████| 1.14G/1.14G [00:07<00:00, 163MB/s]
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [9]:
class SimpsonsFacesDataset(Dataset):
    def __init__(self, filepaths, labels, feat_extractor, transforms):
        super().__init__()
        
        self.filepaths = filepaths
        self.labels = labels
        
        self.feat_extractor = feat_extractor
        self.transforms = transforms
        
    def load_image(self, fp):
        return Image.open(fp).convert("RGB")
    
    def extract_feats(self, img):
        # preprocess the image
        x = self.transforms(img).unsqueeze(0).to(DEVICE)
        with torch.inference_mode():
            feats = self.feat_extractor(x)[FEAT_LAYER].view(-1)
        return feats
        
    def __len__(self):
        return len(self.filepaths)
    
    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        img = self.load_image(filepath)
        feats = self.extract_feats(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return feats, label

In [10]:
# initialize tran and val datasets
train_dset = SimpsonsFacesDataset(x_train, y_train, feat_extractor, transforms)
val_dset = SimpsonsFacesDataset(x_val, y_val, feat_extractor, transforms)

# initialize tran and val iterators
train_iterator = DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_iterator = DataLoader(val_dset, batch_size=BATCH_SIZE, drop_last=True)

In [11]:
class SimpsonsImageClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, 4*input_size),
            nn.BatchNorm1d(4*input_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(4*input_size, output_size),
        )
        
    def forward(self, x):
        logits = self.net(x)
        return logits

In [12]:
def evaluate(model, iterator):
    
    global DEVICE
    
    model.eval()
    
    loss_sum = 0.0
    y_preds, y_gts = [], []    
    
    with torch.inference_mode():
        
        for iter_idx, (x, y) in tqdm(enumerate(iterator), total=len(iterator)):
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            
            preds = torch.argmax(logits, dim=1)
            y_preds.append(preds.cpu())
            y_gts.append(y.cpu())
        
            loss = F.cross_entropy(logits, y)
            loss_sum += loss.item()
    
        y_preds, y_gts = torch.hstack(y_preds).numpy(), torch.hstack(y_gts).numpy()
        acc = accuracy_score(y_gts, y_preds)
        avg_loss = loss_sum / len(iterator)
    return avg_loss, acc

In [13]:
# initialize classifier
model = SimpsonsImageClassifier(input_size=FEAT_DIMS, output_size=n_class).to(DEVICE)
# initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

best_val_loss = float('inf')
best_streak = 0

for epoch_idx in range(1, N_EPOCHS + 1):
    model.train()
    train_loss_sum = 0.0
    for iter_idx, (x, y) in tqdm(enumerate(train_iterator), total=len(train_iterator)):
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss_sum += loss.item()
        
    avg_train_loss = train_loss_sum / len(train_iterator)
    avg_val_loss, val_acc = evaluate(model, val_iterator)
    print(f"Epoch {epoch_idx} | train loss = {avg_train_loss:.3f}\tval loss = {avg_val_loss:.3f}\tval acc = {val_acc:.3f}")
    
    if avg_val_loss < best_val_loss:
        torch.save(model.state_dict(), BEST_MODEL_FILEPATH)
        best_val_loss = avg_val_loss
        best_streak = 0
        print(f'The best model is found and saved. Current best validation loss = {best_val_loss:.3f}')
    else:
        best_streak += 1
        
    if best_streak == 5:
        print('Best model has not been found in the last 5 epochs. Early stopping...')
        break

100%|██████████| 34/34 [01:32<00:00,  2.73s/it]
100%|██████████| 8/8 [00:21<00:00,  2.68s/it]


Epoch 1 | train loss = 1.653	val loss = 1.019	val acc = 0.812
The best model is found and saved. Current best validation loss = 1.019


100%|██████████| 34/34 [01:11<00:00,  2.10s/it]
100%|██████████| 8/8 [00:17<00:00,  2.18s/it]


Epoch 2 | train loss = 0.621	val loss = 0.563	val acc = 0.873
The best model is found and saved. Current best validation loss = 0.563


100%|██████████| 34/34 [01:12<00:00,  2.12s/it]
100%|██████████| 8/8 [00:16<00:00,  2.07s/it]


Epoch 3 | train loss = 0.368	val loss = 0.442	val acc = 0.889
The best model is found and saved. Current best validation loss = 0.442


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


Epoch 4 | train loss = 0.243	val loss = 0.387	val acc = 0.902
The best model is found and saved. Current best validation loss = 0.387


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 5 | train loss = 0.174	val loss = 0.351	val acc = 0.908
The best model is found and saved. Current best validation loss = 0.351


100%|██████████| 34/34 [01:08<00:00,  2.02s/it]
100%|██████████| 8/8 [00:16<00:00,  2.01s/it]


Epoch 6 | train loss = 0.125	val loss = 0.327	val acc = 0.914
The best model is found and saved. Current best validation loss = 0.327


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 7 | train loss = 0.092	val loss = 0.309	val acc = 0.915
The best model is found and saved. Current best validation loss = 0.309


100%|██████████| 34/34 [01:09<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.07s/it]


Epoch 8 | train loss = 0.071	val loss = 0.297	val acc = 0.917
The best model is found and saved. Current best validation loss = 0.297


100%|██████████| 34/34 [01:12<00:00,  2.12s/it]
100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


Epoch 9 | train loss = 0.056	val loss = 0.284	val acc = 0.921
The best model is found and saved. Current best validation loss = 0.284


100%|██████████| 34/34 [01:11<00:00,  2.09s/it]
100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


Epoch 10 | train loss = 0.045	val loss = 0.278	val acc = 0.922
The best model is found and saved. Current best validation loss = 0.278


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


Epoch 11 | train loss = 0.036	val loss = 0.272	val acc = 0.923
The best model is found and saved. Current best validation loss = 0.272


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 12 | train loss = 0.032	val loss = 0.267	val acc = 0.925
The best model is found and saved. Current best validation loss = 0.267


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


Epoch 13 | train loss = 0.026	val loss = 0.263	val acc = 0.926
The best model is found and saved. Current best validation loss = 0.263


100%|██████████| 34/34 [01:10<00:00,  2.08s/it]
100%|██████████| 8/8 [00:16<00:00,  2.07s/it]


Epoch 14 | train loss = 0.022	val loss = 0.259	val acc = 0.929
The best model is found and saved. Current best validation loss = 0.259


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.03s/it]


Epoch 15 | train loss = 0.020	val loss = 0.256	val acc = 0.927
The best model is found and saved. Current best validation loss = 0.256


100%|██████████| 34/34 [01:10<00:00,  2.08s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 16 | train loss = 0.017	val loss = 0.254	val acc = 0.929
The best model is found and saved. Current best validation loss = 0.254


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.07s/it]


Epoch 17 | train loss = 0.016	val loss = 0.252	val acc = 0.928
The best model is found and saved. Current best validation loss = 0.252


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 18 | train loss = 0.014	val loss = 0.249	val acc = 0.931
The best model is found and saved. Current best validation loss = 0.249


100%|██████████| 34/34 [01:10<00:00,  2.07s/it]
100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


Epoch 19 | train loss = 0.013	val loss = 0.249	val acc = 0.930
The best model is found and saved. Current best validation loss = 0.249


100%|██████████| 34/34 [01:10<00:00,  2.07s/it]
100%|██████████| 8/8 [00:16<00:00,  2.03s/it]


Epoch 20 | train loss = 0.012	val loss = 0.247	val acc = 0.932
The best model is found and saved. Current best validation loss = 0.247


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


Epoch 21 | train loss = 0.011	val loss = 0.247	val acc = 0.932


100%|██████████| 34/34 [01:09<00:00,  2.04s/it]
100%|██████████| 8/8 [00:16<00:00,  2.05s/it]


Epoch 22 | train loss = 0.009	val loss = 0.247	val acc = 0.932
The best model is found and saved. Current best validation loss = 0.247


100%|██████████| 34/34 [01:10<00:00,  2.07s/it]
100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


Epoch 23 | train loss = 0.009	val loss = 0.248	val acc = 0.931


100%|██████████| 34/34 [01:10<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 24 | train loss = 0.008	val loss = 0.244	val acc = 0.934
The best model is found and saved. Current best validation loss = 0.244


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.09s/it]


Epoch 25 | train loss = 0.007	val loss = 0.243	val acc = 0.935
The best model is found and saved. Current best validation loss = 0.243


100%|██████████| 34/34 [01:09<00:00,  2.04s/it]
100%|██████████| 8/8 [00:16<00:00,  2.02s/it]


Epoch 26 | train loss = 0.007	val loss = 0.244	val acc = 0.932


100%|██████████| 34/34 [01:09<00:00,  2.05s/it]
100%|██████████| 8/8 [00:16<00:00,  2.07s/it]


Epoch 27 | train loss = 0.006	val loss = 0.243	val acc = 0.933


100%|██████████| 34/34 [01:09<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.09s/it]


Epoch 28 | train loss = 0.006	val loss = 0.242	val acc = 0.934
The best model is found and saved. Current best validation loss = 0.242


100%|██████████| 34/34 [01:09<00:00,  2.06s/it]
100%|██████████| 8/8 [00:16<00:00,  2.04s/it]


Epoch 29 | train loss = 0.005	val loss = 0.242	val acc = 0.935
The best model is found and saved. Current best validation loss = 0.242


100%|██████████| 34/34 [01:10<00:00,  2.07s/it]
100%|██████████| 8/8 [00:16<00:00,  2.09s/it]

Epoch 30 | train loss = 0.005	val loss = 0.244	val acc = 0.933





In [15]:
# reinitialize the model
best_model = SimpsonsImageClassifier(input_size=FEAT_DIMS, output_size=n_class)
# load best model's parameters
best_model.load_state_dict(torch.load(BEST_MODEL_FILEPATH))
best_model = best_model.to(DEVICE)

In [28]:
def predict(img, feat_extractor, model, transforms):
    
    global DEVICE
    
    model.eval()
    
    # preprocess the image
    x = transforms(img).unsqueeze(0).to(DEVICE)
    
    with torch.inference_mode():
        feats = feat_extractor(x)[FEAT_LAYER]
        logits = model(feats)
        probs = F.softmax(logits, dim=1)
        pred = torch.argmax(probs).item()

    return pred

In [29]:
path_list, classId_list = [], []

for fp_test in tqdm(x_test):
    img_test = Image.open(fp_test).convert("RGB")
    pred = predict(img_test, feat_extractor, best_model, transforms)
    # append
    path_list.append(os.path.basename(fp_test))    
    classId_list.append(pred)

100%|██████████| 596/596 [00:11<00:00, 52.56it/s]


In [31]:
columns = ['path', 'classId']
df_submission = pd.DataFrame(
    {
        'path': path_list,
        'classId': classId_list,
    }
)
df_submission.to_csv('submission.csv', index=False)