# Split the data in train and test set

In [8]:
from sklearn.model_selection import train_test_split
import pandas as pd

df = pd.read_csv("../data/interim/cleaned_yield_tables.csv")

df.head()

Unnamed: 0,tree_type,yield_class,age,average_height,dbh,taper,trees_per_ha,basal_area
0,coniferous,15.0,20.0,5.3,11.5,0.396,2585.0,26.8
1,coniferous,15.0,30.0,10.6,16.7,0.458,1708.0,37.5
2,coniferous,15.0,40.0,15.7,21.6,0.46,1266.0,46.3
3,coniferous,15.0,50.0,20.5,26.1,0.456,1003.0,53.5
4,coniferous,15.0,60.0,24.6,30.2,0.451,830.0,59.4


In [9]:
def create_stratified_splits(X, y, test_size=0.2, val_size=0.2, random_state=42):
    """Create stratified train/validation/test splits"""

    # First split: separate test set
    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=random_state
    )

    # Second split: separate train/validation from remaining data
    val_size_adjusted = val_size / (1 - test_size)  # Adjust for remaining data
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp,
        y_temp,
        test_size=val_size_adjusted,
        stratify=y_temp,
        random_state=random_state,
    )

    # Verify class distributions
    print("Class distributions:")
    print(f"Train: {y_train.value_counts(normalize=True)}")
    print(f"Val:   {y_val.value_counts(normalize=True)}")
    print(f"Test:  {y_test.value_counts(normalize=True)}")

    return X_train, X_val, X_test, y_train, y_val, y_test


X = df.drop("tree_type", axis=1)
y = df["tree_type"]
X_train, X_val, X_test, y_train, y_val, y_test = create_stratified_splits(X, y)

Class distributions:
Train: tree_type
coniferous    0.805383
deciduous     0.194617
Name: proportion, dtype: float64
Val:   tree_type
coniferous    0.806211
deciduous     0.193789
Name: proportion, dtype: float64
Test:  tree_type
coniferous    0.804969
deciduous     0.195031
Name: proportion, dtype: float64


In [10]:
# Save the data
X_train.to_csv("../data/processed/X_train.csv", index=False)
X_val.to_csv("../data/processed/X_val.csv", index=False)
X_test.to_csv("../data/processed/X_test.csv", index=False)
y_train.to_csv("../data/processed/y_train.csv", index=False)
y_val.to_csv("../data/processed/y_val.csv", index=False)
y_test.to_csv("../data/processed/y_test.csv", index=False)
