# M5 Prime From-Scratch Implementation (PyTorch)

This notebook contains a from-scratch implementation of the M5 Prime (M5P) model tree algorithm using PyTorch for GPU acceleration. It includes:
1. Helper classes for `LinearRegression` and `_Node`.
2. The main `M5P` class, inheriting from `torch.nn.Module` for proper integration.
3. A comprehensive hyperparameter tuning and plotting function to evaluate the model on the MEFAR datasets.

## 1. Importing Libraries

In [None]:
import pandas as pd
import time
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import os

## 2. Helper and Model Classes

In [None]:
# Helper class for Linear Regression at the leaves
class LinearRegression:
    def fit(self, X, y):
        X_b = torch.cat([torch.ones((X.shape[0], 1), device=X.device), X], dim=1)
        self.theta = torch.linalg.pinv(X_b.T @ X_b) @ X_b.T @ y
    
    def predict(self, X):
        X_b = torch.cat([torch.ones((X.shape[0], 1), device=X.device), X], dim=1)
        return X_b @ self.theta

# Helper class for a single node in the tree
class _Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, *, value=None, model=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
        self.model = model

    def is_leaf_node(self):
        return self.value is not None

# Main M5P Model Class (now a torch.nn.Module)
class M5P(nn.Module):
    def __init__(self, min_samples_split=2, max_depth=100):
        super(M5P, self).__init__()
        self.min_samples_split = min_samples_split
        self.max_depth = max_depth
        self.root = None

    def _std_dev(self, y):
        return torch.std(y)

    def _best_split(self, X, y):
        best_sdr, best_idx, best_thresh = -1, None, None
        n_samples, n_features = X.shape
        parent_std = self._std_dev(y)

        for feat_idx in range(n_features):
            thresholds = torch.unique(X[:, feat_idx])
            for thresh in thresholds:
                left_idx = torch.where(X[:, feat_idx] <= thresh)[0]
                right_idx = torch.where(X[:, feat_idx] > thresh)[0]
                
                if len(left_idx) == 0 or len(right_idx) == 0:
                    continue

                y_left, y_right = y[left_idx], y[right_idx]
                sdr = parent_std - (len(left_idx)/n_samples * self._std_dev(y_left) + 
                                    len(right_idx)/n_samples * self._std_dev(y_right))
                
                if sdr > best_sdr:
                    best_sdr, best_idx, best_thresh = sdr, feat_idx, thresh
        
        return best_idx, best_thresh

    def _grow_tree(self, X, y, depth=0):
        n_samples, _ = X.shape
        if (depth >= self.max_depth or n_samples < self.min_samples_split or len(torch.unique(y)) == 1):
            leaf_value = torch.mean(y)
            model = LinearRegression()
            if X.shape[0] > X.shape[1] + 1:
                model.fit(X, y)
            return _Node(value=leaf_value, model=model)

        feat_idx, thresh = self._best_split(X, y)
        if feat_idx is None:
            leaf_value = torch.mean(y)
            model = LinearRegression()
            if X.shape[0] > X.shape[1] + 1:
                model.fit(X, y)
            return _Node(value=leaf_value, model=model)

        left_idx = torch.where(X[:, feat_idx] <= thresh)[0]
        right_idx = torch.where(X[:, feat_idx] > thresh)[0]
        
        left = self._grow_tree(X[left_idx, :], y[left_idx], depth + 1)
        right = self._grow_tree(X[right_idx, :], y[right_idx], depth + 1)
        return _Node(feat_idx, thresh, left, right)

    def fit(self, X, y):
        self.root = self._grow_tree(X, y)

    def _traverse_tree(self, x, node):
        if node.is_leaf_node():
            if node.model and hasattr(node.model, 'theta'):
                return node.model.predict(x.unsqueeze(0))[0]
            return node.value
        
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

    def forward(self, X):
        return torch.tensor([self._traverse_tree(x, self.root) for x in X], device=X.device)

## 3. Tuning and Plotting Function

In [None]:
def tune_and_plot_m5p_pytorch(file_path, min_samples_range, max_depth_range, device):
    filename = os.path.basename(file_path)
    print(f"--- Starting M5P PyTorch Tuning for {filename} ---")

    df = pd.read_csv(file_path)
    if df.isnull().sum().sum() > 0:
        df.fillna(df.mean(), inplace=True)
    
    X = df.iloc[:, :-1].values
    y = df.iloc[:, -1].values
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    X_train_pt = torch.tensor(X_train, dtype=torch.float32, device=device)
    y_train_pt = torch.tensor(y_train, dtype=torch.float32, device=device)
    X_test_pt = torch.tensor(X_test, dtype=torch.float32, device=device)

    results = []
    for min_samples in min_samples_range:
        for max_d in max_depth_range:
            model = M5P(min_samples_split=min_samples, max_depth=max_d)
            model.to(device)
            
            start_time = time.time()
            model.fit(X_train_pt, y_train_pt)
            training_time = time.time() - start_time
            
            continuous_predictions = model(X_test_pt).cpu().numpy()
            binary_predictions = (continuous_predictions > 0.5).astype(int)
            accuracy = accuracy_score(y_test, binary_predictions)
            
            results.append({
                'min_samples_split': min_samples,
                'max_depth': max_d,
                'accuracy': accuracy,
                'training_time': training_time
            })
            print(f"  Min Samples: {min_samples}, Max Depth: {max_d}, Accuracy: {accuracy*100:6.2f}%, Time: {training_time:.2f}s")

    df_results = pd.DataFrame(results)
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    fig.suptitle(f'M5P PyTorch Tuning for {filename}', fontsize=16)
    
    for min_val in min_samples_range:
        subset = df_results[df_results['min_samples_split'] == min_val]
        ax1.plot(subset['max_depth'], subset['accuracy'], marker='o', label=f'min_samples={min_val}')
    
    ax1.set_title('Accuracy vs. Max Depth')
    ax1.set_xlabel('Max Depth')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)

    for min_val in min_samples_range:
        subset = df_results[df_results['min_samples_split'] == min_val]
        ax2.plot(subset['max_depth'], subset['training_time'], marker='o', label=f'min_samples={min_val}')

    ax2.set_title('Training Time vs. Max Depth')
    ax2.set_xlabel('Max Depth')
    ax2.set_ylabel('Training Time (seconds)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

## 4. Main Script Execution

In [None]:
DATASET_PATH = '../Datasets'
dataset_files = [os.path.join(DATASET_PATH, f) for f in os.listdir(DATASET_PATH)]

# FIX: Set the minimum number of samples for a split to at least 2
min_samples_values = [2, 10, 50, 100]
max_depth_values = [5, 10, 15]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

for file in dataset_files:
    tune_and_plot_m5p_pytorch(
        file_path=file,
        min_samples_range=min_samples_values,
        max_depth_range=max_depth_values,
        device=device
    )