In [1]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
import torch.nn as nn
import torchvision as tv
from torch.utils.data import Subset, Dataset, DataLoader

from source.constants import RESULTS_PATH
from source.data.medical_imaging import get_chexpert, TransformWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
method_seeds = [42, 142, 242, 342, 442]
dseed = 42

model = ["resnet18", "resnet34", "resnet50"][1]

device = "cuda:7"
batch_size = 2048 # 1024 ~ 10GB VRAM / 2048 ~ 15GB VRAM for resnet18 ~ 1GB more for resnet34, 1024 ~ 17GB for resnet50

In [3]:
train_ds, r_val_ds, r_test_ds = get_chexpert()

run_path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}")
fair_inds = torch.load(os.path.join(run_path, "fair_inds.pt"))
val_inds = torch.load(os.path.join(run_path, "val_inds.pt"))

fair_ds = TransformWrapper(Subset(train_ds, indices=fair_inds))
val_ds = TransformWrapper(Subset(train_ds, indices=val_inds))

# patients general 65401
# patients with race 58010
loading images to RAM
images loaded to RAM
loading images to RAM
images loaded to RAM
loading images to RAM
images loaded to RAM


In [4]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset):
    probits = list()
    for network in tqdm(networks):
        network.eval()
        probits_ = list()
        for x, _, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())
        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=0)


In [5]:
for mseed in method_seeds:

    path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{mseed}_dseed{dseed}")

    # load networks
    networks = list()
    model_files = glob.glob(os.path.join(path, "models", "*.pt"))
    for model_file in sorted(model_files):
        if model == "resnet18":
            network = tv.models.resnet18(weights=None) 
            network.fc = nn.Linear(in_features=512, out_features=2)
        elif model == "resnet34":
            network = tv.models.resnet34(weights=None) 
            network.fc = nn.Linear(in_features=512, out_features=2)
        elif model == "resnet50":
            network = tv.models.resnet50(weights=None) 
            network.fc = nn.Linear(in_features=2048, out_features=2)

        network.load_state_dict(torch.load(model_file, map_location=device))
        network.to(device)
        networks.append(copy.deepcopy(network))

    # evaluate
    torch.save(evaluate(networks, fair_ds), os.path.join(path, f"fair_probits.pt"))
    torch.save(evaluate(networks, val_ds), os.path.join(path, f"val_probits.pt"))
    torch.save(evaluate(networks, r_val_ds), os.path.join(path, f"r_val_probits.pt"))
    torch.save(evaluate(networks, r_test_ds), os.path.join(path, f"r_test_probits.pt"))
    print("Evaluated method seed", mseed)

100%|██████████| 10/10 [06:50<00:00, 41.09s/it]
100%|██████████| 10/10 [06:30<00:00, 39.04s/it]
100%|██████████| 10/10 [00:03<00:00,  3.25it/s]
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]


Evaluated method seed 42


100%|██████████| 10/10 [06:33<00:00, 39.34s/it]
100%|██████████| 10/10 [09:02<00:00, 54.21s/it]
100%|██████████| 10/10 [00:03<00:00,  3.18it/s]
100%|██████████| 10/10 [00:10<00:00,  1.05s/it]


Evaluated method seed 142


100%|██████████| 10/10 [07:30<00:00, 45.07s/it]
100%|██████████| 10/10 [07:48<00:00, 46.87s/it]
100%|██████████| 10/10 [00:03<00:00,  3.05it/s]
100%|██████████| 10/10 [00:09<00:00,  1.05it/s]


Evaluated method seed 242


100%|██████████| 10/10 [08:08<00:00, 48.90s/it]
100%|██████████| 10/10 [07:35<00:00, 45.50s/it]
100%|██████████| 10/10 [00:03<00:00,  3.23it/s]
100%|██████████| 10/10 [00:09<00:00,  1.01it/s]


Evaluated method seed 342


100%|██████████| 10/10 [06:54<00:00, 41.50s/it]
100%|██████████| 10/10 [07:17<00:00, 43.71s/it]
100%|██████████| 10/10 [00:03<00:00,  3.30it/s]
100%|██████████| 10/10 [00:09<00:00,  1.09it/s]

Evaluated method seed 442



