In [None]:
import polars as pl
import torch

from pathlib import Path

In [None]:
paths = [Path("../temp/data/interim/card_amr.parquet"), Path("../temp/data/interim/fine_tuning.parquet")]

In [None]:
datasets = [pl.read_parquet(path) for path in paths]

In [None]:
def split_datasets(
    datasets: list[pl.DataFrame], ratios=(0.8, 0.1, 0.1), seed=None
) -> tuple[
    list[pl.DataFrame | None], list[pl.DataFrame | None], list[pl.DataFrame | None]
]:
    """
    Splits multiple datasets into train, val, and test sets globally
    while preserving per-dataset access (no need to concat them physically).
    """
    if seed is not None:
        torch.manual_seed(seed)

    # Get heights
    heights = torch.tensor([ds.height for ds in datasets], dtype=torch.long)
    cumulative_heights = torch.cumsum(heights, dim=0)
    total_n = cumulative_heights[-1].item()

    # Generate global shuffled indices
    indices = torch.randperm(total_n)

    # Compute split boundaries
    train_end = int(ratios[0] * total_n)
    val_end = int((ratios[0] + ratios[1]) * total_n)
    train_idx, val_idx, test_idx = (
        indices[:train_end],
        indices[train_end:val_end],
        indices[val_end:],
    )

    def gather_split(split_indices):
        """Return a list of per-dataset split tables."""
        source_indices = torch.bucketize(split_indices, cumulative_heights, right=True)
        previous_cumulative_height = torch.cat([torch.tensor([0]), cumulative_heights])
        local_indices = split_indices - previous_cumulative_height[source_indices]

        # Prepare result holders per dataset
        splits = [[] for _ in datasets]
        for src, loc in zip(source_indices.tolist(), local_indices.tolist()):
            splits[src].append(loc)

        # Create actual split tables (preserving columns)
        split_tables = []
        for ds, locs in zip(datasets, splits):
            if len(locs) == 0:
                split_tables.append(None)
            else:
                split_tables.append(ds.filter(pl.arange(0, ds.height).is_in(locs)))
        return split_tables

    train_splits = gather_split(train_idx)
    val_splits = gather_split(val_idx)
    test_splits = gather_split(test_idx)

    return train_splits, val_splits, test_splits


train_splits, val_splits, test_splits = split_datasets(datasets, seed=69420)

In [None]:
print("Train splits:", [ds.shape if ds is not None else None for ds in train_splits])
print("Validation splits:", [ds.shape if ds is not None else None for ds in val_splits])
print("Test splits:", [ds.shape if ds is not None else None for ds in test_splits])

In [None]:
# Save each split
for i, (train_ds, val_ds, test_ds) in enumerate(
    zip(train_splits, val_splits, test_splits)
):
    base_path = Path("../temp/data/processed/")
    base_path.mkdir(parents=True, exist_ok=True)
    if train_ds is not None:
        train_ds.write_parquet(base_path / f"{paths[i].stem}.train.parquet")
    if val_ds is not None:
        val_ds.write_parquet(base_path / f"{paths[i].stem}.valid.parquet")
    if test_ds is not None:
        test_ds.write_parquet(base_path / f"{paths[i].stem}.test.parquet")
