In [1]:
# Auto-reload module to access .py files easily
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.abspath("../../datasets/rsna-2023-abdominal-trauma-detection"))

In [36]:
import matplotlib.pyplot as plt
import torch
from collections import Counter

from rsna_datasets import (
    Classification2DDataset, 
    Classification3DDataset, 
    MaskedClassification2DDataset, 
    MaskedClassification3DDataset, 
    Segmentation2DDataset, 
    Segmentation3DDataset
)

## Classification

Streaming dataset
- Data samples are downloaded on-demand and not saved to disk.
- Set num_worker > 0 to utilize multiprocessing so that data samples can be prefetched and preprocessed, and keep the GPU as occupied as possible during training.

In [39]:
cls2d_dataset = Classification2DDataset(split="train", streaming=True)
cls2d_dataloader = torch.utils.data.DataLoader(cls2d_dataset, batch_size=512, num_workers=4, pin_memory=True)

for batch_idx, batch in enumerate(cls2d_dataloader):
    if batch_idx == 20:
        break
    
    img = batch["img"]
    bowel = batch["bowel"]
    extravasation = batch["extravasation"]
    kidney = batch["kidney"]
    liver = batch["liver"]
    spleen = batch["spleen"]
    any_injury = batch["any_injury"]
    
    series_counts = dict(Counter(batch["series_id"].numpy()))
    print(f"batch_idx={batch_idx}")
    print(f"img={list(img.size())}, bowel={list(bowel.size())}, extravasation={list(extravasation.size())}, kidney={list(kidney.size())}, liver={list(liver.size())}, spleen={list(spleen.size())}, any_injury={list(batch['any_injury'].size())}")
    print(f"series_counts={series_counts}")
    print()

batch_idx=0
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={21872: 512}

batch_idx=1
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={13664: 512}

batch_idx=2
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={23622: 512}

batch_idx=3
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={62741: 512}

batch_idx=4
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={21872: 512}

batch_idx=5
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={13664: 512}

batch_idx=6
img=[512, 1, 224, 224], bowel=[512], extravasation=[

Unstreamed dataset
- The entire dataset are downloaded to local cache the first time when running the code.
- The subsequent run will not redownload the dataset and will simply access the dataset from the cache directory.
- Use this if your machine has enough disk space. (size of classification dataset: 90.09 GiB)

In [40]:
cls2d_dataset = Classification2DDataset(split="train", streaming=False) # Download data locally to cache dir first (only need to download once)
cls2d_dataloader = torch.utils.data.DataLoader(cls2d_dataset, batch_size=512, num_workers=4, pin_memory=True)

for batch_idx, batch in enumerate(cls2d_dataloader):
    if batch_idx == 20:
        break
    
    img = batch["img"]
    bowel = batch["bowel"]
    extravasation = batch["extravasation"]
    kidney = batch["kidney"]
    liver = batch["liver"]
    spleen = batch["spleen"]
    any_injury = batch["any_injury"]
    
    series_counts = dict(Counter(batch["series_id"].numpy()))
    print(f"batch_idx={batch_idx}")
    print(f"img={list(img.size())}, bowel={list(bowel.size())}, extravasation={list(extravasation.size())}, kidney={list(kidney.size())}, liver={list(liver.size())}, spleen={list(spleen.size())}, any_injury={list(batch['any_injury'].size())}")
    print(f"series_counts={series_counts}")
    print()

batch_idx=0
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={21872: 512}

batch_idx=1
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={13664: 512}

batch_idx=2
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={23622: 512}

batch_idx=3
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={62741: 512}

batch_idx=4
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={21872: 512}

batch_idx=5
img=[512, 1, 224, 224], bowel=[512], extravasation=[512], kidney=[512], liver=[512], spleen=[512], any_injury=[512]
series_counts={13664: 512}

batch_idx=6
img=[512, 1, 224, 224], bowel=[512], extravasation=[

In [None]:
slice_transform_configs = {
    "crop_strategy": "random",
    "shorter_edge_length": 256,
    
}

## Classification with segmentation masks

## Segmentation