In [1]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
import torch.nn as nn
import torchvision as tv
from torch.utils.data import Subset, Dataset, DataLoader

from source.constants import RESULTS_PATH
from source.data.face_detection import get_fair_face, get_utk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
method_seeds = [42, 142, 242, 342, 442]
dseed = 42

model = ["resnet18", "resnet34", "resnet50"][0]

target = 3 # 0, 1, 2, 3

device = "cuda:7"
batch_size = 2048 # 1024 ~ 10GB VRAM / 2048 ~ 15GB VRAM for resnet18 ~ 1GB more for resnet34, 1024 ~ 17GB for resnet50

In [3]:
ff_train_ds, ff_test_ds = get_fair_face(target=target, binarize=True, augment=False)
utk_test_ds = get_utk(target=target, binarize=True)

run_path = os.path.join(RESULTS_PATH, f"fairface_target{target}_{model}_mseed{method_seeds[0]}_dseed{dseed}")
fair_inds = torch.load(os.path.join(run_path, "fair_inds.pt"))
val_inds = torch.load(os.path.join(run_path, "val_inds.pt"))

fair_ds = Subset(ff_train_ds, indices=fair_inds)
val_ds = Subset(ff_train_ds, indices=val_inds)

In [4]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset):
    probits = list()
    for network in tqdm(networks):
        network.eval()
        probits_ = list()
        for x, _, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())
        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=0)


In [5]:
for mseed in method_seeds:

    path = os.path.join(RESULTS_PATH, f"fairface_target{target}_{model}_mseed{mseed}_dseed{dseed}")

    # load networks
    networks = list()
    model_files = glob.glob(os.path.join(path, "models", "*.pt"))
    for model_file in sorted(model_files):
        if model == "resnet18":
            network = tv.models.resnet18(weights=None) 
            network.fc = nn.Linear(in_features=512, out_features=2)
        elif model == "resnet34":
            network = tv.models.resnet34(weights=None) 
            network.fc = nn.Linear(in_features=512, out_features=2)
        elif model == "resnet50":
            network = tv.models.resnet50(weights=None) 
            network.fc = nn.Linear(in_features=2048, out_features=2)

        network.load_state_dict(torch.load(model_file, map_location=device))
        network.to(device)
        networks.append(copy.deepcopy(network))

    # evaluate
    torch.save(evaluate(networks, fair_ds), os.path.join(path, f"fair_probits_t{target}.pt"))
    torch.save(evaluate(networks, val_ds), os.path.join(path, f"val_probits_t{target}.pt"))
    torch.save(evaluate(networks, ff_test_ds), os.path.join(path, f"ff_test_probits_t{target}.pt"))
    torch.save(evaluate(networks, utk_test_ds), os.path.join(path, f"utk_test_probits_t{target}.pt"))
    print("Evaluated method seed", mseed)

100%|██████████| 10/10 [05:04<00:00, 30.43s/it]
100%|██████████| 10/10 [04:42<00:00, 28.28s/it]
100%|██████████| 10/10 [04:16<00:00, 25.66s/it]
100%|██████████| 10/10 [13:08<00:00, 78.82s/it]


Evaluated method seed 42


100%|██████████| 10/10 [03:49<00:00, 22.99s/it]
100%|██████████| 10/10 [04:29<00:00, 26.99s/it]
100%|██████████| 10/10 [04:15<00:00, 25.55s/it]
100%|██████████| 10/10 [11:41<00:00, 70.13s/it]


Evaluated method seed 142


100%|██████████| 10/10 [03:47<00:00, 22.78s/it]
100%|██████████| 10/10 [04:21<00:00, 26.10s/it]
100%|██████████| 10/10 [03:52<00:00, 23.21s/it]
100%|██████████| 10/10 [12:21<00:00, 74.12s/it]


Evaluated method seed 242


100%|██████████| 10/10 [04:27<00:00, 26.71s/it]
100%|██████████| 10/10 [04:09<00:00, 24.97s/it]
100%|██████████| 10/10 [04:33<00:00, 27.36s/it]
100%|██████████| 10/10 [14:19<00:00, 85.92s/it]


Evaluated method seed 342


100%|██████████| 10/10 [05:13<00:00, 31.34s/it]
100%|██████████| 10/10 [04:19<00:00, 25.94s/it]
100%|██████████| 10/10 [04:20<00:00, 26.01s/it]
100%|██████████| 10/10 [14:36<00:00, 87.61s/it]

Evaluated method seed 442



