In [None]:
from __future__ import annotations
import random
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq

# ----------------- CONFIG -----------------
# Directories containing your input Parquet files
INPUT_DIRS = [
    "/scratch/users/luigi.silva/speczs-catalogs/processed",
    "/scratch/users/luigi.silva/speczs-catalogs/johns-catalogs",
]

# Directory where the sampled files will be saved
OUTPUT_DIR = Path("/scratch/users/luigi.silva/pzserver_pipelines/combine_redshift_dedup/test_data")

# Maximum number of rows per sample
SAMPLE_MAX = 1000

# Random seed for reproducibility
SEED = 42
# ------------------------------------------


def ensure_outdir(path: Path) -> None:
    """Make sure the output directory exists."""
    path.mkdir(parents=True, exist_ok=True)


def sample_parquet_reservoir(in_path: Path, k: int, seed: int = 42) -> pa.Table:
    """
    Perform random sampling (without replacement) from a Parquet file.
    - If the file has <= k rows, return all rows.
    - Otherwise, use reservoir sampling to select k rows while reading in streaming mode.
      This avoids loading the entire file into memory.
    """
    random.seed(seed)

    pf = pq.ParquetFile(str(in_path))
    total_rows = pf.metadata.num_rows

    # Case 1: small file -> just read all rows
    if total_rows <= k:
        return pf.read()

    # Case 2: large file -> reservoir sampling
    schema = pf.schema_arrow
    columns = [[] for _ in schema]  # temporary storage for sampled rows

    seen = 0  # number of rows processed so far

    for batch in pf.iter_batches():
        # Get columns as arrays for easier row access
        cols = [batch.column(i) for i in range(batch.num_columns)]
        n = batch.num_rows

        for i in range(n):
            if seen < k:
                # Fill the reservoir until it reaches size k
                for c_idx, arr in enumerate(cols):
                    columns[c_idx].append(arr[i].as_py())
            else:
                # Replace elements with decreasing probability
                j = random.randint(0, seen)
                if j < k:
                    for c_idx, arr in enumerate(cols):
                        columns[c_idx][j] = arr[i].as_py()
            seen += 1

    # Convert sampled Python lists back to Arrow arrays
    arrays = [pa.array(col, type=schema.field(i).type) for i, col in enumerate(columns)]
    table = pa.Table.from_arrays(arrays, names=[f.name for f in schema])

    assert len(table) == min(k, total_rows)
    return table


# ----------------- MAIN EXECUTION -----------------
ensure_outdir(OUTPUT_DIR)
count = 0

for d in INPUT_DIRS:
    in_dir = Path(d)
    if not in_dir.is_dir():
        print(f"[WARN] Directory not found: {in_dir}")
        continue

    for p in sorted(in_dir.glob("*.parquet")):
        survey_name = p.stem  # file name without extension
        out_path = OUTPUT_DIR / f"{survey_name}_random_sample.parquet"

        print(f"[INFO] Sampling {p} -> {out_path}")
        tbl = sample_parquet_reservoir(p, SAMPLE_MAX, seed=SEED)

        pq.write_table(tbl, out_path)
        count += 1

print(f"[DONE] {count} files processed. Output in: {OUTPUT_DIR}")