In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
# aug = CSVUtils.get_df_from_csv("./augumentation_balanced_remove.csv", "./newAugmentationDataset/")
# df_train_aug = pd.concat([df_train, aug], ignore_index=True)
# df_train_aug = df_train_aug.reset_index(drop=True)
df_train_aug = df_train.reset_index(drop=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch
import numpy as np

transform_func = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train = AgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                        transform_func=transform_func)
# Implementazione di un Dataset utilizzando "CustomDataset" per l'architettura con Film
cd_train_balanced = AgeDatasetKL(df_train_aug, path_col="path", label_col="age", label_function="Linear", 
                                transform_func=transform_func)
cd_train.set_n_classes(81)
cd_train.set_starting_class(1)
cd_train_balanced.set_n_classes(81)
cd_train_balanced.set_starting_class(1)
dm_train = CustomDataLoader(cd_train)
dm_train_balanced = CustomDataLoader(cd_train_balanced)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=64, shuffle=True, drop_last=True, num_workers=6, prefetch_factor=2, pin_memory=True)
dl_train_balanced, sampler = dm_train_balanced.get_balanced_age_dataloader2(batch_size=64, num_workers=6, prefetch_factor=2, pin_memory=True)
sampler.n_batches = len(dl_train)
# sampler.p = np.array([0.4, 0.3, 0.01, 0.01, 0.01, 0.02, 0.10, 0.15])

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.ArgMaxAge, 32, num_workers=6, prefetch_factor=2)

In [3]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone, ResNetNotFiLMed, DoNothingLayer
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_b0, EfficientNet_B0_Weights
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 12
####################################################

backbone = resnet18(ResNet18_Weights.IMAGENET1K_V1)
backbone.fc = DoNothingLayer()
backbone.train()
backbone.requires_grad_(True)
backbone.to("cuda")
model_age = ResNetNotFiLMed(backbone, 81)
model_age.load_state_dict(torch.load("./model_age_classification_simple.pt", map_location="cuda:0"))
opt = optim.SGD(set([*backbone.parameters(), *model_age.fc0.parameters()]), lr=2.5e-4, weight_decay=2.5e-5)
kl = nn.KLDivLoss(reduction="batchmean")



In [4]:
def forward_function(x):
    out = model_age(x)
    out = F.softmax(out, dim=-1)
    return out

sig, ae_age, ae, mae, val_aar, val_aar_old = validator.validate_ext4(forward_function)
print(sig, ae_age, ae, mae, val_aar, val_aar_old)

n_classes = 81
n_worst = 16

worst = sorted([(i, float(torch.nan_to_num(ae_age[i]))) for i in range(n_classes)], reverse=True, key=lambda x: x[1])[:n_worst]
sampler.p = np.array([0.2 / (n_classes - n_worst)]*n_classes)
for w in worst:
    index, _ = w
    sampler.p[index] = 0.8 / n_worst

100%|██████████| 4493/4493 [02:10<00:00, 34.54 batch/s]


tensor(0.7722) {0: tensor(4.), 1: tensor(6.5000), 2: tensor(4.4615), 3: tensor(6.8621), 4: tensor(5.2589), 5: tensor(4.9143), 6: tensor(4.0449), 7: tensor(4.0172), 8: tensor(3.7368), 9: tensor(2.9574), 10: tensor(3.5600), 11: tensor(6.1212), 12: tensor(4.0468), 13: tensor(3.9221), 14: tensor(3.6035), 15: tensor(2.8089), 16: tensor(2.5170), 17: tensor(1.7592), 18: tensor(1.5699), 19: tensor(1.5963), 20: tensor(1.9288), 21: tensor(2.1778), 22: tensor(2.2322), 23: tensor(2.2070), 24: tensor(2.0424), 25: tensor(2.2497), 26: tensor(2.3530), 27: tensor(2.4868), 28: tensor(2.5064), 29: tensor(2.3503), 30: tensor(2.5637), 31: tensor(2.6058), 32: tensor(2.7249), 33: tensor(2.8690), 34: tensor(2.8755), 35: tensor(2.9689), 36: tensor(3.1319), 37: tensor(3.1465), 38: tensor(2.9903), 39: tensor(2.8146), 40: tensor(2.7323), 41: tensor(2.6366), 42: tensor(2.6925), 43: tensor(2.8147), 44: tensor(2.7393), 45: tensor(2.7857), 46: tensor(2.8415), 47: tensor(2.8646), 48: tensor(2.8352), 49: tensor(2.3849)

In [5]:
import numpy as np
def get_centers(outs, ys, old):
    centers_for_age = {x: [] for x in range(81)}

    for out, y in zip(outs, ys):
        for _out, _y in zip(out, y):
            centers_for_age[int(_y)].append(_out.detach().cpu().numpy())

    for i in centers_for_age:
        if len(centers_for_age[i]) > 0:
            centers_for_age[i] = torch.tensor(np.array(centers_for_age[i])).mean(dim=0).to("cuda")
        else:
            centers_for_age[i] = old[i]
    return centers_for_age

def update_centers(old, new, alpha=0.5):
    for i in new:
        new[i] = new[i] - alpha*(new[i] - old[i])
    return new

def get_centers_loss(out, y, centers_for_age):
    loss = None
    for _out, _y in zip(out, y):
        if loss is None:
            loss = torch.mean(torch.square(_out - centers_for_age[int(_y)]))
        else:    
            loss += torch.mean(torch.square(_out - centers_for_age[int(_y)]))
    return loss.to("cuda")

def get_new_p(worst_index):
    if worst_index == 0:
        return np.array([0.68, 0.2, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02])
    if worst_index == 7:
        return np.array([0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.2, 0.68])
    p = np.array([0.02]*8)
    p[worst_index-1:worst_index+2] = np.array([0.2, 0.5, 0.2])
    return p

In [6]:
centers_for_age = {x: torch.zeros(size=(512,), device="cuda") for x in range(81)}
best_val_aar = 6.86

for e in range(EPOCHS):
    with tqdm(zip(dl_train, dl_train_balanced), unit=" batch") as tepoch:
        for batch, batch_balanced in tepoch:
            opt.zero_grad()
            x, y = batch
            x_bal, y_bal = batch_balanced

            x = x.to("cuda")
            y_age: torch.Tensor = y[0].to("cuda")
            y_age_kl: torch.Tensor = y[1].to("cuda")

            x_bal = x_bal.to("cuda")
            y_age_bal: torch.Tensor = y_bal[0].to("cuda")
            y_age_kl_bal: torch.Tensor = y_bal[1].to("cuda")

            out_rep, out_age = model_age.forward_with_repr(x)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)
            out_age = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out_age).to("cuda")
            loss_age = torch.mean(torch.abs(y_age - out))
            loss = loss_age_kl + loss_age

            out_rep_bal, out_age_bal = model_age.forward_with_repr(x_bal)
            loss_age_kl_bal: torch.Tensor = kl(F.log_softmax(out_age_bal, dim=-1), y_age_kl_bal)
            out_age_bal = F.softmax(out_age_bal, dim=-1)
            out_bal = AgeConversion.EVAge(out_age_bal).to("cuda")
            loss_age_bal = torch.mean(torch.abs(y_age_bal - out_bal))
            loss_bal = loss_age_kl_bal + torch.square(loss_age_bal - 2.0) * sig

            loss_repr = get_centers_loss(out_rep, y_age, centers_for_age)
            loss_repr_bal = get_centers_loss(out_rep_bal, y_age_bal, centers_for_age)

            total_loss = loss + loss_bal + loss_repr + loss_repr_bal

            total_loss.backward()
            opt.step()

            centers_for_age = update_centers(centers_for_age, get_centers((out_rep, out_rep_bal), (y_age, y_age_bal), centers_for_age), alpha=0.5)

            tepoch.set_postfix(loss_age_kl=loss_age_kl.detach().cpu().numpy(), loss_age=loss_age.detach().cpu().numpy(),
                                loss_age_kl_bal=loss_age_kl_bal.detach().cpu().numpy(), loss_age_bal=loss_age_bal.detach().cpu().numpy(),
                                loss_repr=loss_repr.detach().cpu().numpy(), loss_repr_bal=loss_repr_bal.detach().cpu().numpy(),
                                total_loss=total_loss.detach().cpu().numpy())

    def forward_function(x):
        out = model_age(x)
        out = F.softmax(out, dim=-1)
        return out

    sig, ae_age, ae, mae_, val_aar, val_aar_old = validator.validate_ext4(forward_function)
    print(sig, ae_age, ae, mae_, val_aar, val_aar_old)

    # sampler.p = (sampler.p + (np.array([1 + float(torch.nan_to_num(x)) for x in ae_age.values()]) / np.sum(np.array([1 + float(torch.nan_to_num(x)) for x in ae_age.values()]))))/2
    # sampler.p = sampler.p / np.sum(sampler.p)

    worst = sorted([(i, float(torch.nan_to_num(ae_age[i]))) for i in range(n_classes)], reverse=True, key=lambda x: x[1])[:n_worst]
    sampler.p = np.array([0.2 / (n_classes - n_worst)]*n_classes)
    for w in worst:
        index, _ = w
        sampler.p[index] = 0.8 / n_worst

    if best_val_aar < val_aar:
        best_val_aar = val_aar
        torch.save(model_age.state_dict(), "./model_age_feature_simple_loss.pt")
        print("Saved model")

372 batch [01:12,  5.35 batch/s, loss_age=5.089105, loss_age_bal=3.7084184, loss_age_kl=1.6604416970127815, loss_age_kl_bal=1.755192442467182, loss_repr=1.7412118, loss_repr_bal=1.3342013, total_loss=13.833966358603803]  