In [None]:
import yaml
from pathlib import Path

from counterfactuals.datasets.base import DatasetParameters
from counterfactuals.datasets.moons import MoonsDataset

# 1. Load YAML config
with open("config/datasets/moons.yaml", "r") as f:
    cfg_dict = yaml.safe_load(f)

config = DatasetParameters(**cfg_dict)

# 2. Initialize dataset
dataset = MoonsDataset(config=config)

print("✅ Dataset initialized")
print(f"Shape X: {dataset.X.shape}, Shape y: {dataset.y.shape}")

# 3. Train/test split
X_train, X_test, y_train, y_test = dataset.split_data(dataset.X, dataset.y)

print(f"Train size: {X_train.shape[0]}, Test size: {X_test.shape[0]}")

# 4. First few samples
print("\nSample features (X_train):")
print(X_train[:5])
print("\nSample targets (y_train):")
print(y_train[:5])

# 5. Cross-validation check
for i, (X_tr, X_te, y_tr, y_te) in enumerate(dataset.get_cv_splits(n_splits=3)):
    print(f"\nFold {i+1}: train={X_tr.shape}, test={X_te.shape}")


: 