In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
aug = CSVUtils.get_df_from_csv("./augumentation_two_age.csv", "./newAugmentationDataset/")
df_train = pd.concat([df_train, aug], ignore_index=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                        transform_func=transform_func)
cd_train.set_n_classes(81)
cd_train.set_starting_class(1)
dm_train = CustomDataLoader(cd_train)
# dl_train = dm_train.get_unbalanced_dataloader(batch_size=64, shuffle=True, drop_last=True, num_workers=16, prefetch_factor=4, pin_memory=True)
dl_train, sampler = dm_train.get_balanced_class_dataloader2(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                            batch_size=64, num_workers=12, prefetch_factor=4, pin_memory=True)
sampler.n_batches = 2000
cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 64, num_workers=8, prefetch_factor=4)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone, ResNetNotFiLMed, DoNothingLayer
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_b0, EfficientNet_B0_Weights
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 24
####################################################

backbone = resnet18(ResNet18_Weights.IMAGENET1K_V1)
backbone.fc = DoNothingLayer()
backbone.train()
backbone.requires_grad_(True)
backbone.to("cuda")
model_age = ResNetNotFiLMed(backbone, 81)
opt = optim.SGD(set([*backbone.parameters(), *model_age.fc0.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)
kl = nn.KLDivLoss(reduction="batchmean")



In [5]:
best_val_aar = -1
import numpy as np

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age: torch.Tensor = y[0].to("cuda")
            y_age_kl: torch.Tensor = y[1].to("cuda")

            out_age = model_age(x)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age = torch.mean(torch.abs(y_age - out))

            loss = loss_age_kl + loss_age
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_kl=loss_age_kl.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy())

    def forward_function(x):
        out = model_age(x)
        out = F.softmax(out, dim=-1)
        return out

    ae, mae, val_aar, val_aar_old = validator.validate_ext2(forward_function)
    print(ae)
    print(mae, val_aar, val_aar_old)

    aes = np.array([float(x) for x in list(ae.values())])
    sampler.p = aes/aes.sum()

    if best_val_aar < val_aar:
        best_val_aar = val_aar
        torch.save(model_age.state_dict(), "./model_age_feature_simple.pt")
        print("Saved model")

100%|██████████| 2000/2000 [02:49<00:00, 11.82 batch/s, loss_age=4.308605, loss_age_kl=1.7389091248414785] 
100%|██████████| 2247/2247 [01:38<00:00, 22.77 batch/s]


{0: tensor(7.1873, dtype=torch.float64), 1: tensor(7.1145, dtype=torch.float64), 2: tensor(8.8504, dtype=torch.float64), 3: tensor(11.7164, dtype=torch.float64), 4: tensor(12.7811, dtype=torch.float64), 5: tensor(11.2289, dtype=torch.float64), 6: tensor(6.7233, dtype=torch.float64), 7: tensor(4.5071, dtype=torch.float64)}
tensor(10.5586, dtype=torch.float64) tensor(1.7557, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [02:44<00:00, 12.13 batch/s, loss_age=4.1534133, loss_age_kl=1.5962547355641443]
100%|██████████| 2247/2247 [01:38<00:00, 22.77 batch/s]


{0: tensor(8.8012, dtype=torch.float64), 1: tensor(7.4721, dtype=torch.float64), 2: tensor(9.4692, dtype=torch.float64), 3: tensor(11.4635, dtype=torch.float64), 4: tensor(10.4078, dtype=torch.float64), 5: tensor(7.8305, dtype=torch.float64), 6: tensor(5.7266, dtype=torch.float64), 7: tensor(4.8704, dtype=torch.float64)}
tensor(9.6417, dtype=torch.float64) tensor(2.4820, dtype=torch.float64) tensor(0.4820, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [02:45<00:00, 12.07 batch/s, loss_age=3.9747274, loss_age_kl=1.545139884005709] 
100%|██████████| 2247/2247 [01:38<00:00, 22.86 batch/s]


{0: tensor(6.7777, dtype=torch.float64), 1: tensor(6.5989, dtype=torch.float64), 2: tensor(7.6273, dtype=torch.float64), 3: tensor(9.8452, dtype=torch.float64), 4: tensor(11.0744, dtype=torch.float64), 5: tensor(8.3795, dtype=torch.float64), 6: tensor(6.1585, dtype=torch.float64), 7: tensor(7.1439, dtype=torch.float64)}
tensor(8.9190, dtype=torch.float64) tensor(3.1224, dtype=torch.float64) tensor(1.1224, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [02:45<00:00, 12.09 batch/s, loss_age=3.654993, loss_age_kl=1.4575348802995964] 
100%|██████████| 2247/2247 [01:37<00:00, 23.05 batch/s]


{0: tensor(9.7915, dtype=torch.float64), 1: tensor(5.4399, dtype=torch.float64), 2: tensor(7.7750, dtype=torch.float64), 3: tensor(11.7380, dtype=torch.float64), 4: tensor(12.0390, dtype=torch.float64), 5: tensor(9.1348, dtype=torch.float64), 6: tensor(6.0670, dtype=torch.float64), 7: tensor(4.9002, dtype=torch.float64)}
tensor(9.6837, dtype=torch.float64) tensor(2.0911, dtype=torch.float64) tensor(0.0911, dtype=torch.float64)


100%|██████████| 2000/2000 [02:52<00:00, 11.62 batch/s, loss_age=4.5084405, loss_age_kl=1.6817809053710775]
100%|██████████| 2247/2247 [01:43<00:00, 21.64 batch/s]


{0: tensor(10.2971, dtype=torch.float64), 1: tensor(6.1207, dtype=torch.float64), 2: tensor(7.2868, dtype=torch.float64), 3: tensor(9.8977, dtype=torch.float64), 4: tensor(10.9652, dtype=torch.float64), 5: tensor(8.5409, dtype=torch.float64), 6: tensor(6.2098, dtype=torch.float64), 7: tensor(6.9668, dtype=torch.float64)}
tensor(8.8234, dtype=torch.float64) tensor(3.1313, dtype=torch.float64) tensor(1.1313, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [03:07<00:00, 10.68 batch/s, loss_age=4.0491276, loss_age_kl=1.5881522633229248]
100%|██████████| 2247/2247 [01:43<00:00, 21.62 batch/s]


{0: tensor(7.8260, dtype=torch.float64), 1: tensor(9.4768, dtype=torch.float64), 2: tensor(10.4811, dtype=torch.float64), 3: tensor(11.3538, dtype=torch.float64), 4: tensor(11.1481, dtype=torch.float64), 5: tensor(8.4907, dtype=torch.float64), 6: tensor(6.1835, dtype=torch.float64), 7: tensor(5.4009, dtype=torch.float64)}
tensor(10.2935, dtype=torch.float64) tensor(2.4328, dtype=torch.float64) tensor(0.4328, dtype=torch.float64)


100%|██████████| 2000/2000 [03:07<00:00, 10.66 batch/s, loss_age=4.4794335, loss_age_kl=1.6720389058740015]
100%|██████████| 2247/2247 [02:21<00:00, 15.85 batch/s]


{0: tensor(8.0666, dtype=torch.float64), 1: tensor(4.3222, dtype=torch.float64), 2: tensor(6.9724, dtype=torch.float64), 3: tensor(11.3980, dtype=torch.float64), 4: tensor(12.7451, dtype=torch.float64), 5: tensor(9.8168, dtype=torch.float64), 6: tensor(5.9259, dtype=torch.float64), 7: tensor(6.1929, dtype=torch.float64)}
tensor(9.5490, dtype=torch.float64) tensor(1.9534, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 2000/2000 [03:09<00:00, 10.57 batch/s, loss_age=4.6147776, loss_age_kl=1.7872101466538273]
100%|██████████| 2247/2247 [01:43<00:00, 21.66 batch/s]


{0: tensor(8.3782, dtype=torch.float64), 1: tensor(6.3561, dtype=torch.float64), 2: tensor(8.2172, dtype=torch.float64), 3: tensor(9.5409, dtype=torch.float64), 4: tensor(9.3982, dtype=torch.float64), 5: tensor(7.4673, dtype=torch.float64), 6: tensor(5.4495, dtype=torch.float64), 7: tensor(5.8110, dtype=torch.float64)}
tensor(8.4558, dtype=torch.float64) tensor(3.2844, dtype=torch.float64) tensor(1.2844, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [03:11<00:00, 10.47 batch/s, loss_age=3.9071777, loss_age_kl=1.5288628275209897]
100%|██████████| 2247/2247 [01:49<00:00, 20.55 batch/s]


{0: tensor(7.4412, dtype=torch.float64), 1: tensor(5.2565, dtype=torch.float64), 2: tensor(7.5189, dtype=torch.float64), 3: tensor(11.1834, dtype=torch.float64), 4: tensor(11.2210, dtype=torch.float64), 5: tensor(7.8619, dtype=torch.float64), 6: tensor(6.6503, dtype=torch.float64), 7: tensor(5.3879, dtype=torch.float64)}
tensor(9.1152, dtype=torch.float64) tensor(2.4877, dtype=torch.float64) tensor(0.4877, dtype=torch.float64)


100%|██████████| 2000/2000 [03:08<00:00, 10.59 batch/s, loss_age=4.6111803, loss_age_kl=1.6607237019777998]
100%|██████████| 2247/2247 [01:41<00:00, 22.22 batch/s]


{0: tensor(9.6147, dtype=torch.float64), 1: tensor(4.4478, dtype=torch.float64), 2: tensor(7.4400, dtype=torch.float64), 3: tensor(10.7854, dtype=torch.float64), 4: tensor(9.6300, dtype=torch.float64), 5: tensor(6.8042, dtype=torch.float64), 6: tensor(5.6855, dtype=torch.float64), 7: tensor(4.6604, dtype=torch.float64)}
tensor(8.4032, dtype=torch.float64) tensor(2.5208, dtype=torch.float64) tensor(0.5208, dtype=torch.float64)


100%|██████████| 2000/2000 [02:43<00:00, 12.23 batch/s, loss_age=4.6265364, loss_age_kl=1.6365593635753939]
100%|██████████| 2247/2247 [01:38<00:00, 22.80 batch/s]


{0: tensor(7.3669, dtype=torch.float64), 1: tensor(5.1283, dtype=torch.float64), 2: tensor(6.6393, dtype=torch.float64), 3: tensor(9.3234, dtype=torch.float64), 4: tensor(10.4882, dtype=torch.float64), 5: tensor(7.8116, dtype=torch.float64), 6: tensor(5.3410, dtype=torch.float64), 7: tensor(6.7171, dtype=torch.float64)}
tensor(8.1792, dtype=torch.float64) tensor(3.0860, dtype=torch.float64) tensor(1.0860, dtype=torch.float64)


100%|██████████| 2000/2000 [02:49<00:00, 11.80 batch/s, loss_age=3.868451, loss_age_kl=1.5796086966805905] 
100%|██████████| 2247/2247 [01:42<00:00, 21.97 batch/s]


{0: tensor(7.9268, dtype=torch.float64), 1: tensor(4.5871, dtype=torch.float64), 2: tensor(6.2864, dtype=torch.float64), 3: tensor(9.5833, dtype=torch.float64), 4: tensor(10.1818, dtype=torch.float64), 5: tensor(7.3620, dtype=torch.float64), 6: tensor(4.4210, dtype=torch.float64), 7: tensor(6.2950, dtype=torch.float64)}
tensor(7.9475, dtype=torch.float64) tensor(2.8428, dtype=torch.float64) tensor(0.8428, dtype=torch.float64)


100%|██████████| 2000/2000 [02:39<00:00, 12.51 batch/s, loss_age=3.7206872, loss_age_kl=1.5264693294082359]
100%|██████████| 2247/2247 [01:39<00:00, 22.57 batch/s]


{0: tensor(6.8902, dtype=torch.float64), 1: tensor(5.6494, dtype=torch.float64), 2: tensor(7.1952, dtype=torch.float64), 3: tensor(7.7851, dtype=torch.float64), 4: tensor(7.5401, dtype=torch.float64), 5: tensor(5.0599, dtype=torch.float64), 6: tensor(3.8384, dtype=torch.float64), 7: tensor(7.5951, dtype=torch.float64)}
tensor(6.8673, dtype=torch.float64) tensor(3.5921, dtype=torch.float64) tensor(1.7248, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [02:46<00:00, 12.04 batch/s, loss_age=2.8834293, loss_age_kl=1.2504706789020599]
100%|██████████| 2247/2247 [01:39<00:00, 22.52 batch/s]


{0: tensor(5.9887, dtype=torch.float64), 1: tensor(4.9653, dtype=torch.float64), 2: tensor(5.8852, dtype=torch.float64), 3: tensor(6.6566, dtype=torch.float64), 4: tensor(6.3449, dtype=torch.float64), 5: tensor(4.2883, dtype=torch.float64), 6: tensor(3.8050, dtype=torch.float64), 7: tensor(6.0909, dtype=torch.float64)}
tensor(5.8043, dtype=torch.float64) tensor(3.9912, dtype=torch.float64) tensor(3.1869, dtype=torch.float64)
Saved model


100%|██████████| 2000/2000 [02:44<00:00, 12.12 batch/s, loss_age=3.254168, loss_age_kl=1.308487664166813]  
100%|██████████| 2247/2247 [01:39<00:00, 22.54 batch/s]


{0: tensor(6.2008, dtype=torch.float64), 1: tensor(4.1306, dtype=torch.float64), 2: tensor(5.2015, dtype=torch.float64), 3: tensor(6.8863, dtype=torch.float64), 4: tensor(7.4300, dtype=torch.float64), 5: tensor(5.2430, dtype=torch.float64), 6: tensor(4.2950, dtype=torch.float64), 7: tensor(4.3728, dtype=torch.float64)}
tensor(6.0141, dtype=torch.float64) tensor(3.7129, dtype=torch.float64) tensor(2.6988, dtype=torch.float64)


100%|██████████| 2000/2000 [02:44<00:00, 12.14 batch/s, loss_age=2.8351786, loss_age_kl=1.224105919664252] 
100%|██████████| 2247/2247 [01:39<00:00, 22.56 batch/s]


{0: tensor(7.2949, dtype=torch.float64), 1: tensor(3.7520, dtype=torch.float64), 2: tensor(5.2236, dtype=torch.float64), 3: tensor(6.2767, dtype=torch.float64), 4: tensor(6.1417, dtype=torch.float64), 5: tensor(4.5805, dtype=torch.float64), 6: tensor(3.2703, dtype=torch.float64), 7: tensor(5.1620, dtype=torch.float64)}
tensor(5.4251, dtype=torch.float64) tensor(3.7252, dtype=torch.float64) tensor(3.3001, dtype=torch.float64)


 15%|█▍        | 293/2000 [00:42<02:01, 14.10 batch/s, loss_age=2.9201164, loss_age_kl=1.2730490926565992] 

In [None]:
####################################################
EPOCHS = 12
####################################################

dl_train = dm_train.get_balanced_class_dataloader(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                  batch_size=64, num_workers=12, prefetch_factor=4, pin_memory=True)

model_age.load_state_dict(torch.load("./model_age_feature_simple.pt"))
opt = optim.SGD(set([*model_age.fc0.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CyclicLR(opt, base_lr=0.1/25, max_lr=0.1, step_size_up=int(len(dl_train)))

In [None]:
def forward_function(x):
    out = model_age(x)
    out = F.softmax(out, dim=-1)
    return out

ae, mae, val_aar, val_aar_old = validator.validate_ext2(forward_function)
print(ae, mae, val_aar, val_aar_old)

100%|██████████| 4493/4493 [02:21<00:00, 31.72 batch/s]


tensor(2.2409, dtype=torch.float64) tensor(4.9902, dtype=torch.float64) tensor(5.8763, dtype=torch.float64)


In [None]:
best_val_aar = val_aar

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age: torch.Tensor = y[0].to("cuda")
            y_age_kl: torch.Tensor = y[1].to("cuda")

            out_age = model_age(x)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age = torch.mean(torch.abs(y_age - out))

            loss = loss_age_kl + torch.square(loss_age - mae)
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_kl=loss_age_kl.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy())

    def forward_function(x):
        out = model_age(x)
        out = F.softmax(out, dim=-1)
        return out

    ae, mae, val_aar, val_aar_old = validator.validate_ext2(forward_function)
    print(ae)
    print(mae, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age.state_dict(), "./model_age_classification_simple.pt")
        print("Saved model")

100%|██████████| 6739/6739 [10:16<00:00, 10.93 batch/s, loss_age=4.2348323, loss_age_kl=15.457227201259457]
100%|██████████| 4493/4493 [01:58<00:00, 38.06 batch/s]


tensor(4.3330, dtype=torch.float64) tensor(3.5900, dtype=torch.float64)


100%|██████████| 6739/6739 [09:35<00:00, 11.70 batch/s, loss_age=3.1603274, loss_age_kl=14.250266355119173]
100%|██████████| 4493/4493 [02:03<00:00, 36.30 batch/s]


tensor(3.4237, dtype=torch.float64) tensor(2.7052, dtype=torch.float64)


100%|██████████| 6739/6739 [10:49<00:00, 10.37 batch/s, loss_age=3.3633087, loss_age_kl=18.20725307154277] 
100%|██████████| 4493/4493 [02:19<00:00, 32.31 batch/s]


tensor(4.1990, dtype=torch.float64) tensor(3.6775, dtype=torch.float64)


100%|██████████| 6739/6739 [11:14<00:00,  9.99 batch/s, loss_age=4.8973255, loss_age_kl=8.180094884813293] 
100%|██████████| 4493/4493 [01:56<00:00, 38.51 batch/s]


tensor(4.4910, dtype=torch.float64) tensor(3.9189, dtype=torch.float64)


100%|██████████| 6739/6739 [08:36<00:00, 13.04 batch/s, loss_age=3.3888707, loss_age_kl=1.6559491197087888]
100%|██████████| 4493/4493 [01:52<00:00, 39.99 batch/s]


tensor(5.3885, dtype=torch.float64) tensor(4.8540, dtype=torch.float64)
Saved model


100%|██████████| 6739/6739 [08:22<00:00, 13.41 batch/s, loss_age=2.4683042, loss_age_kl=1.159581840728757] 
100%|██████████| 4493/4493 [01:52<00:00, 40.07 batch/s]


tensor(5.0994, dtype=torch.float64) tensor(4.4671, dtype=torch.float64)


100%|██████████| 6739/6739 [08:29<00:00, 13.24 batch/s, loss_age=4.376457, loss_age_kl=1.2589574761360263] 
100%|██████████| 4493/4493 [01:53<00:00, 39.44 batch/s]


tensor(5.2628, dtype=torch.float64) tensor(4.7096, dtype=torch.float64)


100%|██████████| 6739/6739 [08:21<00:00, 13.44 batch/s, loss_age=2.7681565, loss_age_kl=1.0991218766440194]
100%|██████████| 4493/4493 [01:51<00:00, 40.46 batch/s]


tensor(5.3480, dtype=torch.float64) tensor(4.7429, dtype=torch.float64)


: 