# Testing Dataset Utilities

This notebook tests the updated `dataset_utils.py` implementation, including:
1. `process_dataset_to_streams`
2. `create_dataloader_with_streams`

In [None]:
# Import necessary libraries
import torch
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import sys
from pathlib import Path

# Add project root to path
project_root = Path('.').resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import our utilities
from src.transforms.dataset_utils import process_dataset_to_streams, create_dataloader_with_streams
from src.transforms.rgb_to_rgbl import RGBtoRGBL

print("✅ Imports successful")

In [None]:
# Create a synthetic dataset for testing
num_samples = 100
input_shape = (3, 32, 32)  # RGB images

# Generate random RGB tensors
rgb_data = torch.rand(num_samples, *input_shape)
# Generate random labels
labels = torch.randint(0, 10, (num_samples,))

# Create a TensorDataset
test_dataset = TensorDataset(rgb_data, labels)

print(f"✅ Created synthetic test dataset with {len(test_dataset)} samples")
print(f"   RGB data shape: {rgb_data.shape}")
print(f"   Labels shape: {labels.shape}")

# Show a sample image
plt.figure(figsize=(6, 6))
plt.imshow(rgb_data[0].permute(1, 2, 0).numpy())
plt.title("Sample Synthetic RGB Image")
plt.axis('off')
plt.show()

In [None]:
# Test process_dataset_to_streams
print("Testing process_dataset_to_streams...")
rgb_stream, brightness_stream, labels_tensor = process_dataset_to_streams(
    test_dataset, 
    batch_size=20,
    desc="Processing test data"
)

print(f"RGB stream shape: {rgb_stream.shape}")
print(f"Brightness stream shape: {brightness_stream.shape}")
print(f"Labels shape: {labels_tensor.shape}")

# Visualize a sample RGB image and its corresponding brightness channel
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# RGB image
axes[0].imshow(rgb_stream[0].permute(1, 2, 0).numpy())
axes[0].set_title("RGB Image")
axes[0].axis('off')

# Brightness channel
axes[1].imshow(brightness_stream[0][0].numpy(), cmap='gray')
axes[1].set_title("Brightness Channel")
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Test create_dataloader_with_streams
print("Testing create_dataloader_with_streams...")
dataloader = create_dataloader_with_streams(
    test_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,  # Use 0 workers for this test
    pin_memory=False
)

# Get a batch
print("Getting a batch from the dataloader...")
rgb_batch, brightness_batch, labels_batch = next(iter(dataloader))

print(f"RGB batch shape: {rgb_batch.shape}")
print(f"Brightness batch shape: {brightness_batch.shape}")
print(f"Labels batch shape: {labels_batch.shape}")

# Visualize the first 4 images from the batch
fig, axes = plt.subplots(2, 4, figsize=(15, 8))

for i in range(4):
    # RGB image
    axes[0, i].imshow(rgb_batch[i].permute(1, 2, 0).numpy())
    axes[0, i].set_title(f"RGB - Class: {labels_batch[i].item()}")
    axes[0, i].axis('off')
    
    # Brightness channel
    axes[1, i].imshow(brightness_batch[i][0].numpy(), cmap='gray')
    axes[1, i].set_title(f"Brightness - Class: {labels_batch[i].item()}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("\n✅ All tests completed successfully!")

## Conclusion

The updated implementation of `dataset_utils.py` works correctly:

1. **Fixed Pickling Issue**: By moving the `collate_with_streams` function outside of `create_dataloader_with_streams`, we fixed the pickling issue when using multiple workers.

2. **Two Processing Options**:
   - `process_dataset_to_streams`: Processes the entire dataset upfront (good for smaller datasets)
   - `create_dataloader_with_streams`: Processes data on-the-fly (best for large datasets)

3. **Memory Efficiency**: The on-the-fly processing with `create_dataloader_with_streams` is more memory-efficient for large datasets.