In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeGroupAndAgeDatasetKL, AgeGroupKLAndAgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#Caricamento del dataframe
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
aug = CSVUtils.get_df_from_csv("./augumentation.csv", "./newAugmentationDataset/")
df_train = pd.concat([df_train, aug], ignore_index=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

431304


In [3]:
from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

label_map_v = CSVUtils.get_label_map_vector()
cd_train = AgeGroupKLAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                                    label_map_vector=label_map_v, transform_func=transform_func)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=6, prefetch_factor=4)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=128, shuffle=True, drop_last=True, num_workers=16, prefetch_factor=4, pin_memory=True)

In [4]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 12
####################################################

backbone = BackBone(pretrained=True, freeze=False)
backbone.train()
model_age_group = ResNetFiLMed(backbone, 8)
model_age_group.train()
model_age = ResNetFiLMed(backbone, 81)
model_age.train()
opt = optim.SGD(set([*model_age_group.parameters(), *model_age.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)
kl = nn.KLDivLoss(reduction="batchmean")

In [5]:
best_mae = 20
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = kl(F.log_softmax(knowledge, dim=-1), y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = loss_age_group + loss_age_kl + loss_age
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), 
                                loss_age=loss_age.detach().cpu().numpy(), 
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae, val_aar, val_aar_old)

    if mae < best_mae:
        best_mae = mae
        torch.save(model_age_group.state_dict(), "./model_age_group_film_feature_aug.pt")
        torch.save(model_age.state_dict(), "./model_age_film_feature_aug.pt")
        print("Saved model")

100%|██████████| 7337/7337 [30:45<00:00,  3.98 batch/s, loss_age=3.7175488, loss_age_group=0.5760222452261261, loss_age_kl=1.4907589687445415] 
100%|██████████| 4493/4493 [03:20<00:00, 22.40 batch/s]


tensor(11.5719, dtype=torch.float64) tensor(0.9765, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:16<00:00,  4.04 batch/s, loss_age=3.1208599, loss_age_group=0.5049353053944864, loss_age_kl=1.3215067822917441] 
100%|██████████| 4493/4493 [03:06<00:00, 24.11 batch/s]


tensor(11.2076, dtype=torch.float64) tensor(1.1922, dtype=torch.float64) tensor(0., dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:19<00:00,  4.03 batch/s, loss_age=3.7261834, loss_age_group=0.6533227406456097, loss_age_kl=1.586908446732442]  
100%|██████████| 4493/4493 [03:12<00:00, 23.30 batch/s]


tensor(11.3803, dtype=torch.float64) tensor(0.9113, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:25<00:00,  4.02 batch/s, loss_age=5.6297636, loss_age_group=0.8131323572870419, loss_age_kl=1.7567391851602612] 
100%|██████████| 4493/4493 [03:08<00:00, 23.80 batch/s]


tensor(14.5492, dtype=torch.float64) tensor(1.7733, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:28<00:00,  4.01 batch/s, loss_age=3.563517, loss_age_group=0.49091431368487454, loss_age_kl=1.3545819367379326] 
100%|██████████| 4493/4493 [03:12<00:00, 23.34 batch/s]


tensor(9.5842, dtype=torch.float64) tensor(2.2173, dtype=torch.float64) tensor(0.2173, dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:25<00:00,  4.02 batch/s, loss_age=3.1511202, loss_age_group=0.4539154732683436, loss_age_kl=1.2580512130716426] 
100%|██████████| 4493/4493 [03:07<00:00, 23.93 batch/s]


tensor(9.8637, dtype=torch.float64) tensor(2.4778, dtype=torch.float64) tensor(0.4778, dtype=torch.float64)


100%|██████████| 7337/7337 [30:10<00:00,  4.05 batch/s, loss_age=2.4683309, loss_age_group=0.327595647304296, loss_age_kl=1.095674029179962]   
100%|██████████| 4493/4493 [03:18<00:00, 22.67 batch/s]


tensor(7.9512, dtype=torch.float64) tensor(2.0609, dtype=torch.float64) tensor(0.0609, dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:09<00:00,  4.06 batch/s, loss_age=2.504515, loss_age_group=0.34031464128771266, loss_age_kl=1.1098505984359501] 
100%|██████████| 4493/4493 [03:07<00:00, 24.01 batch/s]


tensor(8.1020, dtype=torch.float64) tensor(1.9961, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:22<00:00,  4.03 batch/s, loss_age=2.4113493, loss_age_group=0.35190215949385945, loss_age_kl=1.156787107611037] 
100%|██████████| 4493/4493 [03:16<00:00, 22.81 batch/s]


tensor(7.1410, dtype=torch.float64) tensor(2.2402, dtype=torch.float64) tensor(0.2402, dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:13<00:00,  4.05 batch/s, loss_age=1.9035091, loss_age_group=0.2720005845285527, loss_age_kl=0.9377836913175814] 
100%|██████████| 4493/4493 [03:04<00:00, 24.36 batch/s]


tensor(6.0377, dtype=torch.float64) tensor(3.0397, dtype=torch.float64) tensor(1.8978, dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:30<00:00,  4.01 batch/s, loss_age=2.3554192, loss_age_group=0.29531590259595786, loss_age_kl=1.0343260415806181]
100%|██████████| 4493/4493 [03:12<00:00, 23.38 batch/s]


tensor(5.9405, dtype=torch.float64) tensor(2.9879, dtype=torch.float64) tensor(1.7532, dtype=torch.float64)
Saved model


100%|██████████| 7337/7337 [30:39<00:00,  3.99 batch/s, loss_age=1.6728907, loss_age_group=0.2032273909968672, loss_age_kl=0.8227785198793046] 
100%|██████████| 4493/4493 [03:06<00:00, 24.10 batch/s]


tensor(5.5719, dtype=torch.float64) tensor(3.4946, dtype=torch.float64) tensor(2.4327, dtype=torch.float64)
Saved model


In [5]:
####################################################
EPOCHS = 12
####################################################

dl_train = dm_train.get_balanced_class_random_dataset_dataloader(df_base_len=431304, class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                                batch_size=128, num_workers=16, prefetch_factor=4, pin_memory=True)

model_age_group.load_state_dict(torch.load("./model_age_group_film_feature.pt"))
model_age.load_state_dict(torch.load("./model_age_film_feature.pt"))
opt = optim.SGD(set([*model_age.fc0.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)

In [6]:
def forward_function(x):
    knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
    knowledge = model_age_group(x, knowledge_age_group)
    knowledge = F.softmax(knowledge, dim=-1)
    out = model_age(x, knowledge)
    out = F.softmax(out, dim=-1)
    return out

mae, val_aar, val_aar_old = validator.validate_ext(forward_function)
print(mae, val_aar, val_aar_old)

100%|██████████| 4493/4493 [03:11<00:00, 23.47 batch/s]


tensor(5.5719, dtype=torch.float64) tensor(3.4946, dtype=torch.float64) tensor(2.4327, dtype=torch.float64)


In [8]:
best_val_aar = val_aar
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            # loss_age_group: torch.Tensor = kl(F.log_softmax(knowledge, dim=-1), y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = loss_age_kl + torch.square(loss_age - mae) # + loss_age_group
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age=loss_age.detach().cpu().numpy(), 
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae_, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae_, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_film_classifier_aug.pt")
        torch.save(model_age.state_dict(), "./model_age_film_classifier_aug.pt")
        print("Saved model")

100%|██████████| 7337/7337 [32:45<00:00,  3.73 batch/s, loss_age=5.837307, loss_age_kl=1.16501252020349]   
100%|██████████| 4493/4493 [03:15<00:00, 23.03 batch/s]


tensor(7.2114, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [32:38<00:00,  3.75 batch/s, loss_age=5.048768, loss_age_kl=1.121530937141424]  
100%|██████████| 4493/4493 [03:03<00:00, 24.43 batch/s]


tensor(6.8076, dtype=torch.float64) tensor(1.6955, dtype=torch.float64) tensor(0.1924, dtype=torch.float64)


100%|██████████| 7337/7337 [31:36<00:00,  3.87 batch/s, loss_age=6.24458, loss_age_kl=1.1532620911395344]  
100%|██████████| 4493/4493 [03:09<00:00, 23.68 batch/s]


tensor(7.1169, dtype=torch.float64) tensor(1.5196, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:17<00:00,  4.04 batch/s, loss_age=5.74687, loss_age_kl=1.1554092594150667]  
100%|██████████| 4493/4493 [03:01<00:00, 24.72 batch/s]


tensor(7.8608, dtype=torch.float64) tensor(0.0946, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:10<00:00,  4.05 batch/s, loss_age=5.1323853, loss_age_kl=1.0459457817365472]
100%|██████████| 4493/4493 [03:09<00:00, 23.76 batch/s]


tensor(7.6111, dtype=torch.float64) tensor(1.0376, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [30:02<00:00,  4.07 batch/s, loss_age=6.011825, loss_age_kl=1.045396620486668]  
100%|██████████| 4493/4493 [03:02<00:00, 24.68 batch/s]


tensor(8.0177, dtype=torch.float64) tensor(1.1456, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|██████████| 7337/7337 [29:58<00:00,  4.08 batch/s, loss_age=5.330656, loss_age_kl=1.0032771077029927] 
100%|██████████| 4493/4493 [03:09<00:00, 23.72 batch/s]


tensor(7.9992, dtype=torch.float64) tensor(1.1891, dtype=torch.float64) tensor(0., dtype=torch.float64)


100%|█████████▉| 7336/7337 [31:14<00:00,  3.91 batch/s, loss_age=5.247093, loss_age_kl=0.9819015667338922] 


ValueError: Tried to step 58698 times. The specified number of total steps is 58696

: 