### Testing VisionPatchDataModule.py

#### Test with Raw Data

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 40]

In [3]:
import radio as rio
from radio.data.datautils import get_first_batch, plot_batch

#### Test with Raw Data

In [37]:
data = rio.data.VisionPatchDataModule(
    use_augmentation=False,
    resample=True,
    batch_size=32,
    intensities=["T1", "FLAIR"],
    labels=[],
    patch_size=(96, 96, 1),
    queue_max_length=300,
    samples_per_volume=10,
)   

In [38]:
data.prepare_data()
data.setup(stage='fit')

In [39]:
print(f"Training: {data.size_train} patches.")
print(f"Validation: {data.size_val} patches.")

Training: 4952 patches.
Validation: 1238 patches.


#### Get a Train and a Test Batch

In [7]:
train_dataloaders = data.train_dataloader()
val_dataloaders = data.val_dataloader()

In [8]:
# Takes a long time if resampling is required.
# Preprocessing beforehand is therefore recommended to speed up training.
train_batch = get_first_batch(train_dataloaders[0])
val_batch = get_first_batch(val_dataloaders[0])

In [9]:
print(f"Batch keys: {train_batch.keys()}")
print(f"Sample shape: {train_batch['T1']['data'].shape}")
print(f"Image keys: {train_batch['T1'].keys()}")
print(f"Subject IDs: {train_batch['subj_id']}")
print(f"Scan IDs: {train_batch['scan_id']}")

Batch keys: dict_keys(['subj_id', 'scan_id', 'T1', 'FLAIR', 'location'])
Sample shape: torch.Size([32, 1, 96, 96, 1])
Image keys: dict_keys(['data', 'affine', 'path', 'stem', 'type'])
Subject IDs: ['ABD_IH_0066', 'ABD_AJ_0052', 'ABD_SS_0021', 'ABD_AJ_0079', 'ABD_IH_0100', 'ABD_AJ_0099', 'ABD_AJ_0155', 'ABD_IH_0066', 'ABD_AJ_0105', 'ABD_AJ_0234', 'ABD_SW_0052', 'ABD_AJ_0022', 'ABD_IH_0066', 'ABD_SS_0021', 'ABD_SS_0090', 'ABD_AJ_0022', 'ABD_SW_0084', 'ABD_AJ_0052', 'ABD_SW_0175', 'ABD_AJ_0054', 'ABD_AJ_0105', 'ABD_SW_0150', 'ABD_AJ_0054', 'ABD_IH_0103', 'ABD_SW_0150', 'ABD_IH_0100', 'ABD_IH_0100', 'ABD_AJ_0079', 'ABD_BS_0009', 'ABD_AJ_0052', 'ABD_AJ_0138', 'ABD_SW_0062']
Scan IDs: ['scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'sc

#### Plot Example Samples from Batches

In [None]:
plot_batch(train_batch)

#### Test with Preprocessed Data

In [22]:
data = rio.data.VisionPatchDataModule(
    root='~/LocalCerebro/Studies',
    data_dir='Public/preprocessed_data',
    use_augmentation=False,
    use_preprocessing=False,
    batch_size=32,
    intensities=["T1", "FLAIR"],
    labels=[],
    patch_size=(96, 96, 1),
    queue_max_length=300,
    samples_per_volume=10,
)

DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.


In [23]:
data.prepare_data()
data.setup(stage='fit')

DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.


In [24]:
print(f"Training: {data.size_train} patches.")
print(f"Validation: {data.size_val} patches.")

Training: 4960 patches.
Validation: 1230 patches.


#### Get a Train and a Test Batch

In [25]:
train_dataloaders = data.train_dataloader()
val_dataloaders = data.val_dataloader()

In [26]:
# Takes a long time if resampling is required.
# Preprocessing beforehand is therefore recommended to speed up training.
train_batch = get_first_batch(train_dataloaders[0])
val_batch = get_first_batch(val_dataloaders[0])

In [27]:
print(f"Batch keys: {train_batch.keys()}")
print(f"Sample shape: {train_batch['T1']['data'].shape}")
print(f"Image keys: {train_batch['T1'].keys()}")
print(f"Subject IDs: {train_batch['subj_id']}")
print(f"Scan IDs: {train_batch['scan_id']}")

Batch keys: dict_keys(['subj_id', 'scan_id', 'T1', 'FLAIR', 'location'])
Sample shape: torch.Size([32, 1, 96, 96, 1])
Image keys: dict_keys(['data', 'affine', 'path', 'stem', 'type'])
Subject IDs: ['ABD_SW_0088', 'ABD_SW_0034', 'ABD_AJ_0112', 'ABD_AJ_0112', 'ABD_SW_0088', 'ABD_SW_0052', 'ABD_IH_0111', 'ABD_SW_0010', 'ABD_AJ_0009', 'ABD_AJ_0003', 'ABD_AJ_0230', 'ABD_SW_0069', 'ABD_IH_0088', 'ABD_IH_0007', 'ABD_AJ_0003', 'ABD_SW_0037', 'ABD_AJ_0230', 'ABD_AJ_0134', 'ABD_GJ_0146', 'ABD_BS_0026', 'ABD_IH_0007', 'ABD_SW_0111', 'ABD_AJ_0003', 'ABD_IH_0004', 'ABD_SW_0052', 'ABD_SW_0052', 'ABD_AJ_0048', 'ABD_AJ_0139', 'ABD_IH_0007', 'ABD_SS_0069', 'ABD_GJ_0146', 'ABD_GJ_0182']
Scan IDs: ['scan001', 'scan001', 'scan002', 'scan002', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'sc