In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeGroupAndAgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

#Suddivisione del dataframe in 3 age groups
_, label_map = CSVUtils.get_df_with_age_subdivision(df, 3)

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandomGrayscale(),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeGroupAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="CAE", 
                                   label_map=label_map, label_map_n_classes=3, transform_func=transform_func)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_balanced_class_dataloader(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                    batch_size=128, num_workers=20, prefetch_factor=4)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone, ResNetNotFiLMed, DoNothingLayer
from torchvision.models import resnet18, ResNet18_Weights
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 8
####################################################

backbone = resnet18(ResNet18_Weights.IMAGENET1K_V1)
backbone.fc = DoNothingLayer()
backbone.train()
backbone.requires_grad_(False)
backbone.to("cuda")
model_age = ResNetNotFiLMed(backbone, 81)
opt = optim.SGD(model_age.fc0.parameters(), lr=5e-3, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, 1e-2, steps_per_epoch=len(dl_train), epochs=EPOCHS)
kl = nn.KLDivLoss(reduction="batchmean")



In [5]:
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=8, prefetch_factor=4)

In [6]:
model_age.load_state_dict(torch.load("./model_age_balanced_simple_5__5.87.pt"))

<All keys matched successfully>

In [7]:
def forward_function(x):
    out = model_age(x)
    out = F.softmax(out, dim=-1)
    return out

val_aar, val_aar_old = validator.validate(forward_function)
print(val_aar, val_aar_old)

100%|██████████| 3369/3369 [17:14<00:00,  3.26 batch/s, loss_age_kl=1.5370684995596045] 
100%|██████████| 4493/4493 [03:34<00:00, 20.93 batch/s]


tensor(4.4744, dtype=torch.float64) tensor(3.9651, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [17:05<00:00,  3.29 batch/s, loss_age_kl=1.6840404195254868]
100%|██████████| 4493/4493 [01:51<00:00, 40.19 batch/s]


tensor(4.1667, dtype=torch.float64) tensor(3.4907, dtype=torch.float64)


100%|██████████| 3369/3369 [15:36<00:00,  3.60 batch/s, loss_age_kl=1.6396318918468586]
100%|██████████| 4493/4493 [01:47<00:00, 41.93 batch/s]


tensor(4.1366, dtype=torch.float64) tensor(3.5109, dtype=torch.float64)


 27%|██▋       | 913/3369 [04:42<06:39,  6.15 batch/s, loss_age_kl=1.753517192095317]  