In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
# aug = CSVUtils.get_df_from_csv("./augumentation_balanced_remove.csv", "./newAugmentationDataset/")
# df_train_aug = pd.concat([df_train, aug], ignore_index=True)
# df_train_aug = df_train_aug.reset_index(drop=True)
df_train_aug = df_train.reset_index(drop=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch
import numpy as np

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                        transform_func=transform_func)
# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train_balanced = AgeDatasetKL(df_train_aug, path_col="path", label_col="age", label_function="Linear", 
                                transform_func=transform_func)
cd_train.set_n_classes(81)
cd_train.set_starting_class(1)
cd_train_balanced.set_n_classes(81)
cd_train_balanced.set_starting_class(1)
dm_train = CustomDataLoader(cd_train)
dm_train_balanced = CustomDataLoader(cd_train_balanced)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=64, shuffle=True, drop_last=True, num_workers=6, prefetch_factor=2, pin_memory=True)
dl_train_balanced, sampler = dm_train_balanced.get_balanced_class_dataloader2(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                                            batch_size=64, num_workers=6, prefetch_factor=2, pin_memory=True)
sampler.n_batches = len(dl_train)
sampler.p = np.array([0.4, 0.3, 0.03, 0.03, 0.03, 0.03, 0.03, 0.15])

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.ArgMaxAge, 32, num_workers=6, prefetch_factor=2)

In [3]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone, ResNetNotFiLMed, DoNothingLayer
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_b0, EfficientNet_B0_Weights
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 12
####################################################

backbone = resnet18(ResNet18_Weights.IMAGENET1K_V1)
backbone.fc = DoNothingLayer()
backbone.train()
backbone.requires_grad_(True)
backbone.to("cuda")
model_age = ResNetNotFiLMed(backbone, 81)
model_age.load_state_dict(torch.load("./model_age_classification_simple.pt", map_location="cuda:0"))
opt = optim.SGD(set([*backbone.parameters(), *model_age.fc0.parameters()]), lr=1e-4, weight_decay=1e-5)
kl = nn.KLDivLoss(reduction="batchmean")



In [4]:
def forward_function(x):
    out = model_age(x)
    out = F.softmax(out, dim=-1)
    return out

ae, mae, val_aar, val_aar_old = validator.validate_ext2(forward_function)
print(ae, mae, val_aar, val_aar_old)

100%|██████████| 4493/4493 [02:11<00:00, 34.13 batch/s]


{0: tensor(4.5905), 1: tensor(2.0977), 2: tensor(2.2591), 3: tensor(2.8694), 4: tensor(2.7316), 5: tensor(2.3014), 6: tensor(2.3476), 7: tensor(2.3447)} tensor(2.5343) tensor(6.5351) tensor(6.6935)


In [5]:
import numpy as np
def get_centers(outs, ys, old):
    centers_for_age = {x: [] for x in range(81)}

    for out, y in zip(outs, ys):
        for _out, _y in zip(out, y):
            centers_for_age[int(_y)].append(_out.detach().cpu().numpy())

    for i in centers_for_age:
        if len(centers_for_age[i]) > 0:
            centers_for_age[i] = torch.tensor(np.array(centers_for_age[i])).mean(dim=0).to("cuda")
        else:
            centers_for_age[i] = old[i]
    return centers_for_age

def update_centers(old, new, alpha=0.5):
    for i in new:
        new[i] = new[i] - alpha*(new[i] - old[i])
    return new

def get_centers_loss(out, y, centers_for_age):
    loss = None
    for _out, _y in zip(out, y):
        if loss is None:
            loss = torch.mean(torch.square(_out - centers_for_age[int(_y)]))
        else:    
            loss += torch.mean(torch.square(_out - centers_for_age[int(_y)]))
    return loss.to("cuda")

In [6]:
centers_for_age = {x: torch.zeros(size=(512,), device="cuda") for x in range(81)}
best_val_aar = val_aar

for e in range(EPOCHS):
    bb = 0
    with tqdm(zip(dl_train, dl_train_balanced), unit=" batch") as tepoch:
        for batch, batch_balanced in tepoch:
            if bb>2000:
                break
            bb+=1
            opt.zero_grad()
            x, y = batch
            x_bal, y_bal = batch_balanced

            x = x.to("cuda")
            y_age: torch.Tensor = y[0].to("cuda")
            y_age_kl: torch.Tensor = y[1].to("cuda")

            x_bal = x_bal.to("cuda")
            y_age_bal: torch.Tensor = y_bal[0].to("cuda")
            y_age_kl_bal: torch.Tensor = y_bal[1].to("cuda")

            out_rep, out_age = model_age.forward_with_repr(x)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)
            out_age = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out_age).to("cuda")
            loss_age = torch.mean(torch.abs(y_age - out))
            loss = loss_age_kl + loss_age

            out_rep_bal, out_age_bal = model_age.forward_with_repr(x_bal)
            loss_age_kl_bal: torch.Tensor = kl(F.log_softmax(out_age_bal, dim=-1), y_age_kl_bal)
            out_age_bal = F.softmax(out_age_bal, dim=-1)
            out_bal = AgeConversion.EVAge(out_age_bal).to("cuda")
            loss_age_bal = torch.mean(torch.abs(y_age_bal - out_bal))
            loss_bal = loss_age_kl_bal + torch.square(loss_age_bal - mae) # (torch.square(loss_age_bal - mae) if loss_age_bal > mae else torch.tensor(0))#  + loss_age_bal

            loss_repr = get_centers_loss(out_rep, y_age, centers_for_age)
            loss_repr_bal = get_centers_loss(out_rep_bal, y_age_bal, centers_for_age)

            total_loss = loss + loss_bal + loss_repr + loss_repr_bal

            total_loss.backward()
            opt.step()

            centers_for_age = update_centers(centers_for_age, get_centers((out_rep, out_rep_bal), (y_age, y_age_bal), centers_for_age), alpha=0.5)

            tepoch.set_postfix(loss_age_kl=loss_age_kl.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy(),
                                loss_age_kl_bal=loss_age_kl_bal.detach().cpu().numpy(), loss_age_bal=loss_age_bal.detach().cpu().numpy(),
                                loss_repr=loss_repr.detach().cpu().numpy(), loss_repr_bal=loss_repr_bal.detach().cpu().numpy(),
                                total_loss=total_loss.detach().cpu().numpy())

    def forward_function(x):
        out = model_age(x)
        out = F.softmax(out, dim=-1)
        return out

    ae_age, ae, mae_, val_aar, val_aar_old = validator.validate_ext3(forward_function)
    print(ae_age, ae, mae_, val_aar, val_aar_old)

    if best_val_aar < val_aar:
        best_val_aar = val_aar
        torch.save(model_age.state_dict(), "./model_age_feature_simple_loss.pt")
        print("Saved model")

6739 batch [21:32,  5.21 batch/s, loss_age=2.3730526, loss_age_bal=3.25409, loss_age_kl=1.3114452053283996, loss_age_kl_bal=1.3591435548809339, loss_repr=0.998779, loss_repr_bal=0.69928926, total_loss=7.259785263371717]     
100%|██████████| 4493/4493 [02:04<00:00, 36.01 batch/s]


{0: tensor(3.), 1: tensor(2.), 2: tensor(1.), 3: tensor(1.7586), 4: tensor(1.1161), 5: tensor(2.5714), 6: tensor(3.2135), 7: tensor(5.3276), 8: tensor(5.9474), 9: tensor(6.3404), 10: tensor(7.1200), 11: tensor(8.8182), 12: tensor(8.2924), 13: tensor(8.7056), 14: tensor(8.2157), 15: tensor(7.0342), 16: tensor(5.2951), 17: tensor(3.7951), 18: tensor(2.6712), 19: tensor(2.1156), 20: tensor(2.8402), 21: tensor(3.1573), 22: tensor(3.0026), 23: tensor(2.7429), 24: tensor(2.4416), 25: tensor(2.4947), 26: tensor(2.5629), 27: tensor(2.6365), 28: tensor(2.6388), 29: tensor(2.7054), 30: tensor(2.9911), 31: tensor(2.9692), 32: tensor(3.1220), 33: tensor(3.3962), 34: tensor(3.3819), 35: tensor(3.3017), 36: tensor(3.5417), 37: tensor(3.4446), 38: tensor(3.2381), 39: tensor(3.1480), 40: tensor(2.9727), 41: tensor(3.0831), 42: tensor(2.9343), 43: tensor(3.1543), 44: tensor(2.9276), 45: tensor(3.1696), 46: tensor(3.1904), 47: tensor(3.1854), 48: tensor(3.0438), 49: tensor(2.5885), 50: tensor(2.2970), 5

6739 batch [20:20,  5.52 batch/s, loss_age=2.8311567, loss_age_bal=2.6526256, loss_age_kl=1.308581882088821, loss_age_kl_bal=1.200372493855891, loss_repr=0.76686484, loss_repr_bal=0.60118216, total_loss=6.722155411690154]   
100%|██████████| 4493/4493 [02:05<00:00, 35.76 batch/s]


{0: tensor(3.5000), 1: tensor(5.1000), 2: tensor(3.3846), 3: tensor(3.2759), 4: tensor(0.9286), 5: tensor(2.3524), 6: tensor(3.0449), 7: tensor(4.3966), 8: tensor(5.2632), 9: tensor(5.3830), 10: tensor(6.0400), 11: tensor(8.5253), 12: tensor(7.6140), 13: tensor(7.6234), 14: tensor(6.9708), 15: tensor(5.8068), 16: tensor(4.5145), 17: tensor(3.1851), 18: tensor(2.0572), 19: tensor(1.6303), 20: tensor(2.3812), 21: tensor(2.9457), 22: tensor(2.8537), 23: tensor(2.5510), 24: tensor(2.2081), 25: tensor(2.2570), 26: tensor(2.2117), 27: tensor(2.2790), 28: tensor(2.2256), 29: tensor(2.2749), 30: tensor(2.5561), 31: tensor(2.5855), 32: tensor(2.7321), 33: tensor(2.9173), 34: tensor(2.9216), 35: tensor(2.9070), 36: tensor(3.0725), 37: tensor(2.9801), 38: tensor(2.8354), 39: tensor(2.7509), 40: tensor(2.5986), 41: tensor(2.6892), 42: tensor(2.6083), 43: tensor(2.7993), 44: tensor(2.6647), 45: tensor(2.7848), 46: tensor(2.7992), 47: tensor(2.8274), 48: tensor(2.8272), 49: tensor(2.4111), 50: tenso

6739 batch [20:15,  5.54 batch/s, loss_age=2.0380297, loss_age_bal=2.3747823, loss_age_kl=1.0932159163885142, loss_age_kl_bal=1.1053335219269291, loss_repr=0.48131672, loss_repr_bal=0.35553455, total_loss=5.0988811646131795]
 68%|██████▊   | 3062/4493 [01:27<00:36, 38.85 batch/s]