### Testing VisionPatchDataModule.py

#### Test with Raw Data

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 40]

In [3]:
import radio as rio
from radio.data.datamodules import BrainAgingPredictionDataModule
from radio.data.datautils import get_first_batch, plot_batch

#### Test with Raw Data

In [4]:
data = rio.data.BrainAgingPredictionPatchDataModule(
    root='/media/cerebro/Workspaces/Students/Eduardo_Diniz/Studies',
    data_dir='processed_data',
    step='step01_structural_processing',
    use_augmentation=False,
    use_preprocessing=False,
    batch_size=32,
    intensities=["T1", "FLAIR"],
    labels=[],
    patch_size=(96, 96, 1),
    queue_max_length=300,
    samples_per_volume=10,
)

DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.


In [5]:
data.prepare_data()
data.setup(stage='fit')

DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.


In [6]:
print(f"Training: {data.size_train} patches.")
print(f"Validation: {data.size_val} patches.")

Training: 4960 patches.
Validation: 1230 patches.


#### Get a Train and a Test Batch

In [7]:
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()

In [8]:
# Takes a long time if resampling is required.
# Preprocessing beforehand is therefore recommended to speed up training.
train_batch = get_first_batch(train_dataloader)
val_batch = get_first_batch(val_dataloader)

In [11]:
print(f"Batch keys: {train_batch.keys()}")
print(f"Sample shape: {train_batch['T1']['data'].shape}")
print(f"Image keys: {train_batch['T1'].keys()}")
print(f"Subject IDs: {train_batch['subj_id']}")
print(f"Scan IDs: {train_batch['scan_id']}")
print(f"Location Shape: {train_batch['location'].shape}")

Batch keys: dict_keys(['subj_id', 'scan_id', 'T1', 'FLAIR', 'location'])
Sample shape: torch.Size([32, 1, 96, 96, 1])
Image keys: dict_keys(['data', 'affine', 'path', 'stem', 'type'])
Subject IDs: ['ABD_SS_0078', 'ABD_AJ_0144', 'ABD_AJ_0128', 'ABD_SS_0054', 'ABD_BS_0009', 'ABD_AJ_0202', 'ABD_SS_0054', 'ABD_AJ_0128', 'ABD_SW_0034', 'ABD_SS_0076', 'ABD_SW_0040', 'ABD_SS_0054', 'ABD_AJ_0192', 'ABD_AJ_0029', 'ABD_AJ_0128', 'ABD_AJ_0192', 'ABD_SS_0060', 'ABD_AJ_0141', 'ABD_AJ_0192', 'ABD_SS_0078', 'ABD_SS_0064', 'ABD_AJ_0030', 'ABD_AJ_0029', 'ABD_SS_0060', 'ABD_SS_0064', 'ABD_SW_0069', 'ABD_SS_0068', 'ABD_AJ_0144', 'ABD_AJ_0116', 'ABD_AJ_0141', 'ABD_AJ_0116', 'ABD_AJ_0029']
Scan IDs: ['scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan002', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan001', 'scan002', 'scan002', 'scan001', 'scan001', 'scan002', 'scan001', 'scan002', 'sc

#### Plot Example Samples from Batches

In [None]:
plot_batch(train_batch)