In [1]:
"""
MABe Challenge - Enhanced with Comprehensive Visualizations (Fault-Tolerant)
============================================================================
Complete implementation with detailed analytics and robust error handling
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import itertools
import warnings
import json
import os
import gc
from collections import defaultdict, Counter
from datetime import datetime
import time

# Import visualization libraries with error handling
try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False
    print("⚠️ Plotly not available - interactive visualizations disabled")

import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

from sklearn.base import ClassifierMixin, BaseEstimator, clone
from sklearn.model_selection import cross_val_predict, GroupKFold
from sklearn.pipeline import make_pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import polars as pl

import pandas as pd
from pathlib import Path
from typing import Union, List, Optional, Dict, Iterator
from datasets import Dataset, DatasetDict
import pyarrow as pa
import pyarrow.parquet as pq

warnings.filterwarnings('ignore')

# Enhanced visualization settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 15

# Configuration
validate_or_submit = 'submit'
verbose = True
create_visualizations = True
save_plots = True
plot_dir = 'figures'

if save_plots:
    os.makedirs(plot_dir, exist_ok=True)
    print(f"📁 Created visualization directory: {plot_dir}")

# Performance tracking
start_time = time.time()
performance_metrics = {
    'configurations_processed': 0,
    'single_mouse_batches': 0,
    'pair_batches': 0,
    'features_extracted': 0,
    'predictions_made': 0,
    'models_trained': 0,
    'actions_processed': 0
}

print("\n" + "="*80)
print(" "*15 + "🐭 MABe CHALLENGE - ENHANCED VISUALIZATION VERSION 🐭")
print("="*80)
print(f"📅 Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🔧 Mode: {validate_or_submit.upper()}")
print(f"📊 Visualizations: {'ENABLED' if create_visualizations else 'DISABLED'}")
print(f"💾 Save Plots: {'YES' if save_plots else 'NO'}")
print("="*80 + "\n")

  from .autonotebook import tqdm as notebook_tqdm


📁 Created visualization directory: figures

               🐭 MABe CHALLENGE - ENHANCED VISUALIZATION VERSION 🐭
📅 Start Time: 2025-10-02 11:39:01
🔧 Mode: SUBMIT
📊 Visualizations: ENABLED
💾 Save Plots: YES



In [2]:
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')

In [3]:
train.head()

Unnamed: 0,lab_id,video_id,mouse1_strain,mouse1_color,mouse1_sex,mouse1_id,mouse1_age,mouse1_condition,mouse2_strain,mouse2_color,...,pix_per_cm_approx,video_width_pix,video_height_pix,arena_width_cm,arena_height_cm,arena_shape,arena_type,body_parts_tracked,behaviors_labeled,tracking_method
0,AdaptableSnail,44566106,CD-1 (ICR),white,male,10.0,8-12 weeks,wireless device,CD-1 (ICR),white,...,16.0,1228,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
1,AdaptableSnail,143861384,CD-1 (ICR),white,male,3.0,8-12 weeks,,CD-1 (ICR),white,...,9.7,968,608,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
2,AdaptableSnail,209576908,CD-1 (ICR),white,male,7.0,8-12 weeks,,CD-1 (ICR),white,...,16.0,1266,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
3,AdaptableSnail,278643799,CD-1 (ICR),white,male,11.0,8-12 weeks,wireless device,CD-1 (ICR),white,...,16.0,1224,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
4,AdaptableSnail,351967631,CD-1 (ICR),white,male,14.0,8-12 weeks,,CD-1 (ICR),white,...,16.0,1204,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut


In [4]:
test.head()

Unnamed: 0,lab_id,video_id,mouse1_strain,mouse1_color,mouse1_sex,mouse1_id,mouse1_age,mouse1_condition,mouse2_strain,mouse2_color,...,pix_per_cm_approx,video_width_pix,video_height_pix,arena_width_cm,arena_height_cm,arena_shape,arena_type,body_parts_tracked,behaviors_labeled,tracking_method
0,AdaptableSnail,438887472,CD-1 (ICR),white,male,13.0,8-12 weeks,wireless device,CD-1 (ICR),white,...,16.0,1214,1090,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut


In [5]:
train.shape, test.shape

((8790, 38), (1, 38))

In [6]:
x = pd.read_parquet('data/train/annotation/AdaptableSnail/44566106.parquet')
x.head()

Unnamed: 0,agent_id,target_id,action,start_frame,stop_frame
0,2,2,rear,4,139
1,4,2,avoid,13,52
2,4,4,rear,121,172
3,3,3,rear,156,213
4,4,4,rear,208,261


In [7]:
y = pd.read_parquet('data/train/annotation/AdaptableSnail/44566106.parquet')
y.head(), y.shape

(   agent_id  target_id action  start_frame  stop_frame
 0         2          2   rear            4         139
 1         4          2  avoid           13          52
 2         4          4   rear          121         172
 3         3          3   rear          156         213
 4         4          4   rear          208         261,
 (342, 5))

In [None]:
import pandas as pd
from pathlib import Path
from typing import Union, List, Optional

def read_parquet_data_recursive(
    base_path: Union[str, Path],
    file_pattern: str = "*.parquet",
    columns: Optional[List[str]] = None,
    add_metadata: bool = True
) -> pd.DataFrame:
    """
    Recursively read all parquet files from a directory structure.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory to start recursive search
    file_pattern : str, default "*.parquet"
        Pattern to match files (e.g., "*.parquet", "*.csv")
    columns : list, optional
        Specific columns to read from parquet files
    add_metadata : bool, default True
        Whether to add metadata columns (directory, subdirectory, filename)
        
    Returns:
    --------
    pd.DataFrame
        Combined dataframe with all parquet data and metadata columns
    """
    base_path = Path(base_path)
    
    if not base_path.exists():
        raise ValueError(f"Path does not exist: {base_path}")
    
    # Recursively find all parquet files
    parquet_files = list(base_path.rglob(file_pattern))
    
    if not parquet_files:
        raise ValueError(f"No files matching '{file_pattern}' found in {base_path}")
    
    print(f"Found {len(parquet_files)} parquet files to process...")
    
    # Read and combine all parquet files
    dfs = []
    for parquet_file in parquet_files:
        try:
            # Read the parquet file
            df = pd.read_parquet(parquet_file, columns=columns)
            
            if add_metadata:
                # Get relative path from base directory
                relative_path = parquet_file.relative_to(base_path)
                
                # Extract directory information
                parts = relative_path.parts
                
                # Add metadata columns
                df['source_filename'] = parquet_file.name
                df['source_directory'] = parts[0] if len(parts) > 1 else ''
                df['source_subdirectory'] = parts[1] if len(parts) > 2 else ''
                df['source_full_path'] = str(relative_path)
                
            dfs.append(df)
            
        except Exception as e:
            print(f"Error reading {parquet_file}: {e}")
            continue
    
    if not dfs:
        raise ValueError("No parquet files could be successfully read")
    
    # Combine all dataframes
    combined_df = pd.concat(dfs, ignore_index=True)
    
    print(f"Successfully combined {len(dfs)} files into dataframe with {len(combined_df)} rows")
    
    return combined_df


def read_parquet_with_custom_metadata(
    base_path: Union[str, Path],
    max_depth: Optional[int] = None
) -> pd.DataFrame:
    """
    Advanced version that handles arbitrary directory depth.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory to start recursive search
    max_depth : int, optional
        Maximum directory depth to traverse (None for unlimited)
        
    Returns:
    --------
    pd.DataFrame
        Combined dataframe with flexible metadata columns
    """
    base_path = Path(base_path)
    parquet_files = list(base_path.rglob("*.parquet"))
    
    dfs = []
    for parquet_file in parquet_files:
        try:
            # Check depth limit
            relative_path = parquet_file.relative_to(base_path)
            depth = len(relative_path.parts) - 1  # Subtract filename
            
            if max_depth is not None and depth > max_depth:
                continue
            
            df = pd.read_parquet(parquet_file)
            
            # Add comprehensive metadata
            df['filename'] = parquet_file.name
            df['file_stem'] = parquet_file.stem  # Filename without extension
            
            # Add each directory level as separate column
            parts = relative_path.parent.parts
            for i, part in enumerate(parts):
                df[f'dir_level_{i}'] = part
            
            # Add full relative path
            df['full_relative_path'] = str(relative_path)
            df['depth_level'] = depth
            
            dfs.append(df)
            
        except Exception as e:
            print(f"Skipping {parquet_file}: {e}")
            continue
    
    return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()


# Example usage for your data structure:
"""
# Basic usage - reads all parquet files recursively
df = read_parquet_data_recursive('_provided/test_tracking')

# Access metadata
print(df[['source_directory', 'source_subdirectory', 'source_filename']].head())

# Filter by specific category
adaptable_snail_data = df[df['source_directory'] == 'AdaptableSnail']

# Advanced usage with dynamic depth handling
df_advanced = read_parquet_with_custom_metadata('_provided/test_tracking', max_depth=2)

# See all unique directories
print(df['source_directory'].unique())
"""


"\n# Basic usage - reads all parquet files recursively\ndf = read_parquet_data_recursive('_provided/test_tracking')\n\n# Access metadata\nprint(df[['source_directory', 'source_subdirectory', 'source_filename']].head())\n\n# Filter by specific category\nadaptable_snail_data = df[df['source_directory'] == 'AdaptableSnail']\n\n# Advanced usage with dynamic depth handling\ndf_advanced = read_parquet_with_custom_metadata('_provided/test_tracking', max_depth=2)\n\n# See all unique directories\nprint(df['source_directory'].unique())\n"

In [9]:
# # Basic usage - reads all parquet files recursively
# df_ta = read_parquet_data_recursive('data/train_annotation')
# df_ta.head()

In [10]:
# df_tt = read_parquet_data_recursive('data/train_tracking')
# df_tt

In [11]:
# df_test_tracking = read_parquet_data_recursive('data/test_tracking')
# df_test_tracking.head()

In [12]:
def parquet_generator(
    base_path: Union[str, Path],
    file_pattern: str = "*.parquet",
    columns: Optional[List[str]] = None,
    add_metadata: bool = True,
    batch_size: int = 1000
) -> Iterator[Dict]:
    """
    Memory-efficient generator that yields rows from parquet files.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory to start recursive search
    file_pattern : str
        Pattern to match files
    columns : list, optional
        Specific columns to read
    add_metadata : bool
        Whether to add metadata columns
    batch_size : int
        Number of rows to read at once from each file
        
    Yields:
    -------
    dict
        Individual row as dictionary
    """
    base_path = Path(base_path)
    parquet_files = list(base_path.rglob(file_pattern))
    
    print(f"Found {len(parquet_files)} parquet files to process...")
    
    for parquet_file in parquet_files:
        try:
            # Use PyArrow to read in batches for memory efficiency
            parquet_reader = pq.ParquetFile(parquet_file)
            
            # Get relative path info
            relative_path = parquet_file.relative_to(base_path)
            parts = relative_path.parts
            
            # Read file in batches
            for batch in parquet_reader.iter_batches(batch_size=batch_size, columns=columns):
                # Convert batch to pandas for easier manipulation
                df_batch = batch.to_pandas()
                
                if add_metadata:
                    df_batch['source_filename'] = parquet_file.name
                    df_batch['source_directory'] = parts[0] if len(parts) > 1 else ''
                    df_batch['source_subdirectory'] = parts[1] if len(parts) > 2 else ''
                    df_batch['source_full_path'] = str(relative_path)
                
                # Yield each row individually
                for _, row in df_batch.iterrows():
                    yield row.to_dict()
                    
        except Exception as e:
            print(f"Error reading {parquet_file}: {e}")
            continue

In [13]:
def create_unified_dataset_dict_efficient(
    base_path: Union[str, Path],
    columns: Optional[List[str]] = None,
    writer_batch_size: int = 1000,
    batch_size: int = 1000
) -> DatasetDict:
    """
    Memory-efficient version using generators to create DatasetDict.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory (should contain 'train' and 'test' folders)
    columns : list, optional
        Specific columns to read from parquet files
    writer_batch_size : int
        Number of rows to write at once (controls memory usage)
    batch_size : int
        Number of rows to read at once from parquet files
        
    Returns:
    --------
    DatasetDict
        Single DatasetDict with all splits
    """
    base_path = Path(base_path)
    dataset_dict = {}
    
    structure = {
        'train': ['annotation', 'tracking'],
        'test': ['tracking']
    }
    
    for split, subsets in structure.items():
        for subset in subsets:
            split_name = f"{split}_{subset}"
            subset_path = base_path / split / subset
            
            if not subset_path.exists():
                print(f"Warning: Path does not exist: {subset_path}")
                continue
            
            print(f"\nProcessing {split_name} with generator...")
            
            # Create generator for this split/subset
            def gen():
                for row in parquet_generator(
                    subset_path,
                    columns=columns,
                    add_metadata=True,
                    batch_size=batch_size
                ):
                    row['split'] = split
                    row['subset'] = subset
                    yield row
            
            # Use from_generator for memory-efficient loading
            dataset_dict[split_name] = Dataset.from_generator(
                gen,
                writer_batch_size=writer_batch_size
            )
    
    return DatasetDict(dataset_dict)

In [14]:

def upload_to_huggingface(
    dataset_dict: DatasetDict,
    repo_name: str,
    private: bool = False,
    token: Optional[str] = None,
    max_shard_size: str = "500MB"
):
    """
    Upload DatasetDict to Hugging Face Hub with memory-efficient settings.
    
    Parameters:
    -----------
    dataset_dict : DatasetDict
        The dataset dictionary to upload
    repo_name : str
        Repository name in format "username/dataset-name"
    private : bool
        Whether to make the dataset private
    token : str, optional
        Hugging Face authentication token
    max_shard_size : str
        Maximum size of each shard file (e.g., "500MB", "1GB")
    """
    print(f"\nUploading dataset to {repo_name}...")
    print(f"Dataset splits: {list(dataset_dict.keys())}")
    
    for split_name, dataset in dataset_dict.items():
        print(f"  - {split_name}: {len(dataset)} rows")
    
    # Push to hub with sharding for large datasets
    dataset_dict.push_to_hub(
        repo_name,
        private=private,
        token=token,
        max_shard_size=max_shard_size
    )
    
    print(f"\n✓ Successfully uploaded to: https://huggingface.co/datasets/{repo_name}")

In [15]:
def create_dataset_from_pyarrow_tables(
    base_path: Union[str, Path],
    columns: Optional[List[str]] = None
) -> DatasetDict:
    """
    Most memory-efficient version using PyArrow tables directly.
    Avoids pandas entirely and works with Arrow format throughout.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory
    columns : list, optional
        Specific columns to read
        
    Returns:
    --------
    DatasetDict
        Single DatasetDict with all splits
    """
    base_path = Path(base_path)
    dataset_dict = {}
    
    structure = {
        'train': ['annotation', 'tracking'],
        'test': ['tracking']
    }
    
    for split, subsets in structure.items():
        for subset in subsets:
            split_name = f"{split}_{subset}"
            subset_path = base_path / split / subset
            
            if not subset_path.exists():
                print(f"Warning: Path does not exist: {subset_path}")
                continue
            
            print(f"\nProcessing {split_name} with PyArrow...")
            
            # Find all parquet files
            parquet_files = list(subset_path.rglob("*.parquet"))
            tables = []
            
            for parquet_file in parquet_files:
                try:
                    # Read parquet file as PyArrow table
                    table = pq.read_table(parquet_file, columns=columns)
                    
                    # Add metadata columns
                    relative_path = parquet_file.relative_to(subset_path)
                    parts = relative_path.parts
                    
                    # Add metadata as new columns to the table
                    table = table.append_column(
                        'source_filename',
                        pa.array([parquet_file.name] * len(table))
                    )
                    table = table.append_column(
                        'source_directory',
                        pa.array([parts[0] if len(parts) > 1 else ''] * len(table))
                    )
                    table = table.append_column(
                        'source_subdirectory',
                        pa.array([parts[1] if len(parts) > 2 else ''] * len(table))
                    )
                    table = table.append_column(
                        'source_full_path',
                        pa.array([str(relative_path)] * len(table))
                    )
                    table = table.append_column(
                        'split',
                        pa.array([split] * len(table))
                    )
                    table = table.append_column(
                        'subset',
                        pa.array([subset] * len(table))
                    )
                    
                    tables.append(table)
                    
                except Exception as e:
                    print(f"Error reading {parquet_file}: {e}")
                    continue
            
            if tables:
                # Concatenate all tables
                combined_table = pa.concat_tables(tables)
                # Create dataset directly from PyArrow table
                dataset_dict[split_name] = Dataset(combined_table)
            
    return DatasetDict(dataset_dict)



In [16]:
import pandas as pd
from pathlib import Path
from typing import Union, List, Optional, Dict, Iterator, Tuple
from datasets import Dataset, DatasetDict
import pyarrow as pa
import pyarrow.parquet as pq
from collections import defaultdict


def scan_schemas(
    base_path: Union[str, Path],
    file_pattern: str = "*.parquet"
) -> Dict[str, List[Tuple[Path, pa.Schema]]]:
    """
    Scan all parquet files and collect their schemas for validation.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory to scan
    file_pattern : str
        Pattern to match files
        
    Returns:
    --------
    Dict[str, List[Tuple[Path, pa.Schema]]]
        Dictionary mapping split names to list of (file_path, schema) tuples
    """
    base_path = Path(base_path)
    schema_map = defaultdict(list)
    
    structure = {
        'train': ['annotation', 'tracking'],
        'test': ['tracking']
    }
    
    for split, subsets in structure.items():
        for subset in subsets:
            split_name = f"{split}_{subset}"
            subset_path = base_path / split / subset
            
            if not subset_path.exists():
                continue
            
            parquet_files = list(subset_path.rglob(file_pattern))
            print(f"Scanning {len(parquet_files)} files in {split_name}...")
            
            for parquet_file in parquet_files:
                try:
                    parquet_meta = pq.read_metadata(parquet_file)
                    schema = parquet_meta.schema.to_arrow_schema()
                    schema_map[split_name].append((parquet_file, schema))
                except Exception as e:
                    print(f"Error reading schema from {parquet_file}: {e}")
    
    return dict(schema_map)


def validate_and_unify_schemas(
    schema_map: Dict[str, List[Tuple[Path, pa.Schema]]],
    promote_options: str = "permissive"
) -> Dict[str, pa.Schema]:
    """
    Validate schemas and create unified schemas for each split.
    
    Parameters:
    -----------
    schema_map : Dict[str, List[Tuple[Path, pa.Schema]]]
        Output from scan_schemas()
    promote_options : str
        Either 'default' or 'permissive'
        - 'default': Only null can be unified with another type
        - 'permissive': Types are promoted to the greater common denominator
        
    Returns:
    --------
    Dict[str, pa.Schema]
        Unified schema for each split
    """
    unified_schemas = {}
    
    for split_name, file_schemas in schema_map.items():
        print(f"\n{'='*60}")
        print(f"Validating schemas for {split_name}")
        print(f"{'='*60}")
        
        if not file_schemas:
            continue
        
        # Extract just the schemas
        schemas = [schema for _, schema in file_schemas]
        
        # Check for schema inconsistencies
        print(f"Found {len(schemas)} files")
        
        # Group files by schema signature
        schema_groups = defaultdict(list)
        for file_path, schema in file_schemas:
            # Create a signature based on field names and types
            signature = tuple((field.name, str(field.type)) for field in schema)
            schema_groups[signature].append(file_path)
        
        if len(schema_groups) > 1:
            print(f"\n⚠️  WARNING: Found {len(schema_groups)} different schemas!")
            for i, (signature, files) in enumerate(schema_groups.items(), 1):
                print(f"\nSchema variant {i}: ({len(files)} files)")
                # Show just the schema fields
                for field_name, field_type in signature:
                    print(f"  - {field_name}: {field_type}")
                print(f"  Example file: {files[0]}")
        else:
            print("✓ All schemas are identical")
        
        # Attempt to unify schemas
        try:
            unified_schema = pa.unify_schemas(schemas, promote_options=promote_options)
            unified_schemas[split_name] = unified_schema
            
            print(f"\n✓ Successfully unified schema:")
            for field in unified_schema:
                print(f"  - {field.name}: {field.type}")
                
        except pa.ArrowInvalid as e:
            print(f"\n❌ ERROR: Cannot unify schemas for {split_name}")
            print(f"Error: {e}")
            print("\nTrying with 'permissive' mode...")
            
            try:
                unified_schema = pa.unify_schemas(schemas, promote_options="permissive")
                unified_schemas[split_name] = unified_schema
                print(f"✓ Successfully unified with permissive mode")
                for field in unified_schema:
                    print(f"  - {field.name}: {field.type}")
            except pa.ArrowInvalid as e2:
                print(f"❌ Still failed: {e2}")
                raise
    
    return unified_schemas


def create_dataset_from_pyarrow_tables_with_schema(
    base_path: Union[str, Path],
    unified_schemas: Dict[str, pa.Schema],
    columns: Optional[List[str]] = None
) -> DatasetDict:
    """
    Memory-efficient version using PyArrow tables with unified schema enforcement.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory
    unified_schemas : Dict[str, pa.Schema]
        Unified schemas from validate_and_unify_schemas()
    columns : list, optional
        Specific columns to read
        
    Returns:
    --------
    DatasetDict
        Single DatasetDict with all splits
    """
    base_path = Path(base_path)
    dataset_dict = {}
    
    structure = {
        'train': ['annotation', 'tracking'],
        'test': ['tracking']
    }
    
    for split, subsets in structure.items():
        for subset in subsets:
            split_name = f"{split}_{subset}"
            subset_path = base_path / split / subset
            
            if not subset_path.exists():
                print(f"Warning: Path does not exist: {subset_path}")
                continue
            
            if split_name not in unified_schemas:
                print(f"Warning: No unified schema for {split_name}")
                continue
            
            print(f"\nProcessing {split_name} with unified schema...")
            
            # Get the base unified schema (without metadata columns)
            base_schema = unified_schemas[split_name]
            
            # Find all parquet files
            parquet_files = list(subset_path.rglob("*.parquet"))
            tables = []
            
            for parquet_file in parquet_files:
                try:
                    # Read parquet file as PyArrow table
                    table = pq.read_table(parquet_file, columns=columns)
                    
                    # Cast to unified schema to ensure consistency
                    # Only cast columns that are in the base schema
                    cols_to_cast = [field.name for field in base_schema if field.name in table.column_names]
                    if cols_to_cast:
                        cast_schema = pa.schema([
                            base_schema.field(name) for name in cols_to_cast
                        ])
                        # Cast only the existing columns
                        table = table.cast(cast_schema)
                    
                    # Add metadata columns
                    relative_path = parquet_file.relative_to(subset_path)
                    parts = relative_path.parts
                    
                    table = table.append_column(
                        'source_filename',
                        pa.array([parquet_file.name] * len(table))
                    )
                    table = table.append_column(
                        'source_directory',
                        pa.array([parts[0] if len(parts) > 1 else ''] * len(table))
                    )
                    table = table.append_column(
                        'source_subdirectory',
                        pa.array([parts[1] if len(parts) > 2 else ''] * len(table))
                    )
                    table = table.append_column(
                        'source_full_path',
                        pa.array([str(relative_path)] * len(table))
                    )
                    table = table.append_column(
                        'split',
                        pa.array([split] * len(table))
                    )
                    table = table.append_column(
                        'subset',
                        pa.array([subset] * len(table))
                    )
                    
                    tables.append(table)
                    
                except Exception as e:
                    print(f"Error reading {parquet_file}: {e}")
                    continue
            
            if tables:
                # Concatenate all tables
                combined_table = pa.concat_tables(tables)
                # Create dataset directly from PyArrow table
                dataset_dict[split_name] = Dataset(combined_table)
            
    return DatasetDict(dataset_dict)


def validate_schemas_and_create_dataset(
    base_path: Union[str, Path],
    columns: Optional[List[str]] = None,
    promote_options: str = "permissive"
) -> DatasetDict:
    """
    Complete workflow: scan schemas, validate, unify, and create dataset.
    
    Parameters:
    -----------
    base_path : str or Path
        Base directory
    columns : list, optional
        Specific columns to read
    promote_options : str
        'default' or 'permissive' for schema unification
        
    Returns:
    --------
    DatasetDict
        Complete dataset dictionary
    """
    print("="*60)
    print("STEP 1: Scanning schemas")
    print("="*60)
    schema_map = scan_schemas(base_path)
    
    print("\n" + "="*60)
    print("STEP 2: Validating and unifying schemas")
    print("="*60)
    unified_schemas = validate_and_unify_schemas(schema_map, promote_options=promote_options)
    
    print("\n" + "="*60)
    print("STEP 3: Creating dataset with unified schemas")
    print("="*60)
    dataset_dict = create_dataset_from_pyarrow_tables_with_schema(
        base_path,
        unified_schemas,
        columns=columns
    )
    
    return dataset_dict


In [17]:
# ============================================================================
# MAIN USAGE - CHOOSE ONE APPROACH
# ============================================================================

if __name__ == "__main__":
    
    # OPTION 1: Generator-based (MOST MEMORY EFFICIENT)
    # Best for extremely large datasets (100GB+)
    # Processes data in small batches and writes to disk incrementally
    print("="*60)
    print("OPTION 1: Generator-based approach (most memory efficient)")
    print("="*60)
    
    dataset_dict = create_unified_dataset_dict_efficient(
        base_path="data",
        writer_batch_size=500,  # Write to disk every 500 rows
        batch_size=500  # Read 500 rows at a time from each parquet file
    )
    
    # OPTION 2: PyArrow tables (FASTER, moderately memory efficient)
    # Best for large datasets (10-100GB) where speed matters
    # Avoids pandas overhead but loads more data into memory
    # print("="*60)
    # print("OPTION 2: PyArrow table approach (fast and efficient)")
    # print("="*60)
    # 
    # dataset_dict = create_dataset_from_pyarrow_tables(
    #     base_path="data"
    # )
    
    # Inspect the dataset
    print("\n" + "="*60)
    print("Dataset Structure:")
    print("="*60)
    print(dataset_dict)
    
    # Preview each split
    for split_name in dataset_dict.keys():
        print(f"\n{split_name}:")
        print(f"  Columns: {dataset_dict[split_name].column_names}")
        print(f"  Rows: {len(dataset_dict[split_name])}")
        print(f"  First row: {dataset_dict[split_name][0]}")
    
    # Upload to Hugging Face with memory-efficient sharding
    upload_to_huggingface(
        dataset_dict=dataset_dict,
        repo_name="mxngjxa/MABe-2025",
        private=True,
        max_shard_size="500MB"  # Split large files into 500MB chunks
    )


OPTION 1: Generator-based approach (most memory efficient)

Processing train_annotation with generator...

Processing train_tracking with generator...


Generating train split: 1548 examples [00:00, 15399.86 examples/s]

Found 8790 parquet files to process...


Generating train split: 288187614 examples [3:06:03, 25814.04 examples/s]


KeyboardInterrupt: 