In [30]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm
from pathlib import Path
import numpy as np

api = wandb.Api()

TRAIN_DIR = Path("/data/histaug/train")
BOOTSTRAPS_DIR = Path("/data/histaug/bootstraps")
N_BOOTSTRAPS = 1000

In [39]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key, None) == value for key, value in filters.items())]


runs = list(api.runs("histaug"))
runs = filter_runs(runs, {"state": "finished"})
run = runs[0]

dataset = run.config["dataset"]["name"]
column = run.config["dataset"]["targets"][0]["column"]
classes = run.config["dataset"]["targets"][0]["classes"]

In [33]:
df = next(iter(TRAIN_DIR.glob(f"*/{run.id}"))) / "test-patient-preds.csv"
df = pd.read_csv(df)

In [35]:
bootstrap_csv = BOOTSTRAPS_DIR / f"{dataset}_{column}.csv"

if not bootstrap_csv.exists():
    logger.debug(f"Caching bootstraps for {dataset} {column} at {bootstrap_csv}")
    patients = df.PATIENT.unique()
    bootstraps = np.random.choice(patients, size=(N_BOOTSTRAPS, len(patients)), replace=True)
    bootstraps_df = pd.DataFrame(bootstraps)
    BOOTSTRAPS_DIR.mkdir(parents=True, exist_ok=True)
    bootstraps_df.to_csv(BOOTSTRAPS_DIR / f"{dataset}_{column}.csv", index=False, header=False)

bootstraps = pd.read_csv(bootstrap_csv, header=None).values
bootstraps

[32m2023-10-23 15:18:29.258[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [34m[1mCaching bootstraps for tcga_brca_CDH1 CDH1 at /data/histaug/bootstraps/tcga_brca_CDH1_CDH1.csv[0m


array([['11BR072', '18BR009', '11BR047', ..., '11BR030', '01BR008',
        '15BR003'],
       ['11BR017', '01BR030', '11BR074', ..., '18BR006', '03BR004',
        '11BR047'],
       ['11BR080', '11BR010', '15BR003', ..., '20BR008', '11BR030',
        '11BR014'],
       ...,
       ['11BR004', '01BR031', '01BR027', ..., '11BR017', '11BR055',
        '11BR017'],
       ['18BR007', '18BR004', '11BR031', ..., '21BR001', '01BR018',
        '11BR025'],
       ['09BR005', '05BR029', '05BR043', ..., '11BR011', '20BR001',
        '01BR033']], dtype=object)

In [44]:
from torchmetrics.classification import MulticlassAUROC

auroc = MulticlassAUROC(num_classes=len(classes))

logits = df[[f"{column}_{c}" for c in classes]].values
df

array([[0.9506246 , 0.04937536],
       [0.9032602 , 0.09673987],
       [0.87391293, 0.12608702],
       [0.8999622 , 0.10003781],
       [0.89624465, 0.10375532],
       [0.86639905, 0.13360101],
       [0.9312646 , 0.06873545],
       [0.9017733 , 0.09822673],
       [0.9122821 , 0.08771793],
       [0.9316866 , 0.06831342],
       [0.87694997, 0.12305004],
       [0.9310366 , 0.06896345],
       [0.9086642 , 0.09133577],
       [0.9278037 , 0.07219631],
       [0.94013137, 0.05986866],
       [0.8417314 , 0.15826859],
       [0.93442774, 0.06557228],
       [0.96265984, 0.03734019],
       [0.9346734 , 0.06532662],
       [0.93234724, 0.06765275],
       [0.96529526, 0.0347047 ],
       [0.8705813 , 0.12941866],
       [0.9141491 , 0.08585089],
       [0.82569885, 0.17430112],
       [0.92543775, 0.07456227],
       [0.87243253, 0.1275675 ],
       [0.94585985, 0.05414008],
       [0.9375587 , 0.06244127],
       [0.9177427 , 0.0822573 ],
       [0.9545906 , 0.04540936],
       [0.

In [42]:
df

Unnamed: 0,PATIENT,CDH1,CDH1_0,CDH1_1,CDH1_pred
0,01BR001,0,0.950625,0.049375,0
1,01BR008,0,0.903260,0.096740,0
2,01BR009,0,0.873913,0.126087,0
3,01BR010,0,0.899962,0.100038,0
4,01BR015,0,0.896245,0.103755,0
...,...,...,...,...,...
115,21BR001,0,0.971703,0.028297,0
116,21BR002,0,0.966159,0.033841,0
117,21BR010,0,0.931586,0.068414,0
118,22BR005,0,0.946921,0.053079,0
