In [1]:
from typing import List
from Dataset.CustomDataset import AgeGroupAndAgeDataset, StandardDataset, AgeGroupAndAgeDatasetKL, AgeGroupKLAndAgeDatasetKL
from Dataset.CustomDataLoaders import CustomDataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from Utils import AAR, CSVUtils, AgeConversion
from Utils.Validator import Validator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

label_map_v = CSVUtils.get_label_map_vector()
cd_train = AgeGroupKLAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                                    label_map_vector=label_map_v, transform_func=transform_func)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=6, prefetch_factor=4)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_unbalanced_dataloader(batch_size=128, shuffle=True, drop_last=True, num_workers=16, prefetch_factor=4, pin_memory=True)

In [3]:
from ResNetFilmed.resnet import ResNetFiLMed, BackBone
import torch
from torch import optim
import torch.nn.functional as F
from torch import nn

####################################################
EPOCHS = 24
####################################################

backbone = BackBone(pretrained=True, freeze=False)
backbone.train()
model_age_group = ResNetFiLMed(backbone, 8)
model_age_group.train()
model_age = ResNetFiLMed(backbone, 81)
model_age.train()
opt = optim.SGD(set([*model_age_group.parameters(), *model_age.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)
kl = nn.KLDivLoss(reduction="batchmean")

In [4]:
best_val_aar = -1
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group(x, knowledge_age_group)
            loss_age_group: torch.Tensor = kl(F.log_softmax(knowledge, dim=-1), y_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = loss_age_group + loss_age_kl + loss_age
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age_group=loss_age_group.detach().cpu().numpy(), 
                                loss_age=loss_age.detach().cpu().numpy(), 
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_film_feature.pt")
        torch.save(model_age.state_dict(), "./model_age_film_feature.pt")
        print("Saved model")

100%|██████████| 3369/3369 [16:06<00:00,  3.49 batch/s, loss_age=4.1225286, loss_age_group=0.8409992113719177, loss_age_kl=1.644089189164096] 
100%|██████████| 4493/4493 [02:50<00:00, 26.31 batch/s]


tensor(4.3023, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(2.6977, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:57<00:00,  3.75 batch/s, loss_age=4.087906, loss_age_group=0.7079038489606979, loss_age_kl=1.6315398155898817] 
100%|██████████| 4493/4493 [02:56<00:00, 25.41 batch/s]


tensor(4.0169, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(2.9831, dtype=torch.float64)


100%|██████████| 3369/3369 [14:50<00:00,  3.78 batch/s, loss_age=3.5028963, loss_age_group=0.6470631385072011, loss_age_kl=1.5101720884514476] 
100%|██████████| 4493/4493 [02:52<00:00, 26.08 batch/s]


tensor(4.0790, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(2.9210, dtype=torch.float64)


100%|██████████| 3369/3369 [15:12<00:00,  3.69 batch/s, loss_age=4.242238, loss_age_group=0.69957255253239, loss_age_kl=1.6982743737867865]    
100%|██████████| 4493/4493 [02:47<00:00, 26.77 batch/s]


tensor(3.8755, dtype=torch.float64) tensor(0.2590, dtype=torch.float64) tensor(3.1245, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [15:23<00:00,  3.65 batch/s, loss_age=3.6837072, loss_age_group=0.6784702101753103, loss_age_kl=1.5806730910984392]
100%|██████████| 4493/4493 [02:50<00:00, 26.33 batch/s]


tensor(3.8550, dtype=torch.float64) tensor(0.0663, dtype=torch.float64) tensor(3.1450, dtype=torch.float64)


100%|██████████| 3369/3369 [15:16<00:00,  3.67 batch/s, loss_age=3.000021, loss_age_group=0.6307333173830566, loss_age_kl=1.4267006080022582] 
100%|██████████| 4493/4493 [02:50<00:00, 26.29 batch/s]


tensor(3.8304, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(3.1696, dtype=torch.float64)


100%|██████████| 3369/3369 [14:37<00:00,  3.84 batch/s, loss_age=2.967871, loss_age_group=0.6173589501647027, loss_age_kl=1.3431954748833768]  
100%|██████████| 4493/4493 [02:46<00:00, 27.00 batch/s]


tensor(3.8357, dtype=torch.float64) tensor(0.3571, dtype=torch.float64) tensor(3.1643, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:05<00:00,  3.99 batch/s, loss_age=3.920508, loss_age_group=0.7198921237830948, loss_age_kl=1.568295526443661]  
100%|██████████| 4493/4493 [02:42<00:00, 27.64 batch/s]


tensor(3.8058, dtype=torch.float64) tensor(0.4395, dtype=torch.float64) tensor(3.1942, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [13:54<00:00,  4.04 batch/s, loss_age=3.9172468, loss_age_group=0.6811447372082948, loss_age_kl=1.5307950513210977] 
100%|██████████| 4493/4493 [02:37<00:00, 28.44 batch/s]


tensor(3.7836, dtype=torch.float64) tensor(0.1419, dtype=torch.float64) tensor(3.2164, dtype=torch.float64)


100%|██████████| 3369/3369 [13:59<00:00,  4.01 batch/s, loss_age=3.8942373, loss_age_group=0.6926710724984875, loss_age_kl=1.5614178987449847] 
100%|██████████| 4493/4493 [02:42<00:00, 27.66 batch/s]


tensor(3.7698, dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(3.2302, dtype=torch.float64)


100%|██████████| 3369/3369 [14:18<00:00,  3.92 batch/s, loss_age=2.9378834, loss_age_group=0.6069531303178753, loss_age_kl=1.3756416783662304] 
100%|██████████| 4493/4493 [02:40<00:00, 28.05 batch/s]


tensor(3.5788, dtype=torch.float64) tensor(0.0182, dtype=torch.float64) tensor(3.4212, dtype=torch.float64)


100%|██████████| 3369/3369 [14:03<00:00,  3.99 batch/s, loss_age=2.9381008, loss_age_group=0.5229229438908762, loss_age_kl=1.3445473710091462] 
100%|██████████| 4493/4493 [02:42<00:00, 27.72 batch/s]


tensor(3.3131, dtype=torch.float64) tensor(0.5953, dtype=torch.float64) tensor(3.6869, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:09<00:00,  3.97 batch/s, loss_age=3.618049, loss_age_group=0.5592546575840189, loss_age_kl=1.4151100221493556]  
100%|██████████| 4493/4493 [02:39<00:00, 28.23 batch/s]


tensor(3.1567, dtype=torch.float64) tensor(1.1241, dtype=torch.float64) tensor(3.8433, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [13:54<00:00,  4.04 batch/s, loss_age=2.9230523, loss_age_group=0.4775690204886981, loss_age_kl=1.26395688219803]   
100%|██████████| 4493/4493 [02:42<00:00, 27.72 batch/s]


tensor(2.8435, dtype=torch.float64) tensor(1.4805, dtype=torch.float64) tensor(4.1565, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:02<00:00,  4.00 batch/s, loss_age=2.6927166, loss_age_group=0.43904586449318295, loss_age_kl=1.2371398903825395]
100%|██████████| 4493/4493 [02:40<00:00, 27.92 batch/s]


tensor(2.7705, dtype=torch.float64) tensor(3.2809, dtype=torch.float64) tensor(4.5908, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:01<00:00,  4.00 batch/s, loss_age=2.5156503, loss_age_group=0.4365818533242489, loss_age_kl=1.151235849977433]  
100%|██████████| 4493/4493 [02:43<00:00, 27.43 batch/s]


tensor(2.7501, dtype=torch.float64) tensor(2.5628, dtype=torch.float64) tensor(4.2499, dtype=torch.float64)


100%|██████████| 3369/3369 [14:23<00:00,  3.90 batch/s, loss_age=2.2466235, loss_age_group=0.37897717600940894, loss_age_kl=1.105387573051918] 
100%|██████████| 4493/4493 [02:42<00:00, 27.67 batch/s]


tensor(2.6655, dtype=torch.float64) tensor(3.2214, dtype=torch.float64) tensor(4.6065, dtype=torch.float64)


100%|██████████| 3369/3369 [14:20<00:00,  3.91 batch/s, loss_age=3.0767837, loss_age_group=0.5402176190202812, loss_age_kl=1.316104232414391]  
100%|██████████| 4493/4493 [02:43<00:00, 27.46 batch/s]


tensor(2.6531, dtype=torch.float64) tensor(3.0934, dtype=torch.float64) tensor(4.5061, dtype=torch.float64)


100%|██████████| 3369/3369 [14:13<00:00,  3.95 batch/s, loss_age=2.5697117, loss_age_group=0.39265989563863457, loss_age_kl=1.1554645665990142]
100%|██████████| 4493/4493 [02:42<00:00, 27.58 batch/s]


tensor(2.6164, dtype=torch.float64) tensor(2.7873, dtype=torch.float64) tensor(4.3836, dtype=torch.float64)


100%|██████████| 3369/3369 [14:13<00:00,  3.95 batch/s, loss_age=2.567331, loss_age_group=0.42029057867508507, loss_age_kl=1.1247398699448805] 
100%|██████████| 4493/4493 [02:40<00:00, 28.06 batch/s]


tensor(2.4879, dtype=torch.float64) tensor(3.6091, dtype=torch.float64) tensor(4.8477, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [13:40<00:00,  4.11 batch/s, loss_age=2.228742, loss_age_group=0.32643292303768523, loss_age_kl=1.0008824148745645] 
100%|██████████| 4493/4493 [02:43<00:00, 27.56 batch/s]


tensor(2.4999, dtype=torch.float64) tensor(3.9076, dtype=torch.float64) tensor(5.1366, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:46<00:00,  3.80 batch/s, loss_age=2.0654926, loss_age_group=0.3440079900443026, loss_age_kl=0.990021445372302]  
100%|██████████| 4493/4493 [02:35<00:00, 28.88 batch/s]


tensor(2.3515, dtype=torch.float64) tensor(3.7933, dtype=torch.float64) tensor(5.1091, dtype=torch.float64)


100%|██████████| 3369/3369 [14:10<00:00,  3.96 batch/s, loss_age=1.5980439, loss_age_group=0.23236247612995625, loss_age_kl=0.814830174363359] 
100%|██████████| 4493/4493 [02:43<00:00, 27.43 batch/s]


tensor(2.3348, dtype=torch.float64) tensor(4.0360, dtype=torch.float64) tensor(5.2931, dtype=torch.float64)
Saved model


100%|██████████| 3369/3369 [14:23<00:00,  3.90 batch/s, loss_age=1.6790231, loss_age_group=0.28259544935052927, loss_age_kl=0.8802935455864149]
100%|██████████| 4493/4493 [02:42<00:00, 27.73 batch/s]


tensor(2.3195, dtype=torch.float64) tensor(4.1414, dtype=torch.float64) tensor(5.3608, dtype=torch.float64)
Saved model


In [4]:
####################################################
EPOCHS = 12
####################################################

#########################
import pandas as pd
df = CSVUtils.get_df_from_csv("./training_caip_contest.csv", "./training_caip_contest/")
df_train, df_val = train_test_split(df, test_size=0.25, random_state=42)
aug = CSVUtils.get_df_from_csv("./augumentation_simple.csv", "./newAugmentationDataset/")
df_train = pd.concat([df_train, aug], ignore_index=True)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
#########################

from torchvision import transforms
import torch

transform_func = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(2, 9),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

transform_func_val = transforms.Compose([
    transforms.Resize(224),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

label_map_v = CSVUtils.get_label_map_vector()
cd_train = AgeGroupKLAndAgeDatasetKL(df_train, path_col="path", label_col="age", label_function="Linear", 
                                    label_map_vector=label_map_v, transform_func=transform_func)

cd_val = StandardDataset(df_val, path_col="path", label_col="age", label_function="CAE", transform_func=transform_func_val)
cd_val.set_n_classes(81)
cd_val.set_starting_class(1)
validator = Validator(cd_val, AgeConversion.EVAge, 32, num_workers=6, prefetch_factor=4)

dm_train = CustomDataLoader(cd_train)
dl_train = dm_train.get_balanced_class_dataloader(class_ranges=[(0, 11), (11, 21), (21, 31), (31, 41), (41, 51), (51, 61), (61, 71), (71, 91)], 
                                                  batch_size=128, num_workers=16, prefetch_factor=4, pin_memory=True)

model_age_group.load_state_dict(torch.load("./model_age_group_film_feature.pt"))
model_age_group.requires_grad_(False)
model_age.load_state_dict(torch.load("./model_age_film_feature.pt"))
opt = optim.SGD(set([*model_age_group.fc0.parameters(), *model_age.fc0.parameters()]), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.1, steps_per_epoch=len(dl_train), epochs=EPOCHS, three_phase=True)

In [5]:
def forward_function(x):
    knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
    knowledge = model_age_group(x, knowledge_age_group)
    knowledge = F.softmax(knowledge, dim=-1)
    out = model_age(x, knowledge)
    out = F.softmax(out, dim=-1)
    return out

mae, val_aar, val_aar_old = validator.validate_ext_fixed(forward_function)
print(mae, val_aar, val_aar_old)

100%|██████████| 4493/4493 [02:49<00:00, 26.56 batch/s]


tensor(2.3195, dtype=torch.float64) tensor(4.1794, dtype=torch.float64) tensor(5.3762, dtype=torch.float64)


In [6]:
best_val_aar = val_aar
knowledge_age_group = torch.tensor([[0.125]*8]*128, requires_grad=False).float().to("cuda")

for e in range(EPOCHS):
    with tqdm(dl_train, unit=" batch") as tepoch:
        for batch in tepoch:
            opt.zero_grad()
            x, y = batch
            x = x.to("cuda")
            y_age_group = y[0].to("cuda")
            y_age = y[1].to("cuda")
            y_age_kl: torch.Tensor = y[2].to("cuda")
            
            knowledge = model_age_group.forward_detach(x, knowledge_age_group)

            knowledge = F.softmax(knowledge, dim=-1)
            out_age = model_age(x, knowledge)
            loss_age_kl: torch.Tensor = kl(F.log_softmax(out_age, dim=-1), y_age_kl)

            out = F.softmax(out_age, dim=-1)
            out = AgeConversion.EVAge(out).to("cuda")
            loss_age: torch.Tensor = torch.mean(torch.abs(out - y_age))

            loss = (torch.square(loss_age - mae) if loss_age > mae else torch.tensor(0)) # + loss_age_kl
            loss.backward()
            opt.step()
            scheduler.step()

            tepoch.set_postfix(loss_age=loss_age.detach().cpu().numpy(),
                                loss_age_kl=loss_age_kl.detach().cpu().numpy())

    def forward_function(x):
        knowledge_age_group = torch.tensor([[0.125]*8]*len(x), requires_grad=False).float().to("cuda")
        knowledge = model_age_group(x, knowledge_age_group)
        knowledge = F.softmax(knowledge, dim=-1)
        out = model_age(x, knowledge)
        out = F.softmax(out, dim=-1)
        return out

    mae_, val_aar, val_aar_old = validator.validate_ext(forward_function)
    print(mae_, val_aar, val_aar_old)

    if val_aar > best_val_aar:
        best_val_aar = val_aar
        torch.save(model_age_group.state_dict(), "./model_age_group_film_classifier.pt")
        torch.save(model_age.state_dict(), "./model_age_film_classifier.pt")
        print("Saved model")

 34%|███▎      | 1401/4165 [02:35<04:28, 10.31 batch/s, loss_age=3.2480974, loss_age_kl=4.147272398326127] 