In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

#Suddivisione del dataframe in 3 age groups
_, label_map = CSVUtils.get_df_with_age_subdivision(df, 3)

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

from PIL import Image
import numpy as np
def transform_image(image: Image):
    _image = (np.array(image.resize((224, 224))) / 255.0).transpose(2, 0, 1).astype(np.float32)
    # From: https://github.com/pytorch/examples/blob/main/imagenet/main.py
    _image[0] = (_image[0] - 0.485)/0.229
    _image[1] = (_image[1] - 0.456)/0.224
    _image[2] = (_image[2] - 0.406)/0.225
    return _image

#Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeGroupAndAgeDataset(df_train, path_col="path", label_col="age", label_function="CAE", 
                                 label_map=label_map, label_map_n_classes=3, transform_func=transform_image)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_image)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=256, shuffle=True, drop_last=True)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

backbone = BackBone(pretrained=True)
model_age_group = ResNetFiLMed(backbone, 3)
model_age = ResNetFiLMed(backbone, 81)
opt = optim.Adam(set([*model_age_group.parameters(), *model_age.parameters()]), lr=1e-4)
cross_entropy = nn.CrossEntropyLoss()

In [5]:
validator = Validator(cd_val, AgeConversion.EVAge, 32)

In [7]:
epochs = 10
age_weight = 5

best_val_aar = -1
knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*32, requires_grad=False).float().to("cuda")
for e in range(epochs):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = cross_entropy(knowledge, y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out = model_age(x, knowledge)
            out = F.softmax(out, dim=-1)
            out = AgeConversion.EVAge(out)
            y_age = AgeConversion.EVAge(y_age)
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age)) / age_weight
            
            loss_age_group.backward(retain_graph=True)
            loss_age.backward()
            opt.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy()*age_weight)

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    val_aar, val_aar_old = validator.validate(forward_function)
    print(val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group.pt")
        torch.save(model_age.state_dict(), "./model_age.pt")
        print("Saved model")

100%|██████████| 13478/13478 [55:14<00:00,  4.07 batch/s, loss_age=6.74, loss_age_group=0.8131173274014145] 
100%|██████████| 4493/4493 [13:38<00:00,  5.49 batch/s]


tensor(0., dtype=torch.float64) tensor(0.6362, dtype=torch.float64)
Saved model


100%|██████████| 13478/13478 [27:29<00:00,  8.17 batch/s, loss_age=6.24, loss_age_group=0.5815560481860302] 
100%|██████████| 4493/4493 [07:36<00:00,  9.84 batch/s]


tensor(0., dtype=torch.float64) tensor(0.9328, dtype=torch.float64)


100%|██████████| 13478/13478 [28:13<00:00,  7.96 batch/s, loss_age=7.74, loss_age_group=0.8740766658447683] 
100%|██████████| 4493/4493 [07:38<00:00,  9.80 batch/s]


tensor(0., dtype=torch.float64) tensor(1.0922, dtype=torch.float64)


100%|██████████| 13478/13478 [28:14<00:00,  7.96 batch/s, loss_age=4.55, loss_age_group=0.5556758943712339] 
100%|██████████| 4493/4493 [07:33<00:00,  9.90 batch/s]


tensor(0., dtype=torch.float64) tensor(1.1966, dtype=torch.float64)


100%|██████████| 13478/13478 [28:14<00:00,  7.95 batch/s, loss_age=5.6, loss_age_group=0.715615825552959]   
100%|██████████| 4493/4493 [07:38<00:00,  9.79 batch/s]


tensor(0., dtype=torch.float64) tensor(1.3012, dtype=torch.float64)


100%|██████████| 13478/13478 [28:08<00:00,  7.98 batch/s, loss_age=5.65, loss_age_group=0.6398524850374088] 
100%|██████████| 4493/4493 [07:50<00:00,  9.56 batch/s]


tensor(0., dtype=torch.float64) tensor(1.3868, dtype=torch.float64)


 20%|██        | 2747/13478 [05:39<22:06,  8.09 batch/s, loss_age=5.15, loss_age_group=0.7133738825796172] 


KeyboardInterrupt: 

In [8]:
# model_age.load_state_dict(torch.load("./with_loss/model_age.pt"))
# model_age_group.load_state_dict(torch.load("./with_loss/model_age_group.pt"))