# Demo for using Dataloaders

In [None]:
# local imports
from load_dataset import MOVi_Dataset, MOVi_ImageDataset

In [None]:
from torch.utils.data import DataLoader

## Load dataset

In [None]:
# Replace with your root path
ROOT_PATH = "/p/lustre2/marcou1/dsc2025/dsc25_data"


In [None]:
# This loads in the video dataset - one object per sample, all frames
movi_ds = MOVi_Dataset(root=ROOT_PATH, split = 'test', n_frames = 8, n_samples=30)

In [None]:
# Check sample and dimensions
sample = next(iter(movi_ds))
print('frame', sample['frames'].shape) # 3 chans, 8 frames, 256x256 pixels
print('depth', sample['depths'].shape) # 1 chans (float), 8 frames, 256x256 pixels
print('mmasks',sample['modal_masks'].shape, sample['modal_masks'].dtype) # 1 chans, 8 frames, 256x256 pixels (int)
print(sample['amodal_masks'].shape, sample['modal_masks'].dtype) # 1 chans, 8 frames, 256x256 pixels (int)
print(sample['amodal_content'].shape, sample['amodal_content'].dtype) # 3 chans (float), 8 frames, 256x256 pixels (int)

sample

## Batch and pass to model

In [None]:
# Pass to torch and batch
dataloader = DataLoader(
    movi_ds,
    batch_size=4,     # Or whatever batch size you want
    shuffle=True,     # Shuffle for training
    num_workers=1     # Set >0 for faster loading if you have CPU cores
)

In [None]:
next(iter(dataloader)) # this gives one batch

In [None]:
len(dataloader) # 8 batches

### Iterate over batches

In [None]:
# Iterate over batches
# EXAMPLE: Using RGB content (frames) and modal masks as input,
# Amodal mask is output
# This is a reduced form of Task 2.1 - not using the amodal content
for batch in dataloader:
    # Select features from the batch (B)
    # Inputs
    frames = batch['frames']         # [B, 3, n_frames, H, W]
    modal_masks = batch['modal_masks']  # [B, 1, n_frames, H, W]

    # Output (target)
    amodal_masks = batch['amodal_masks']  # [B, 1, n_frames, H, W]

    # Combine inputs if needed (e.g., concatenate along channel dimension)
    # Need to concatenate depends on model architecture!
    # Example: Combine frames and modal_masks as input
    inputs = torch.cat([frames, modal_masks.float()], dim=1)  # [B, 4, n_frames, H, W]

    # Now you can pass `inputs` to the model and use `amodal_masks` as the target
    output = model(inputs)
    loss = loss_fn(output, amodal_masks)
    loss.backward()
    ...