In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
aug = CSVUtils.get_df_from_csv("./augumentation_balanced_remove.csv", "./newAugmentationDataset/")
df_train_aug = pd.concat([df_train, aug], ignore_index=True)
df_train_aug = df_train_aug.reset_index(drop=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch
import numpy as np

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                        transform_func=transform_func)
# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train_balanced = AgeDatasetKL(df_train_aug, path_col="path", label_col="age", label_function="Linear", 
                                transform_func=transform_func)
cd_train.set_n_classes(81)
cd_train.set_starting_class(1)
cd_train_balanced.set_n_classes(81)
cd_train_balanced.set_starting_class(1)
dm_train = CustomDataLoader(cd_train)
dm_train_balanced = CustomDataLoader(cd_train_balanced)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=64, shuffle=True, drop_last=True, num_workers=8, prefetch_factor=2)
dl_train_balanced, sampler = dm_train_balanced.get_balanced_class_dataloader2(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                                            batch_size=64, num_workers=8, prefetch_factor=2)
sampler.n_batches = len(dl_train)
sampler.p = np.array([0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11])

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=4, prefetch_factor=2)

In [3]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone, ResNetNotFiLMed, DoNothingLayer
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_b0, EfficientNet_B0_Weights, resnet34, ResNet34_Weights
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 24
####################################################

backbone = resnet34(ResNet34_Weights.IMAGENET1K_V1)
backbone.fc = DoNothingLayer()
backbone.train()
backbone.requires_grad_(True)
backbone.to("cuda")
model_age = ResNetNotFiLMed(backbone, 81)
# model_age.load_state_dict(torch.load("./model_age_classification_simple.pt", map_location="cuda:0"))
opt = optim.SGD(set([*backbone.parameters(), *model_age.fc0.parameters()]), lr=1e-2, weight_decay=5e-5)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-2, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)
kl = nn.KLDivLoss(reduction="batchmean")



In [4]:
import numpy as np
def get_centers(outs, ys, old):
    centers_for_age = {x: [] for x in range(81)}

    for out, y in zip(outs, ys):
        for _out, _y in zip(out, y):
            centers_for_age[int(_y)].append(_out.detach().cpu().numpy())

    for i in centers_for_age:
        if len(centers_for_age[i]) > 0:
            centers_for_age[i] = torch.tensor(np.array(centers_for_age[i])).mean(dim=0).to("cuda")
        else:
            centers_for_age[i] = old[i]
    return centers_for_age

def update_centers(old, new, alpha=0.5):
    for i in new:
        new[i] = new[i] - alpha*(new[i] - old[i])
    return new

def get_centers_loss(out, y, centers_for_age):
    loss = None
    for _out, _y in zip(out, y):
        if loss is None:
            loss = torch.mean(torch.square(_out - centers_for_age[int(_y)]))
        else:    
            loss += torch.mean(torch.square(_out - centers_for_age[int(_y)]))
    return loss.to("cuda")

In [6]:
centers_for_age = {x: torch.zeros(size=(512,), device="cuda") for x in range(81)}
best_val_aar = -1

for e in range(EPOCHS):
    with tqdm(zip(dl_train, dl_train_balanced), unit=" batch") as tepoch:
        for batch, batch_balanced in tepoch:
            opt.zero_grad()
            x, y = batch
            x_bal, y_bal = batch_balanced

            x = x.to("cuda")
            y_age: torch.Tensor = y[0].to("cuda")
            y_age_kl: torch.Tensor = y[1].to("cuda")

            x_bal = x_bal.to("cuda")
            y_age_bal: torch.Tensor = y_bal[0].to("cuda")
            y_age_kl_bal: torch.Tensor = y_bal[1].to("cuda")

            out_rep, out_age = model_age.forward_with_repr(x)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)
            out_age = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out_age).to("cuda")
            loss_age = torch.mean(torch.abs(y_age - out))
            loss = loss_age_kl + loss_age

            out_rep_bal, out_age_bal = model_age.forward_with_repr(x_bal)
            loss_age_kl_bal: torch.Tensor = kl(F.log_softmax(out_age_bal, dim=-1), y_age_kl_bal)
            out_age_bal = F.softmax(out_age_bal, dim=-1)
            out_bal = AgeConversion.EVAge(out_age_bal).to("cuda")
            loss_age_bal = torch.mean(torch.abs(y_age_bal - out_bal))
            loss_bal = loss_age_kl_bal + torch.square(loss_age_bal - 2.0)

            loss_repr = get_centers_loss(out_rep, y_age, centers_for_age)
            loss_repr_bal = get_centers_loss(out_rep_bal, y_age_bal, centers_for_age)

            total_loss = loss + loss_bal + loss_repr + loss_repr_bal

            total_loss.backward()
            opt.step()
            scheduler.step()

            centers_for_age = update_centers(centers_for_age, get_centers((out_rep, out_rep_bal), (y_age, y_age_bal), centers_for_age), alpha=0.5)

            tepoch.set_postfix(loss_age_kl=loss_age_kl.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy(),
                                loss_age_kl_bal=loss_age_kl_bal.detach().cpu().numpy(), loss_age_bal=loss_age_bal.detach().cpu().numpy(),
                                loss_repr=loss_repr.detach().cpu().numpy(), loss_repr_bal=loss_repr_bal.detach().cpu().numpy(),
                                total_loss=total_loss.detach().cpu().numpy())

    def forward_function(x):
        out = model_age(x)
        out = F.softmax(out, dim=-1)
        return out

    ae, mae_, val_aar, val_aar_old = validator.validate_ext2(forward_function)
    print(ae, mae_, val_aar, val_aar_old)

    if best_val_aar < val_aar:
        best_val_aar = val_aar
        torch.save(model_age.state_dict(), "./model_age_feature_simple_no_loss_34.pt")
        print("Saved model")

6739 batch [35:34,  3.16 batch/s, loss_age=5.8269243, loss_age_bal=4.8324165, loss_age_kl=2.033347208678843, loss_age_kl_bal=1.831698433383939, loss_repr=0.25322247, loss_repr_bal=0.25659937, total_loss=18.22437480609095]   
100%|██████████| 4493/4493 [03:59<00:00, 18.72 batch/s]


{0: tensor(23.0485, dtype=torch.float64), 1: tensor(9.1328, dtype=torch.float64), 2: tensor(4.6676, dtype=torch.float64), 3: tensor(4.3379, dtype=torch.float64), 4: tensor(5.3151, dtype=torch.float64), 5: tensor(6.2021, dtype=torch.float64), 6: tensor(9.0642, dtype=torch.float64), 7: tensor(15.2392, dtype=torch.float64)} tensor(5.5072, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(1.4928, dtype=torch.float64)
Saved model


6739 batch [36:57,  3.04 batch/s, loss_age=4.3261447, loss_age_bal=4.041786, loss_age_kl=1.838188225176517, loss_age_kl_bal=1.6042386686973893, loss_repr=0.17612745, loss_repr_bal=0.15544678, total_loss=12.26903677360604]    
100%|██████████| 4493/4493 [02:29<00:00, 30.13 batch/s]


{0: tensor(17.3041, dtype=torch.float64), 1: tensor(5.5825, dtype=torch.float64), 2: tensor(3.5899, dtype=torch.float64), 3: tensor(3.8535, dtype=torch.float64), 4: tensor(3.9024, dtype=torch.float64), 5: tensor(4.0656, dtype=torch.float64), 6: tensor(5.0341, dtype=torch.float64), 7: tensor(9.3567, dtype=torch.float64)} tensor(4.0428, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(2.9572, dtype=torch.float64)


6739 batch [36:51,  3.05 batch/s, loss_age=3.5731835, loss_age_bal=3.1613517, loss_age_kl=1.5571467164914237, loss_age_kl_bal=1.3161975115620546, loss_repr=0.10796403, loss_repr_bal=0.111631945, total_loss=8.014861457431703]  
100%|██████████| 4493/4493 [02:31<00:00, 29.59 batch/s]


{0: tensor(14.0691, dtype=torch.float64), 1: tensor(4.1666, dtype=torch.float64), 2: tensor(2.7120, dtype=torch.float64), 3: tensor(3.7175, dtype=torch.float64), 4: tensor(3.8486, dtype=torch.float64), 5: tensor(3.4373, dtype=torch.float64), 6: tensor(4.5835, dtype=torch.float64), 7: tensor(9.3111, dtype=torch.float64)} tensor(3.5689, dtype=torch.float64) tensor(0.7358, dtype=torch.float64) tensor(3.4311, dtype=torch.float64)
Saved model


6739 batch [36:52,  3.05 batch/s, loss_age=2.9846108, loss_age_bal=3.197947, loss_age_kl=1.3854908616657855, loss_age_kl_bal=1.3384172249854966, loss_repr=0.08481025, loss_repr_bal=0.10005864, total_loss=7.328464840882735]    
100%|██████████| 4493/4493 [02:31<00:00, 29.62 batch/s]


{0: tensor(14.0161, dtype=torch.float64), 1: tensor(4.4862, dtype=torch.float64), 2: tensor(2.6072, dtype=torch.float64), 3: tensor(3.1161, dtype=torch.float64), 4: tensor(3.2846, dtype=torch.float64), 5: tensor(2.9587, dtype=torch.float64), 6: tensor(3.6273, dtype=torch.float64), 7: tensor(6.8156, dtype=torch.float64)} tensor(3.1454, dtype=torch.float64) tensor(0.9071, dtype=torch.float64) tensor(3.8546, dtype=torch.float64)
Saved model


6739 batch [36:52,  3.05 batch/s, loss_age=3.1365905, loss_age_bal=2.9918766, loss_age_kl=1.3634712690773216, loss_age_kl_bal=1.2546286200519376, loss_repr=0.07802414, loss_repr_bal=0.086901985, total_loss=6.903435683220485]  
100%|██████████| 4493/4493 [02:16<00:00, 32.81 batch/s]


{0: tensor(10.6661, dtype=torch.float64), 1: tensor(3.6024, dtype=torch.float64), 2: tensor(2.4098, dtype=torch.float64), 3: tensor(2.9217, dtype=torch.float64), 4: tensor(2.9318, dtype=torch.float64), 5: tensor(2.8071, dtype=torch.float64), 6: tensor(3.5617, dtype=torch.float64), 7: tensor(6.0707, dtype=torch.float64)} tensor(2.8799, dtype=torch.float64) tensor(2.6281, dtype=torch.float64) tensor(4.1201, dtype=torch.float64)
Saved model


6739 batch [36:45,  3.06 batch/s, loss_age=3.7563164, loss_age_bal=2.9239368, loss_age_kl=1.475838604645926, loss_age_kl_bal=1.240027864002203, loss_repr=0.08884904, loss_repr_bal=0.08120018, total_loss=7.495891384581737]     
100%|██████████| 4493/4493 [02:38<00:00, 28.42 batch/s]


{0: tensor(12.0295, dtype=torch.float64), 1: tensor(3.8359, dtype=torch.float64), 2: tensor(2.1426, dtype=torch.float64), 3: tensor(2.6703, dtype=torch.float64), 4: tensor(2.8596, dtype=torch.float64), 5: tensor(2.9017, dtype=torch.float64), 6: tensor(3.5476, dtype=torch.float64), 7: tensor(4.2599, dtype=torch.float64)} tensor(2.7529, dtype=torch.float64) tensor(2.3549, dtype=torch.float64) tensor(4.2471, dtype=torch.float64)


6739 batch [36:48,  3.05 batch/s, loss_age=2.673304, loss_age_bal=2.4442065, loss_age_kl=1.1978703492807607, loss_age_kl_bal=1.0994876487322385, loss_repr=0.06326068, loss_repr_bal=0.06382529, total_loss=5.2950674362177645]   
100%|██████████| 4493/4493 [02:32<00:00, 29.51 batch/s]


{0: tensor(9.0461, dtype=torch.float64), 1: tensor(2.6943, dtype=torch.float64), 2: tensor(2.2096, dtype=torch.float64), 3: tensor(2.7340, dtype=torch.float64), 4: tensor(2.7755, dtype=torch.float64), 5: tensor(2.5736, dtype=torch.float64), 6: tensor(3.8250, dtype=torch.float64), 7: tensor(6.5458, dtype=torch.float64)} tensor(2.6724, dtype=torch.float64) tensor(3.2757, dtype=torch.float64) tensor(4.6538, dtype=torch.float64)
Saved model


6739 batch [36:52,  3.05 batch/s, loss_age=3.0325847, loss_age_bal=2.5506783, loss_age_kl=1.2690034582283998, loss_age_kl_bal=1.0762888584928372, loss_repr=0.06761605, loss_repr_bal=0.056166984, total_loss=5.804906548903739]  
100%|██████████| 4493/4493 [02:17<00:00, 32.79 batch/s]


{0: tensor(8.6662, dtype=torch.float64), 1: tensor(3.4887, dtype=torch.float64), 2: tensor(2.3194, dtype=torch.float64), 3: tensor(2.6196, dtype=torch.float64), 4: tensor(2.6440, dtype=torch.float64), 5: tensor(2.2166, dtype=torch.float64), 6: tensor(2.8828, dtype=torch.float64), 7: tensor(4.4775, dtype=torch.float64)} tensor(2.5708, dtype=torch.float64) tensor(4.0464, dtype=torch.float64) tensor(5.1399, dtype=torch.float64)
Saved model


6739 batch [36:59,  3.04 batch/s, loss_age=2.1805434, loss_age_bal=2.0071466, loss_age_kl=1.097612367517082, loss_age_kl_bal=0.9356257033385401, loss_repr=0.04603666, loss_repr_bal=0.040672753, total_loss=4.3005419810784575]  
100%|██████████| 4493/4493 [03:57<00:00, 18.88 batch/s]


{0: tensor(9.4458, dtype=torch.float64), 1: tensor(3.5114, dtype=torch.float64), 2: tensor(2.0544, dtype=torch.float64), 3: tensor(2.3639, dtype=torch.float64), 4: tensor(2.5166, dtype=torch.float64), 5: tensor(2.2459, dtype=torch.float64), 6: tensor(2.7152, dtype=torch.float64), 7: tensor(3.9050, dtype=torch.float64)} tensor(2.4059, dtype=torch.float64) tensor(3.8244, dtype=torch.float64) tensor(5.0132, dtype=torch.float64)


6739 batch [37:09,  3.02 batch/s, loss_age=2.4854977, loss_age_bal=2.724047, loss_age_kl=1.143778819621547, loss_age_kl_bal=1.171642062218905, loss_repr=0.055577155, loss_repr_bal=0.059440594, total_loss=5.4401802951423965]   
100%|██████████| 4493/4493 [03:42<00:00, 20.19 batch/s]


{0: tensor(6.3328, dtype=torch.float64), 1: tensor(2.9450, dtype=torch.float64), 2: tensor(1.9813, dtype=torch.float64), 3: tensor(2.4448, dtype=torch.float64), 4: tensor(2.4223, dtype=torch.float64), 5: tensor(2.1700, dtype=torch.float64), 6: tensor(2.6539, dtype=torch.float64), 7: tensor(3.8458, dtype=torch.float64)} tensor(2.3329, dtype=torch.float64) tensor(5.3620, dtype=torch.float64) tensor(6.1286, dtype=torch.float64)
Saved model


6739 batch [36:48,  3.05 batch/s, loss_age=2.695856, loss_age_bal=2.2298265, loss_age_kl=1.1484576750597926, loss_age_kl_bal=1.072251952620886, loss_repr=0.052481122, loss_repr_bal=0.05968844, total_loss=5.08155548404922]     
100%|██████████| 4493/4493 [02:16<00:00, 32.88 batch/s]


{0: tensor(7.1073, dtype=torch.float64), 1: tensor(2.7924, dtype=torch.float64), 2: tensor(1.9453, dtype=torch.float64), 3: tensor(2.4083, dtype=torch.float64), 4: tensor(2.4708, dtype=torch.float64), 5: tensor(2.0464, dtype=torch.float64), 6: tensor(2.5788, dtype=torch.float64), 7: tensor(3.9649, dtype=torch.float64)} tensor(2.2979, dtype=torch.float64) tensor(5.0170, dtype=torch.float64) tensor(5.8834, dtype=torch.float64)


6739 batch [36:44,  3.06 batch/s, loss_age=1.912674, loss_age_bal=2.2444153, loss_age_kl=0.9704191152301509, loss_age_kl_bal=1.023974583756345, loss_repr=0.03992828, loss_repr_bal=0.051875245, total_loss=4.058610003556932]    
100%|██████████| 4493/4493 [03:32<00:00, 21.12 batch/s]


{0: tensor(7.0811, dtype=torch.float64), 1: tensor(2.6998, dtype=torch.float64), 2: tensor(1.8617, dtype=torch.float64), 3: tensor(2.2937, dtype=torch.float64), 4: tensor(2.3502, dtype=torch.float64), 5: tensor(2.0155, dtype=torch.float64), 6: tensor(2.5058, dtype=torch.float64), 7: tensor(4.2860, dtype=torch.float64)} tensor(2.2080, dtype=torch.float64) tensor(4.9731, dtype=torch.float64) tensor(5.9019, dtype=torch.float64)


6739 batch [35:22,  3.18 batch/s, loss_age=1.9347908, loss_age_bal=2.05008, loss_age_kl=0.9275684536163277, loss_age_kl_bal=0.9507330970373014, loss_repr=0.041958243, loss_repr_bal=0.045384362, total_loss=3.902943017911963]   
100%|██████████| 4493/4493 [03:48<00:00, 19.70 batch/s]


{0: tensor(6.4570, dtype=torch.float64), 1: tensor(2.4691, dtype=torch.float64), 2: tensor(1.7700, dtype=torch.float64), 3: tensor(2.1501, dtype=torch.float64), 4: tensor(2.3619, dtype=torch.float64), 5: tensor(2.1675, dtype=torch.float64), 6: tensor(2.4871, dtype=torch.float64), 7: tensor(3.8111, dtype=torch.float64)} tensor(2.1566, dtype=torch.float64) tensor(5.3965, dtype=torch.float64) tensor(6.1991, dtype=torch.float64)
Saved model


814 batch [04:16,  3.11 batch/s, loss_age=2.1201398, loss_age_bal=1.8519177, loss_age_kl=1.0633291326048717, loss_age_kl_bal=0.8991121059369301, loss_repr=0.053668194, loss_repr_bal=0.043339718, total_loss=4.2015173425569134]