In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

#Suddivisione del dataframe in 3 age groups
_, label_map = CSVUtils.get_df_with_age_subdivision(df, 3)

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

from PIL import Image
import numpy as np
import random
def transform_image(image: Image):
    _image = (np.array(image.resize((224, 224))) / 255.0).transpose(2, 0, 1).astype(np.float32)
    # From: https://github.com/pytorch/examples/blob/main/imagenet/main.py
    _image[0] = (_image[0] - 0.485)/0.229
    _image[1] = (_image[1] - 0.456)/0.224
    _image[2] = (_image[2] - 0.406)/0.225
    if random.random() > 0.5:
        _image = _image[:, :, ::-1].copy()
    return _image

#Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeGroupAndAgeDataset(df_train, path_col="path", label_col="age", label_function="CAE", 
                                 label_map=label_map, label_map_n_classes=3, transform_func=transform_image)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_image)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_balanced_dataloader(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], samples_per_class=16)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 24
####################################################

backbone = BackBone(pretrained=True)
model_age_group = ResNetFiLMed(backbone, 3)
model_age = ResNetFiLMed(backbone, 81)
# opt = optim.Adam(set([*model_age_group.parameters(), *model_age.parameters()]), lr=1e-4)
opt = optim.SGD(set([*model_age_group.parameters(), *model_age.parameters()]), lr=1e-1, weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, 1e-2, steps_per_epoch=len(dl_train), epochs=EPOCHS)
cross_entropy = nn.CrossEntropyLoss()

In [5]:
validator = Validator(cd_val, AgeConversion.EVAge, 32)

In [6]:
age_weight = 5

best_val_aar = -1
knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*128, requires_grad=False).float().to("cuda")
for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = cross_entropy(knowledge, y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out = model_age(x, knowledge)
            out = F.softmax(out, dim=-1)
            out = AgeConversion.EVAge(out)
            y_age = AgeConversion.EVAge(y_age)
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age)) / age_weight
            
            loss_age_group.backward(retain_graph=True)
            loss_age.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy()*age_weight)

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    val_aar, val_aar_old = validator.validate(forward_function)
    print(val_aar, val_aar_old)

    if val_aar > best_val_aar or (best_val_aar==0 and val_aar==0):
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_balanced.pt")
        torch.save(model_age.state_dict(), "./model_age_balanced.pt")
        print("Saved model")

100%|██████████| 3369/3369 [26:05<00:00,  2.15 batch/s, loss_age=8.39, loss_age_group=0.49941195620340295]
100%|██████████| 4493/4493 [07:30<00:00,  9.98 batch/s]


tensor(0.6276, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [22:35<00:00,  2.49 batch/s, loss_age=7.86, loss_age_group=0.5390158374539169] 
100%|██████████| 4493/4493 [07:22<00:00, 10.14 batch/s]


tensor(1.3097, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [21:47<00:00,  2.58 batch/s, loss_age=7.13, loss_age_group=0.5067024628715444] 
100%|██████████| 4493/4493 [08:02<00:00,  9.31 batch/s]


tensor(0.8565, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [24:51<00:00,  2.26 batch/s, loss_age=6.63, loss_age_group=0.4983575397527602] 
100%|██████████| 4493/4493 [08:09<00:00,  9.17 batch/s]


tensor(1.6427, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [24:42<00:00,  2.27 batch/s, loss_age=5.92, loss_age_group=0.4418032834580572] 
100%|██████████| 4493/4493 [08:12<00:00,  9.13 batch/s]


tensor(2.1429, dtype=torch.float64) tensor(0.1429, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [24:43<00:00,  2.27 batch/s, loss_age=5.36, loss_age_group=0.4101467243665411] 
100%|██████████| 4493/4493 [08:09<00:00,  9.19 batch/s]


tensor(2.3035, dtype=torch.float64) tensor(0.3035, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [24:42<00:00,  2.27 batch/s, loss_age=5.43, loss_age_group=0.4863959180229358] 
100%|██████████| 4493/4493 [08:13<00:00,  9.10 batch/s]


tensor(1.9552, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [24:49<00:00,  2.26 batch/s, loss_age=5.5, loss_age_group=0.41166584171651266] 
100%|██████████| 4493/4493 [08:10<00:00,  9.16 batch/s]


tensor(2.0999, dtype=torch.float64) tensor(0.0999, dtype=torch.float64)


100%|██████████| 3369/3369 [24:48<00:00,  2.26 batch/s, loss_age=5.7, loss_age_group=0.42825020237819444] 
100%|██████████| 4493/4493 [08:13<00:00,  9.11 batch/s]


tensor(1.6114, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [24:50<00:00,  2.26 batch/s, loss_age=5.64, loss_age_group=0.3937569219156103] 
100%|██████████| 4493/4493 [08:06<00:00,  9.24 batch/s]


tensor(1.7287, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [25:01<00:00,  2.24 batch/s, loss_age=5.35, loss_age_group=0.4026523222778451] 
100%|██████████| 4493/4493 [08:09<00:00,  9.18 batch/s]


tensor(2.3461, dtype=torch.float64) tensor(0.3461, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [32:34<00:00,  1.72 batch/s, loss_age=5.34, loss_age_group=0.3927220993048195] 
100%|██████████| 4493/4493 [18:29<00:00,  4.05 batch/s]


tensor(2.0657, dtype=torch.float64) tensor(0.0657, dtype=torch.float64)


100%|██████████| 3369/3369 [27:40<00:00,  2.03 batch/s, loss_age=5.28, loss_age_group=0.3271855615149093] 
100%|██████████| 4493/4493 [08:08<00:00,  9.19 batch/s]


tensor(1.9008, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [22:01<00:00,  2.55 batch/s, loss_age=5.98, loss_age_group=0.4610352670524662] 
100%|██████████| 4493/4493 [07:25<00:00, 10.07 batch/s]


tensor(2.3216, dtype=torch.float64) tensor(0.3216, dtype=torch.float64)


100%|██████████| 3369/3369 [21:52<00:00,  2.57 batch/s, loss_age=5.39, loss_age_group=0.3697303320419678] 
100%|██████████| 4493/4493 [07:28<00:00, 10.01 batch/s]


tensor(1.9044, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [21:45<00:00,  2.58 batch/s, loss_age=5.9, loss_age_group=0.3772147858126118]  
100%|██████████| 4493/4493 [07:26<00:00, 10.07 batch/s]


tensor(1.9815, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [21:51<00:00,  2.57 batch/s, loss_age=4.76, loss_age_group=0.3937355206901145] 
100%|██████████| 4493/4493 [07:24<00:00, 10.10 batch/s]


tensor(1.9216, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [22:12<00:00,  2.53 batch/s, loss_age=6.37, loss_age_group=0.4974377700047512] 
100%|██████████| 4493/4493 [07:27<00:00, 10.04 batch/s]


tensor(1.9282, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 3369/3369 [22:11<00:00,  2.53 batch/s, loss_age=5.32, loss_age_group=0.3937688707783309] 
100%|██████████| 4493/4493 [07:42<00:00,  9.71 batch/s]


tensor(1.7925, dtype=torch.float64) tensor(0., dtype=torch.float64)


 12%|█▏        | 392/3369 [02:31<19:23,  2.56 batch/s, loss_age=5.05, loss_age_group=0.37848595215473324]

In [None]:
# model_age.load_state_dict(torch.load("./with_loss/model_age.pt"))
# model_age_group.load_state_dict(torch.load("./with_loss/model_age_group.pt"))