In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeGroupAndAgeDatasetKL, AgeGroupKLAndAgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Caricamento del dataframe
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")

In [3]:
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
aug = CSVUtils.get_df_from_csv("./augumentation_simple.csv", "./newAugmentationDataset/")
df_train = pd.concat([df_train, aug], ignore_index=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

label_map_v = CSVUtils.get_label_map_vector()
cd_train = AgeGroupKLAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                                    label_map_vector=label_map_v, transform_func=transform_func)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=4, prefetch_factor=4)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=128, shuffle=True, drop_last=True, num_workers=12, prefetch_factor=4, pin_memory=True)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 12
####################################################

backbone = BackBone(pretrained=True, freeze=False)
backbone.train()
model_age_group = ResNetFiLMed(backbone, 8)
model_age_group.train()
model_age = ResNetFiLMed(backbone, 81)
model_age.train()
opt = optim.SGD(set([*model_age_group.parameters(), *model_age.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)
kl = nn.KLDivLoss(reduction="batchmean")

In [5]:
best_val_aar = -1
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = kl(F.log_softmax(knowledge, dim=-1), y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = loss_age_group + loss_age_kl + loss_age
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), 
                                loss_age=loss_age.detach().cpu().numpy(), 
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_film_feature.pt")
        torch.save(model_age.state_dict(), "./model_age_film_feature.pt")
        print("Saved model")

100%|██████████| 3369/3369 [13:52<00:00,  4.05 batch/s, loss_age=4.418851, loss_age_group=0.8484421756816958, loss_age_kl=1.7428398030648684] 
100%|██████████| 4493/4493 [03:03<00:00, 24.49 batch/s]


tensor(4.3063, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(2.6937, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [15:13<00:00,  3.69 batch/s, loss_age=3.666069, loss_age_group=0.5901269066677528, loss_age_kl=1.527265556035743]  
100%|██████████| 4493/4493 [03:20<00:00, 22.42 batch/s]


tensor(3.9628, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(3.0372, dtype=torch.float64)


100%|██████████| 3369/3369 [15:19<00:00,  3.66 batch/s, loss_age=3.9765315, loss_age_group=0.6724678332292433, loss_age_kl=1.6035756311292857]
100%|██████████| 4493/4493 [03:09<00:00, 23.75 batch/s]


tensor(3.8126, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(3.1874, dtype=torch.float64)


100%|██████████| 3369/3369 [15:49<00:00,  3.55 batch/s, loss_age=3.2281678, loss_age_group=0.5578498772305798, loss_age_kl=1.4260700256196137]
100%|██████████| 4493/4493 [02:56<00:00, 25.45 batch/s]


tensor(3.7532, dtype=torch.float64) tensor(0.1427, dtype=torch.float64) tensor(3.2468, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:35<00:00,  3.85 batch/s, loss_age=3.5419397, loss_age_group=0.632945561686667, loss_age_kl=1.4863395216921458]  
100%|██████████| 4493/4493 [02:44<00:00, 27.36 batch/s]


tensor(3.5460, dtype=torch.float64) tensor(0.0161, dtype=torch.float64) tensor(3.4540, dtype=torch.float64)


100%|██████████| 3369/3369 [14:25<00:00,  3.89 batch/s, loss_age=3.2469814, loss_age_group=0.6381147592897489, loss_age_kl=1.4232626642567898] 
100%|██████████| 4493/4493 [02:44<00:00, 27.35 batch/s]


tensor(3.2905, dtype=torch.float64) tensor(0.1466, dtype=torch.float64) tensor(3.7095, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:28<00:00,  3.88 batch/s, loss_age=2.4589474, loss_age_group=0.4542229877122823, loss_age_kl=1.164990009580303]  
100%|██████████| 4493/4493 [02:48<00:00, 26.65 batch/s]


tensor(2.6871, dtype=torch.float64) tensor(2.1787, dtype=torch.float64) tensor(4.3129, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:43<00:00,  3.81 batch/s, loss_age=2.3552203, loss_age_group=0.4090158020003517, loss_age_kl=1.1136337974321417] 
100%|██████████| 4493/4493 [04:18<00:00, 17.35 batch/s]


tensor(2.5932, dtype=torch.float64) tensor(2.8997, dtype=torch.float64) tensor(4.4333, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [15:09<00:00,  3.70 batch/s, loss_age=2.283958, loss_age_group=0.37994872803532176, loss_age_kl=1.106674769961066]  
100%|██████████| 4493/4493 [04:19<00:00, 17.32 batch/s]


tensor(2.5063, dtype=torch.float64) tensor(2.3463, dtype=torch.float64) tensor(4.4937, dtype=torch.float64)


100%|██████████| 3369/3369 [14:48<00:00,  3.79 batch/s, loss_age=2.4236293, loss_age_group=0.43776229716048787, loss_age_kl=1.0823988946404945]
100%|██████████| 4493/4493 [02:59<00:00, 25.06 batch/s]


tensor(2.4407, dtype=torch.float64) tensor(2.8013, dtype=torch.float64) tensor(4.5593, dtype=torch.float64)


100%|██████████| 3369/3369 [14:06<00:00,  3.98 batch/s, loss_age=1.9568629, loss_age_group=0.3312729317419897, loss_age_kl=0.9398443738314677] 
100%|██████████| 4493/4493 [02:44<00:00, 27.32 batch/s]


tensor(2.2242, dtype=torch.float64) tensor(3.7748, dtype=torch.float64) tensor(5.0831, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:17<00:00,  3.93 batch/s, loss_age=2.3123775, loss_age_group=0.3824366608615769, loss_age_kl=1.1233796710168498] 
100%|██████████| 4493/4493 [02:45<00:00, 27.17 batch/s]


tensor(2.1877, dtype=torch.float64) tensor(4.1028, dtype=torch.float64) tensor(5.3027, dtype=torch.float64)
Saved model


In [5]:
####################################################
EPOCHS = 12
####################################################

dl_train = dm_train.get_balanced_class_dataloader(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                  batch_size=128, num_workers=8, prefetch_factor=4, pin_memory=True)

model_age_group.load_state_dict(torch.load("./model_age_group_film_feature_mae2.pt"))
model_age_group.requires_grad_(False)
model_age_group.fc0.requires_grad_(True)
model_age.load_state_dict(torch.load("./model_age_film_feature_mae2.pt"))
opt = optim.SGD(set([*model_age_group.fc0.parameters(), *model_age.fc0.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)

In [6]:
def forward_function(x):
    knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
    knowledge = model_age_group(x, knowledge_age_group)
    knowledge = F.softmax(knowledge, dim=-1)
    out = model_age(x, knowledge)
    out = F.softmax(out, dim=-1)
    return out

mae, val_aar, val_aar_old = validator.validate_ext(forward_function)
print(mae, val_aar, val_aar_old)

100%|██████████| 4493/4493 [04:14<00:00, 17.67 batch/s]


tensor(2.1877, dtype=torch.float64) tensor(4.1033, dtype=torch.float64) tensor(5.2903, dtype=torch.float64)


In [7]:
best_val_aar = val_aar
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group.forward_detach(x, knowledge_age_group)
            #loss_age_group: torch.Tensor = kl(F.log_softmax(knowledge, dim=-1), y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = loss_age_kl + (torch.square(loss_age - mae) if loss_age > mae else torch.tensor(0)) # + loss_age_group + loss_age_kl
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(#loss_age_group=loss_age_group.detach().cpu().numpy(), 
                                loss_age=loss_age.detach().cpu().numpy(),
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae_, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae_, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_film_classifier_unlocked.pt")
        torch.save(model_age.state_dict(), "./model_age_film_classifier_unlocked.pt")
        print("Saved model")

100%|██████████| 4165/4165 [16:34<00:00,  4.19 batch/s, loss_age=2.7784278, loss_age_kl=1.1639567308489056]
100%|██████████| 4493/4493 [03:40<00:00, 20.38 batch/s]


tensor(4.8587, dtype=torch.float64) tensor(4.5131, dtype=torch.float64) tensor(3.8831, dtype=torch.float64)
Saved model


100%|██████████| 4165/4165 [16:14<00:00,  4.27 batch/s, loss_age=3.0142322, loss_age_kl=1.1951743632948955]
100%|██████████| 4493/4493 [03:37<00:00, 20.64 batch/s]


tensor(4.4580, dtype=torch.float64) tensor(5.0917, dtype=torch.float64) tensor(4.6513, dtype=torch.float64)
Saved model


100%|██████████| 4165/4165 [16:36<00:00,  4.18 batch/s, loss_age=3.1093388, loss_age_kl=1.33460850475756]  
100%|██████████| 4493/4493 [03:54<00:00, 19.19 batch/s]


tensor(4.4482, dtype=torch.float64) tensor(4.4870, dtype=torch.float64) tensor(4.4022, dtype=torch.float64)


100%|██████████| 4165/4165 [16:21<00:00,  4.24 batch/s, loss_age=3.146628, loss_age_kl=1.5197401325056714] 
100%|██████████| 4493/4493 [03:37<00:00, 20.62 batch/s]


tensor(3.7605, dtype=torch.float64) tensor(5.9069, dtype=torch.float64) tensor(5.5584, dtype=torch.float64)
Saved model


100%|██████████| 4165/4165 [16:58<00:00,  4.09 batch/s, loss_age=4.1817546, loss_age_kl=1.414717732998918] 
100%|██████████| 4493/4493 [04:18<00:00, 17.38 batch/s]


tensor(3.2740, dtype=torch.float64) tensor(6.2252, dtype=torch.float64) tensor(6.1634, dtype=torch.float64)
Saved model


100%|██████████| 4165/4165 [17:19<00:00,  4.01 batch/s, loss_age=3.1126652, loss_age_kl=1.361400030182947] 
100%|██████████| 4493/4493 [04:09<00:00, 17.99 batch/s]


tensor(3.9385, dtype=torch.float64) tensor(5.6551, dtype=torch.float64) tensor(5.2921, dtype=torch.float64)


100%|██████████| 4165/4165 [16:29<00:00,  4.21 batch/s, loss_age=3.0246415, loss_age_kl=1.2842931053440938]
100%|██████████| 4493/4493 [03:38<00:00, 20.58 batch/s]


tensor(3.9459, dtype=torch.float64) tensor(5.6733, dtype=torch.float64) tensor(5.2993, dtype=torch.float64)


100%|██████████| 4165/4165 [16:14<00:00,  4.27 batch/s, loss_age=2.8431807, loss_age_kl=1.0935919150030564]
100%|██████████| 4493/4493 [03:37<00:00, 20.64 batch/s]


tensor(3.9288, dtype=torch.float64) tensor(5.7111, dtype=torch.float64) tensor(5.2981, dtype=torch.float64)


100%|██████████| 4165/4165 [16:33<00:00,  4.19 batch/s, loss_age=2.8940268, loss_age_kl=1.153948352196724] 
100%|██████████| 4493/4493 [03:51<00:00, 19.37 batch/s]


tensor(4.1742, dtype=torch.float64) tensor(5.4295, dtype=torch.float64) tensor(4.9148, dtype=torch.float64)


100%|██████████| 4165/4165 [16:17<00:00,  4.26 batch/s, loss_age=3.0213046, loss_age_kl=1.1927278011105962]
100%|██████████| 4493/4493 [03:39<00:00, 20.51 batch/s]


tensor(3.9208, dtype=torch.float64) tensor(5.7424, dtype=torch.float64) tensor(5.3575, dtype=torch.float64)


100%|██████████| 4165/4165 [16:08<00:00,  4.30 batch/s, loss_age=2.9844527, loss_age_kl=1.2379675153243381]
100%|██████████| 4493/4493 [03:38<00:00, 20.55 batch/s]


tensor(3.8072, dtype=torch.float64) tensor(5.7985, dtype=torch.float64) tensor(5.4537, dtype=torch.float64)


100%|██████████| 4165/4165 [16:29<00:00,  4.21 batch/s, loss_age=2.5633926, loss_age_kl=1.0643003455769706]
100%|██████████| 4493/4493 [03:50<00:00, 19.51 batch/s]


tensor(3.7964, dtype=torch.float64) tensor(5.8398, dtype=torch.float64) tensor(5.4833, dtype=torch.float64)
