In [36]:
import torch
import numpy as np
import pytorch_lightning as pl
import hydra
from hydra.core.global_hydra import GlobalHydra
import os

os.environ["HYDRA_FULL_ERROR"] = "1"

GlobalHydra().clear()
hydra.initialize(config_path="../conf", version_base="1.3")
cfg = hydra.compose(
    "config.yaml",
    overrides=[
        "+experiment=brca_CDH1",
        "+feature_extractor=swav",
        # "+feature_extractor=dino_p16",
        "augmentations@dataset.augmentations=none",
        "model=attmil",
        "seed=1",
        "+magnification=high",
        "dataset.num_workers=20",
    ],
)

In [37]:
from histaug.data import FeatureDataset
from histaug.train import load_dataset_df, get_folds, TargetEncoder

from torch.utils.data import DataLoader

dataset_df = load_dataset_df(cfg.dataset)
folds = get_folds(cfg, dataset_df)
dataset_df

Unnamed: 0_level_0,CDH1,path
PATIENT,Unnamed: 1_level_1,Unnamed: 2_level_1
TCGA-3C-AALI,0,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-3C-AALJ,0,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-3C-AALK,0,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-4H-AAAK,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-5L-AAT0,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
...,...,...
TCGA-WT-AB44,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-XX-A899,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-XX-A89A,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...
TCGA-Z7-A8R5,1,[/raid/histaug/features/tcga_brca_mpp0.5/swav/...


In [38]:
crossval_fold = 1

valid_mask = folds == crossval_fold
train_items, valid_items = folds.index[~valid_mask], folds.index[valid_mask]
train_df, valid_df = dataset_df.loc[train_items], dataset_df.loc[valid_items]
assert not (
    overlap := set(train_df.index) & set(valid_df.index)
), f"unexpected overlap between training and testing set: {overlap}"

encoders = {target.column: TargetEncoder.for_target(target) for target in cfg.dataset.targets}
train_targets = {t: encoder.fit(train_df) for t, encoder in encoders.items()}
valid_targets = {t: encoder(valid_df) for t, encoder in encoders.items()}

train_ds = FeatureDataset(
    patient_ids=train_df.index,
    bags=train_df.path.values,
    targets=train_targets,
    instances_per_bag=cfg.dataset.instances_per_bag,
    augmentations=cfg.dataset.augmentations.train,
)
train_dl = DataLoader(
    train_ds,
    batch_size=cfg.dataset.batch_size,
    num_workers=cfg.dataset.num_workers,
    shuffle=True,
    pin_memory=True,
    collate_fn=train_ds.collate_fn,
)

In [33]:
from time import time

print("Sequential loading using ds[index]")

stard_idx = 30
indices = range(stard_idx, stard_idx + 10)
times = []
for i, index in enumerate(indices):
    start = time()
    train_ds[index]
    elapsed = time() - start
    times.append(elapsed)
    print(f"{i+1:02d}/{len(indices)}: {elapsed:.2f}s")
times = np.array(times)
print(f"took {times.mean():.2f} +- {times.std():.2f}s")

Sequential loading using ds[index]
01/10: 0.40s
02/10: 0.71s
03/10: 0.60s
04/10: 0.43s
05/10: 0.56s
06/10: 0.30s
07/10: 0.64s
08/10: 0.45s
09/10: 1.30s
10/10: 0.62s
took 0.60 +- 0.26s


In [43]:
print("Parallel loading using DataLoader")

times = []
start = time()
for i, batch in enumerate(train_dl):
    if i < 30:
        continue
    elapsed = time() - start
    times.append(elapsed)
    print(f"{i+1:02d}/{len(train_dl)}: {elapsed:.2f}s; shape: {batch[0].shape}")
    start = time()
    if i == 60:
        break

times = np.array(times)
print(f"took {times.mean():.2f} +- {times.std():.2f}s")

Parallel loading using DataLoader
31/754: 5.29s; shape: torch.Size([1, 32768, 2048])
32/754: 0.00s; shape: torch.Size([1, 5673, 2048])
33/754: 0.00s; shape: torch.Size([1, 28741, 2048])
34/754: 0.00s; shape: torch.Size([1, 30070, 2048])
35/754: 0.00s; shape: torch.Size([1, 12585, 2048])
36/754: 0.00s; shape: torch.Size([1, 26808, 2048])
37/754: 0.00s; shape: torch.Size([1, 21309, 2048])
38/754: 0.00s; shape: torch.Size([1, 21490, 2048])
39/754: 0.00s; shape: torch.Size([1, 29897, 2048])
40/754: 0.00s; shape: torch.Size([1, 2013, 2048])
41/754: 0.00s; shape: torch.Size([1, 2223, 2048])
42/754: 0.20s; shape: torch.Size([1, 6990, 2048])
43/754: 0.00s; shape: torch.Size([1, 13929, 2048])
44/754: 1.95s; shape: torch.Size([1, 32768, 2048])
45/754: 0.00s; shape: torch.Size([1, 29133, 2048])
46/754: 0.00s; shape: torch.Size([1, 29140, 2048])
47/754: 0.00s; shape: torch.Size([1, 13647, 2048])
48/754: 0.00s; shape: torch.Size([1, 7736, 2048])
49/754: 0.00s; shape: torch.Size([1, 10924, 2048])
50

In [42]:
batch

[tensor([[[0.0008, 0.0688, 0.3683,  ..., 0.0000, 0.0170, 0.0000],
          [0.0000, 0.1337, 0.4470,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.1109, 0.1455,  ..., 0.0000, 0.0000, 0.0079],
          ...,
          [0.0392, 0.0000, 0.0351,  ..., 0.0058, 0.0000, 0.0052],
          [0.0000, 0.1143, 0.1806,  ..., 0.0136, 0.0000, 0.0000],
          [0.0000, 0.0219, 0.2623,  ..., 0.0045, 0.0000, 0.0000]]]),
 tensor([[[59136, 38976],
          [63168, 32928],
          [14112, 10528],
          ...,
          [70112, 22176],
          [19040, 13888],
          [18592, 35392]]], dtype=torch.int32),
 tensor([[True, True, True,  ..., True, True, True]]),
 {'CDH1': tensor([1])},
 ['TCGA-BH-A209']]