In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from pathlib import Path
from typing import Union, Optional
import zarr
import torch
from torch import nn
from tqdm import tqdm
import itertools

from histaug.data import SlideDataset
from histaug.utils import slide_loader
from histaug.augmentations import load_augmentations, Augmentations
from histaug.feature_extractors import load_feature_extractor
from histaug.extract_features.augmented_feature_extractor import AugmentedFeatureExtractor

In [2]:
model = load_feature_extractor("ctranspath")
augmentations = load_augmentations()

ds = SlideDataset("/data/shiprec/TCGA-BRCA-DX", batch_size=256)
loader = DataLoader(ds, batch_size=None, shuffle=False, pin_memory=True, num_workers=8)

loader = slide_loader(iter(loader))

[32m2023-09-12 13:52:22.761[0m | [1mINFO    [0m | [36mhistaug.feature_extractors.utils[0m:[36mdownload_file[0m:[36m20[0m - [1mSkipping download of https://drive.google.com/u/0/uc?id=1DoDx_70_TLj98gTf6YTXnu4tFhsFocDX&export=download to /app/weights/ctranspath.pth as file already exists[0m
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[32m2023-09-12 13:52:23.975[0m | [1mINFO    [0m | [36mhistaug.augmentations[0m:[36m__init__[0m:[36m31[0m - [1mFitting Macenko normalizer to /app/normalization_template.jpg[0m


In [3]:
def process_dataset(loader, model: nn.Module, augmentations: Augmentations, device="cuda", n_batches: int = None):
    augmented_feature_extractor = AugmentedFeatureExtractor(model, augmentations)
    augmented_feature_extractor.to(device)

    with torch.no_grad():
        all_feats = []
        all_feats_augs = {aug_name: [] for aug_name in augmentations}

        for imgs, labels, files in tqdm(
            itertools.islice(loader, n_batches), desc="Processing dataset", total=n_batches
        ):
            imgs = imgs.to(device)
            feats, feats_augs = augmented_feature_extractor(imgs)

            all_feats.append(feats.detach().cpu())
            for aug_name, feats_aug in feats_augs.items():
                all_feats_augs[aug_name].append(feats_aug.detach().cpu())

        feats = torch.cat(all_feats)
        feats_augs = {aug_name: torch.cat(feats_augs) for aug_name, feats_augs in all_feats_augs.items()}
        return feats, feats_augs, labels, files


for slide, patch_loader in loader:
    result = process_dataset(
        loader=(
            (x[0], None, None) for x in patch_loader
        ),  # patch_loader is a generator of (patch, slide, coords), we only need patch
        model=model,
        augmentations=augmentations,
        device="cuda",
    )
    break

Processing dataset: 2it [00:29, 14.62s/it]

In [None]:
result