In [26]:
# Import modular components
import os
# Fix tokenizers parallelism warning when using multiprocessing DataLoaders
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import logging
import pandas as pd
import numpy as np

# Add src directory to path for absolute imports
import sys
from pathlib import Path
# Go up 2 levels: notebooks/ -> dataloader/ -> ml_workflow/ -> src/
src_path = Path().absolute().parent.parent.parent
sys.path.insert(0, str(src_path))

from ml_workflow.constants import (GCS_BUCKET_NAME, GCS_METADATA_PATH, GCS_IMAGE_PREFIX, 
                                   LOCAL_METADATA_PATH, LOCAL_DATA_DIR, IMG_ID_COL)
from ml_workflow.utils import load_metadata, logger, stratified_split
from ml_workflow.dataloader import (create_dataloaders, ImageDataset, 
                                    get_basic_transform, get_train_transform, get_test_valid_transform)
from ml_workflow.dataloader.embedding_utils import load_or_compute_embeddings
import torch

# Set logging level
logger.setLevel(logging.INFO)
print("✓ All modules imported successfully!")

✓ All modules imported successfully!


In [27]:
# Use a very small sample for testing (1% of data) to speed up embedding computation
sample_frac = 0.01  # 1% of data for quick testing
metadata_gcs_small = (metadata_gcs.groupby('label', group_keys=False)
                      .sample(frac=sample_frac, random_state=42)
                      .reset_index(drop=True))
logger.info(f"After sampling: {len(metadata_gcs_small):,} samples ({sample_frac*100:.1f}% of {len(metadata_gcs):,} total)")

# Filter out classes with too few samples for stratified splitting
# For test_size=0.2, we need at least 5 samples per class (1 in test, 4 in train)
min_samples_per_class = 5
label_counts = metadata_gcs_small['label'].value_counts()
valid_labels = label_counts[label_counts >= min_samples_per_class].index
metadata_gcs_small = metadata_gcs_small[metadata_gcs_small['label'].isin(valid_labels)].reset_index(drop=True)
dropped_classes = len(label_counts) - len(valid_labels)
logger.info(f"Filtered out {dropped_classes} classes with < {min_samples_per_class} samples")
logger.info(f"Final dataset: {len(metadata_gcs_small):,} samples across {len(valid_labels)} classes")

# Compute/load embeddings (REQUIRED for multi-modal model)
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
logger.info(f"Computing/loading embeddings on device: {device}")

# Embedding configuration (adjust as needed)
# Use a local path for small test embeddings to avoid GCS overhead
embedding_path = "embeddings_test_small.parquet"  # Local file for quick testing
embeddings = load_or_compute_embeddings(
    data=metadata_gcs_small[[IMG_ID_COL, 'text_desc']],
    path=embedding_path,
    model_name='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext',  # pubmedbert
    batch_size=32,
    max_length=512,
    device=device,
    pooling_strategy='mean',
    qwen_instr=""
)

# Merge embeddings with metadata (required for dataloader)
metadata_gcs_small = metadata_gcs_small.merge(embeddings, how="left", on=IMG_ID_COL, validate="1:1")
logger.info(f"Embeddings merged. Metadata shape: {metadata_gcs_small.shape}")

# Create config dictionaries for GCS dataloader
# Note: create_dataloaders expects data_config, training_config, and augmentation_config
data_config = {
    'use_local': False, 
    'img_prefix': '',
    'test_size': 0.2,
    'val_size': None,
    'seed': 42,
    'img_size': [224, 224]
}

training_config = {
    'batch_size': 32, 
    'num_workers': 8, 
    'prefetch_factor': 2,
    'compute_stats': True,
    'weighted_sampling': True
}

augmentation_config = {
    'brightness_jitter': 0.1,
    'contrast_jitter': 0.1,
    'saturation_jitter': 0.1,
    'hue_jitter': 0.05,
    'rotation_degrees': 20,
    'translate': [0.1, 0.1],
    'scale': [0.9, 1.1],
    'grayscale_prob': 0.1,
    'horizontal_flip_prob': 0.5,
    'vertical_flip_prob': 0.5
}

# Create dataloaders (works with both GCS and local data)
train_loader_gcs, val_loader_gcs, test_loader_gcs, info_gcs = create_dataloaders(
    metadata_df=metadata_gcs_small,
    img_prefix=img_prefix,  # Uses GCS or local path based on use_local flag
    data_config=data_config,
    training_config=training_config,
    augmentation_config=augmentation_config
)

print(f"\n{'='*50}")
print("GCS DataLoader (sampled) Created!")
print(f"{'='*50}")
print(f"Classes: {info_gcs['num_classes']}")
print(f"Training samples: {info_gcs['train_size']:,}")
print(f"Test samples: {info_gcs['test_size']:,}")
print(f"Train batches: {len(train_loader_gcs)}")
print(f"Test batches: {len(test_loader_gcs)}")
print(f"Mean: {[f'{m:.4f}' for m in info_gcs['mean']]}")
print(f"Std: {[f'{s:.4f}' for s in info_gcs['std']]}")
print(f"{'='*50}")

2025-11-11 15:57:03,483 - ml_workflow.utils - INFO - After sampling: 1,918 samples (1.0% of 191,523 total)
2025-11-11 15:57:03,485 - ml_workflow.utils - INFO - Filtered out 32 classes with < 5 samples
2025-11-11 15:57:03,485 - ml_workflow.utils - INFO - Final dataset: 1,849 samples across 49 classes
2025-11-11 15:57:03,486 - ml_workflow.utils - INFO - Computing/loading embeddings on device: mps
2025-11-11 15:57:25,071 - ml_workflow.utils - INFO - Embeddings merged. Metadata shape: (1849, 23)
2025-11-11 15:57:25,072 - ml_workflow.utils - INFO - Creating DataLoaders
2025-11-11 15:57:25,075 - ml_workflow.utils - INFO - Train samples: 1,479, Test samples: 370
2025-11-11 15:57:25,077 - ml_workflow.utils - INFO - Total classes: 49
2025-11-11 15:57:25,077 - ml_workflow.utils - INFO - Computing dataset statistics from all training data
2025-11-11 15:57:25,105 - ml_workflow.utils - INFO - Computing dataset statistics...
Computing stats: 100%|██████████| 47/47 [01:07<00:00,  1.43s/it]
2025-11-11


GCS DataLoader (sampled) Created!
Classes: 49
Training samples: 1,479
Test samples: 370
Train batches: 46
Test batches: 12
Mean: ['0.5920', '0.4656', '0.4251']
Std: ['0.1773', '0.1639', '0.1615']
