In [None]:
from __future__ import annotations

import argparse
import csv
import json
import random
from datetime import datetime
from pathlib import Path
from typing import List, Sequence

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split

In [None]:
def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Train a Random Forest regressor for cell counting."
    )
    parser.add_argument(
        "--metadata",
        type=Path,
        default=Path("/content/drive/MyDrive/CellCounting/processed/metadata.csv"),
        help="CSV generated by data_preprocess.py.",
    )
    parser.add_argument(
        "--root",
        type=Path,
        default=Path("/content/drive/MyDrive/CellCounting/processed"),
        help="Root directory containing sample NPZ files.",
    )
    parser.add_argument(
        "--experiment-root",
        type=Path,
        default=Path("/content/drive/MyDrive/CellCounting/experiments"),
        help="Directory where run artifacts will be stored.",
    )
    parser.add_argument(
        "--run-name",
        type=str,
        default=None,
        help="Optional identifier for this training run (defaults to timestamp).",
    )
    parser.add_argument(
        "--val-split",
        type=float,
        default=0.1,
        help="Validation split fraction.",
    )
    parser.add_argument(
        "--test-split",
        type=float,
        default=0.1,
        help="Test split fraction.",
    )
    parser.add_argument(
        "--n-estimators",
        type=int,
        default=200,
        help="Number of trees in the forest.",
    )
    parser.add_argument(
        "--max-depth",
        type=int,
        default=None,
        help="Maximum tree depth (None implies unlimited).",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed.",
    )
    return parser.parse_args()

In [None]:
def load_metadata(metadata_path: Path, root: Path) -> List[Path]:
    npz_paths: List[Path] = []
    with metadata_path.open("r", newline="", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            sample_rel = Path(row["sample_npz"])
            if sample_rel.is_absolute():
                sample_path = sample_rel
            elif sample_rel.parts and sample_rel.parts[0] == root.name:
                sample_path = (root.parent / sample_rel).resolve()
            else:
                sample_path = (root / sample_rel).resolve()
            npz_paths.append(sample_path)
    if not npz_paths:
        raise SystemExit(f"No samples found via {metadata_path}.")
    return npz_paths