In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

#Suddivisione del dataframe in 3 age groups
_, label_map = CSVUtils.get_df_with_age_subdivision(df, 3)

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

from PIL import Image
import numpy as np
def transform_image(image: Image):
    _image = (np.array(image.resize((224, 224))) / 255.0).transpose(2, 0, 1).astype(np.float32)
    # From: https://github.com/pytorch/examples/blob/main/imagenet/main.py
    _image[0] = (_image[0] - 0.485)/0.229
    _image[1] = (_image[1] - 0.456)/0.224
    _image[2] = (_image[2] - 0.406)/0.225
    return _image

#Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeGroupAndAgeDataset(df_train, path_col="path", label_col="age", label_function="CAE", 
                                 label_map=label_map, label_map_n_classes=3, transform_func=transform_image)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_image)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=128, shuffle=True, drop_last=True)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 10
####################################################

backbone = BackBone(pretrained=True)
model_age_group = ResNetFiLMed(backbone, 3)
model_age = ResNetFiLMed(backbone, 81)
# opt = optim.Adam(set([*model_age_group.parameters(), *model_age.parameters()]), lr=1e-4)
opt = optim.SGD(set([*model_age_group.parameters(), *model_age.parameters()]), lr=1e-1, weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, 1e-2, steps_per_epoch=len(dl_train), epochs=EPOCHS)
cross_entropy = nn.CrossEntropyLoss()

In [5]:
validator = Validator(cd_val, AgeConversion.EVAge, 32)

In [6]:
age_weight = 5

best_val_aar = -1
knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*128, requires_grad=False).float().to("cuda")
for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = cross_entropy(knowledge, y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out = model_age(x, knowledge)
            out = F.softmax(out, dim=-1)
            out = AgeConversion.EVAge(out)
            y_age = AgeConversion.EVAge(y_age)
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age)) / age_weight
            
            loss_age_group.backward(retain_graph=True)
            loss_age.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy()*age_weight)

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.33, 0.33, 0.33]]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    val_aar, val_aar_old = validator.validate(forward_function)
    print(val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group.pt")
        torch.save(model_age.state_dict(), "./model_age.pt")
        print("Saved model")

100%|██████████| 3369/3369 [22:38<00:00,  2.48 batch/s, loss_age=6.25, loss_age_group=0.7837388041079976]
100%|██████████| 4493/4493 [07:11<00:00, 10.40 batch/s]


tensor(0., dtype=torch.float64) tensor(0.1802, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [22:42<00:00,  2.47 batch/s, loss_age=5.81, loss_age_group=0.6936754291818943]
100%|██████████| 4493/4493 [07:38<00:00,  9.79 batch/s]


tensor(0., dtype=torch.float64) tensor(0.9474, dtype=torch.float64)


100%|██████████| 3369/3369 [22:17<00:00,  2.52 batch/s, loss_age=5.49, loss_age_group=0.798656390297765]  
100%|██████████| 4493/4493 [07:14<00:00, 10.33 batch/s]


tensor(0., dtype=torch.float64) tensor(1.1337, dtype=torch.float64)


100%|██████████| 3369/3369 [22:14<00:00,  2.52 batch/s, loss_age=5.36, loss_age_group=0.7224537784142626] 
100%|██████████| 4493/4493 [07:06<00:00, 10.53 batch/s]


tensor(0., dtype=torch.float64) tensor(1.3614, dtype=torch.float64)


100%|██████████| 3369/3369 [21:51<00:00,  2.57 batch/s, loss_age=5.55, loss_age_group=0.6098237217540827] 
100%|██████████| 4493/4493 [06:55<00:00, 10.82 batch/s]


tensor(0., dtype=torch.float64) tensor(1.3778, dtype=torch.float64)


100%|██████████| 3369/3369 [21:59<00:00,  2.55 batch/s, loss_age=4.87, loss_age_group=0.5902453234220957] 
100%|██████████| 4493/4493 [06:59<00:00, 10.71 batch/s]


tensor(0., dtype=torch.float64) tensor(1.5133, dtype=torch.float64)


100%|██████████| 3369/3369 [21:57<00:00,  2.56 batch/s, loss_age=5.14, loss_age_group=0.6454454584145424] 
100%|██████████| 4493/4493 [07:25<00:00, 10.09 batch/s]


tensor(0., dtype=torch.float64) tensor(1.6733, dtype=torch.float64)


100%|██████████| 3369/3369 [21:54<00:00,  2.56 batch/s, loss_age=4.7, loss_age_group=0.551346996511711]   
100%|██████████| 4493/4493 [07:02<00:00, 10.64 batch/s]


tensor(0., dtype=torch.float64) tensor(1.7119, dtype=torch.float64)


100%|██████████| 3369/3369 [23:16<00:00,  2.41 batch/s, loss_age=4.59, loss_age_group=0.6067066646346575] 
100%|██████████| 4493/4493 [07:06<00:00, 10.53 batch/s]


tensor(0., dtype=torch.float64) tensor(1.8123, dtype=torch.float64)


100%|██████████| 3369/3369 [24:21<00:00,  2.30 batch/s, loss_age=5, loss_age_group=0.5735915286222735]    
100%|██████████| 4493/4493 [08:59<00:00,  8.33 batch/s]


tensor(0., dtype=torch.float64) tensor(1.8427, dtype=torch.float64)


In [9]:
# model_age.load_state_dict(torch.load("./with_loss/model_age.pt"))
# model_age_group.load_state_dict(torch.load("./with_loss/model_age_group.pt"))
torch.save(model_age_group.state_dict(), "./model_age_group_unbalanced.pt")
torch.save(model_age.state_dict(), "./model_age_unbalanced.pt")