Main Imports

In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeGroupAndAgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

#Suddivisione del dataframe in 3 age groups
_, d = CSVUtils.get_df_with_age_subdivision(df, 3)

In [3]:
#Splitting tra Train e Validation set
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
#Aggiornamento degli indici per pandas
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

from PIL import Image
import numpy as np
from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandomGrayscale(),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

#Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeGroupAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="CAE", 
                                 label_map=d, label_map_n_classes=3, transform_func=transform_func)

#Implementazione di un Dataset che adatta le label all'utilizzo che vogliamo farne (in questo caso CAE = Cathegorical)
cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)

#Dato che lo split potrebbe non prendere sample di determinate classi facciamo il set del numero di classi
cd_train.set_n_classes(101)

#Loader che, conoscendo la grandezza del dataset, farà lo shuffle dei campioni e crea i batch
dm_train, dm_val = CustomDataLoader(cd_train), CustomDataLoader(cd_val)
#Utilizziamo il dataloader che crea batch bilanciati implementato in CustomDataLoaders
# In generale non usare questo ma l'unbalanced dato che non è usato in nessun doc
dl_train = dm_train.get_unbalanced_dataloader(batch_size=128 ,shuffle=True, num_workers=6, prefetch_factor=4)


In [4]:
import torchvision.models as models

# Initialize the ResNet18 model with pre-trained parameters on ImageNet
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Change the dimension of the fully connected layer to K
K = 101
model.fc = nn.Linear(model.fc.in_features, K)

In [5]:
#Validator che si occuperà di tenere in considerazione le metriche da massimizzare per il contest
validator = Validator(cd_val, AgeConversion.ArgMaxAge, 128)

In [6]:
import torchvision.transforms as transforms

# Define the transformation that resizes the image and flips it horizontally with a probability of 0.5
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5)
])

In [7]:
# Define the KL divergence loss
def kl_divergence_loss(pred, target):
    # Calculate the KL divergence between the label distribution and the predicted age distribution
    loss = nn.KLDivLoss(reduction='batchmean')(pred.log(), target)
    return loss


# Define the L1 loss
def l1_loss(pred, target):
    # Calculate the L1 loss between the predicted age and the ground truth label
    loss = nn.L1Loss()(pred, target)
    return loss

# Combine the KL divergence loss and L1 loss
def combined_loss(pred, target):
    return kl_divergence_loss(pred, target) + l1_loss(pred, target)


In [8]:
import torch.optim as optim

best_val_aar = -1
# Define the optimization algorithm
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

# Define the learning rate scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60], gamma=0.1)

model.to("cuda")
# Train the model for a total of 75 epochs
for epoch in range(75):
    # Train the model on the training data
    model.train()
    with tqdm(dl_train, unit=" batch") as tepoch:
        for input, target in tepoch:
            input = input.to("cuda")
            target = target[-1].to("cuda")
            # Clear the gradients
            optimizer.zero_grad()

            # Forward pass
            output = F.softmax(model(input), dim=-1)

            # Calculate the loss
            loss = combined_loss(output, target)

            # Backward pass
            loss.backward()

            # Update the model parameters
            optimizer.step()

            tepoch.set_postfix(loss=loss.detach().cpu().numpy()) 
        
    # Decrease the learning rate
    scheduler.step()

    def forward_function(x):
        return F.softmax(model(x), dim=-1)

    val_aar, val_aar_old = validator.validate(forward_function)
    print(val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model.state_dict(), "./model_age_baseline.pt")
        print("Model saved")


100%|██████████| 3370/3370 [29:37<00:00,  1.90 batch/s, loss=1.550241681114393] 
100%|██████████| 1124/1124 [11:26<00:00,  1.64 batch/s]


tensor(0.7332) tensor(3.2873)
Model saved


100%|██████████| 3370/3370 [27:02<00:00,  2.08 batch/s, loss=1.5913941467935568]
100%|██████████| 1124/1124 [06:03<00:00,  3.09 batch/s]


tensor(1.8389) tensor(3.5043)
Model saved


100%|██████████| 3370/3370 [26:10<00:00,  2.15 batch/s, loss=1.3799243464590971]
100%|██████████| 1124/1124 [05:51<00:00,  3.20 batch/s]


tensor(2.8280) tensor(3.9404)
Model saved


100%|██████████| 3370/3370 [25:55<00:00,  2.17 batch/s, loss=1.2814616227887252]
100%|██████████| 1124/1124 [05:47<00:00,  3.23 batch/s]


tensor(3.3062) tensor(4.2741)
Model saved


100%|██████████| 3370/3370 [26:05<00:00,  2.15 batch/s, loss=1.4452313129793148]
100%|██████████| 1124/1124 [05:49<00:00,  3.22 batch/s]


tensor(3.3052) tensor(4.2903)


100%|██████████| 3370/3370 [26:19<00:00,  2.13 batch/s, loss=1.1852652929740393]
100%|██████████| 1124/1124 [05:51<00:00,  3.20 batch/s]


tensor(3.1249) tensor(4.1482)


100%|██████████| 3370/3370 [25:29<00:00,  2.20 batch/s, loss=1.380918563752509] 
100%|██████████| 1124/1124 [05:58<00:00,  3.14 batch/s]


tensor(3.8820) tensor(4.5617)
Model saved


100%|██████████| 3370/3370 [26:17<00:00,  2.14 batch/s, loss=1.2774677652363395]
100%|██████████| 1124/1124 [05:49<00:00,  3.22 batch/s]


tensor(3.9958) tensor(4.7356)
Model saved


100%|██████████| 3370/3370 [26:08<00:00,  2.15 batch/s, loss=1.320111941851504] 
100%|██████████| 1124/1124 [05:47<00:00,  3.24 batch/s]


tensor(3.9632) tensor(4.7728)


100%|██████████| 3370/3370 [24:26<00:00,  2.30 batch/s, loss=1.37745122997859]  
100%|██████████| 1124/1124 [05:43<00:00,  3.27 batch/s]


tensor(4.2664) tensor(4.9828)
Model saved


100%|██████████| 3370/3370 [24:52<00:00,  2.26 batch/s, loss=1.3061048887902151]
100%|██████████| 1124/1124 [06:00<00:00,  3.12 batch/s]


tensor(3.9983) tensor(4.7992)


100%|██████████| 3370/3370 [24:24<00:00,  2.30 batch/s, loss=1.0694381833196762]
100%|██████████| 1124/1124 [05:47<00:00,  3.23 batch/s]


tensor(3.9797) tensor(4.7391)


100%|██████████| 3370/3370 [25:35<00:00,  2.19 batch/s, loss=1.1617215321009176]
100%|██████████| 1124/1124 [05:42<00:00,  3.28 batch/s]


tensor(4.9322) tensor(5.4152)
Model saved


100%|██████████| 3370/3370 [24:24<00:00,  2.30 batch/s, loss=1.1989155125754247]
100%|██████████| 1124/1124 [05:40<00:00,  3.30 batch/s]


tensor(4.6992) tensor(5.2669)


100%|██████████| 3370/3370 [27:05<00:00,  2.07 batch/s, loss=1.0663645169928682]
100%|██████████| 1124/1124 [08:45<00:00,  2.14 batch/s]


tensor(4.0498) tensor(4.8838)


100%|██████████| 3370/3370 [29:38<00:00,  1.89 batch/s, loss=1.0724023157030758]
100%|██████████| 1124/1124 [05:48<00:00,  3.23 batch/s]


tensor(4.5113) tensor(5.2133)


100%|██████████| 3370/3370 [27:32<00:00,  2.04 batch/s, loss=1.1447826310447502]
100%|██████████| 1124/1124 [05:56<00:00,  3.15 batch/s]


tensor(4.8884) tensor(5.4632)


 13%|█▎        | 447/3370 [03:51<25:12,  1.93 batch/s, loss=1.1788868882977528] 


KeyboardInterrupt: 