In [None]:
import numpy as np
import pandas as pd
import re
import os
import glob
from cdt.metrics import SHD
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix

def extract_vars_and_lag(filename):
    vars_match = re.search(r'vars(\d+)', filename)
    lag_match = re.search(r'lag(\d+)', filename)

    if vars_match and lag_match:
        return int(vars_match.group(1)), int(lag_match.group(1))

    return None, None

def get_truth_matrices(n_vars, lag):
    # All truth matrices from the timegraph study are included here, see readme for reference 

    if n_vars == 4:
        if lag == 2:
            truth_matrix = np.array([
                [0, 0, 0, 1], 
                [1, 0, 0, 0], 
                [0, 1, 0, 0],  
                [0, 0, 1, 0]   
            ])
            truth_lag_matrix = np.array([
                [0, 0, 0, 2],  
                [0, 0, 0, 0],  
                [0, 1, 0, 0],  
                [0, 0, 0, 0]  
            ])
        elif lag == 3:
            truth_matrix = np.array([
                [0, 0, 0, 1], 
                [1, 0, 1, 0], 
                [0, 1, 0, 0],  
                [0, 0, 1, 0]   
            ])
            truth_lag_matrix = np.array([
                [0, 0, 0, 2],  
                [0, 0, 3, 0],  
                [0, 1, 0, 0],  
                [0, 0, 0, 0]  
            ])
        elif lag == 4:
            truth_matrix = np.array([
                [0, 0, 0, 1],  
                [1, 0, 1, 0],  
                [0, 1, 0, 0],  
                [0, 0, 1, 0]   
            ])
            truth_lag_matrix = np.array([
                [0, 0, 0, 4],  
                [0, 0, 3, 0],  
                [0, 1, 0, 0],  
                [0, 0, 0, 0]  
            ])

    elif n_vars == 6:
        if lag == 2:
            truth_matrix = np.array([
                [0, 0, 0, 1, 0, 0], 
                [1, 0, 0, 0, 0, 0], 
                [0, 1, 0, 0, 0, 0],  
                [0, 0, 1, 0, 0, 0], 
                [0, 0, 0, 1, 0, 0],  
                [0, 0, 0, 0, 1, 0]   
            ])
            truth_matrix[4, 3] = 1 

            truth_lag_matrix = np.array([
                [0, 0, 0, 2, 0, 0],  
                [0, 0, 0, 0, 0, 0],  
                [0, 1, 0, 0, 0, 0],  
                [0, 0, 0, 0, 0, 0],  
                [0, 0, 0, 1, 0, 0],  
                [0, 0, 0, 0, 0, 0]   
            ])
            truth_lag_matrix[4, 3] = 1  

        elif lag == 3:
            truth_matrix = np.array([
                [0, 0, 0, 1, 0, 0],  
                [1, 0, 1, 0, 0, 0],  
                [0, 1, 0, 0, 0, 0],  
                [0, 0, 1, 0, 0, 0], 
                [0, 0, 0, 1, 0, 0],  
                [0, 0, 0, 0, 1, 0]   
            ])

            truth_lag_matrix = np.array([
                [0, 0, 0, 2, 0, 0],  
                [0, 0, 3, 0, 0, 0], 
                [0, 1, 0, 0, 0, 0],  
                [0, 0, 0, 0, 0, 0], 
                [0, 0, 0, 1, 0, 0],  
                [0, 0, 0, 0, 0, 0]   
            ])

        elif lag == 4:
            truth_matrix = np.array([
                [0, 0, 0, 1, 0, 0],  
                [1, 0, 1, 0, 0, 0], 
                [0, 1, 0, 0, 0, 0],  
                [0, 0, 1, 0, 0, 0], 
                [0, 0, 0, 1, 0, 0],  
                [0, 0, 0, 0, 1, 0]   
            ])

            truth_lag_matrix = np.array([
                [0, 0, 0, 4, 0, 0], 
                [0, 0, 3, 0, 0, 0],  
                [0, 1, 0, 0, 0, 0], 
                [0, 0, 0, 0, 0, 0], 
                [0, 0, 0, 1, 0, 0], 
                [0, 0, 0, 0, 0, 0]  
            ])

    # 8 variables
    elif n_vars == 8:
        if lag == 2:
            truth_matrix = np.zeros((8, 8))

            truth_matrix[0, 3] = 1  
            truth_matrix[3, 2] = 1  
            truth_matrix[2, 1] = 1  
            truth_matrix[1, 0] = 1  

            truth_matrix[3, 4] = 1  
            truth_matrix[4, 5] = 1  
            truth_matrix[5, 6] = 1 
            truth_matrix[6, 7] = 1  

            truth_matrix[4, 3] = 1  

            truth_lag_matrix = np.zeros((8, 8), dtype=int)
            truth_lag_matrix[0, 3] = 2  
            truth_lag_matrix[2, 1] = 1  
            truth_lag_matrix[3, 4] = 1  
            truth_lag_matrix[5, 6] = 1  
            truth_lag_matrix[4, 3] = 1  

        elif lag == 3:
            truth_matrix = np.zeros((8, 8))
            truth_matrix[0, 3] = 1  
            truth_matrix[3, 2] = 1  
            truth_matrix[2, 1] = 1  
            truth_matrix[1, 0] = 1 
            truth_matrix[1, 2] = 1  
            truth_matrix[3, 4] = 1  
            truth_matrix[4, 5] = 1 
            truth_matrix[5, 6] = 1  
            truth_matrix[6, 7] = 1  

            truth_lag_matrix = np.zeros((8, 8), dtype=int)
            truth_lag_matrix[0, 3] = 2 
            truth_lag_matrix[2, 1] = 1 
            truth_lag_matrix[1, 2] = 3  
            truth_lag_matrix[3, 4] = 1 
            truth_lag_matrix[5, 6] = 1 

        elif lag == 4:
            truth_matrix = np.zeros((8, 8))
            truth_matrix[0, 3] = 1  
            truth_matrix[3, 2] = 1  
            truth_matrix[2, 1] = 1  
            truth_matrix[1, 0] = 1  
            truth_matrix[1, 2] = 1  
            truth_matrix[3, 4] = 1  
            truth_matrix[4, 5] = 1 
            truth_matrix[5, 6] = 1 
            truth_matrix[6, 7] = 1  

            truth_lag_matrix = np.zeros((8, 8), dtype=int)
            truth_lag_matrix[0, 3] = 4  
            truth_lag_matrix[2, 1] = 1  
            truth_lag_matrix[1, 2] = 3  
            truth_lag_matrix[3, 4] = 1  
            truth_lag_matrix[5, 6] = 1  

    else:
        return None, None

    return truth_matrix, truth_lag_matrix

def determine_data_type(file):
    if 'nonlinear' in file.lower() or 'C1' in file:
        return 'nonlinear'
    else:
        return 'linear'

def calculate_metrics(pred_matrix, pred_lag_matrix, truth_matrix, truth_lag_matrix):
    truth_binary = truth_matrix.astype(int)

    N = truth_binary.shape[0]
    select_off_diagonal = (np.identity(N) == 0)
    y_true = truth_binary[select_off_diagonal]
    y_scores = pred_matrix[select_off_diagonal]

    metrics = {}


    metrics['auc'] = roc_auc_score(y_true, y_scores)
    
    metrics['shd'] = SHD(truth_binary, pred_matrix)

    # Confusion matrix values calculations 
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    metrics['tpr'] = tpr[optimal_idx]
    metrics['fpr'] = fpr[optimal_idx]

    y_pred_binary = (y_scores >= optimal_threshold).astype(int)


    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred_binary).ravel()
    except Exception as e:
        tn, fp, fn, tp = 0, 0, 0, 0

    # False Discovery Rate as specify in the paper
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0.0

    lag_errors = []
    detected_edges = (pred_matrix > 0) & (truth_binary > 0)

    true_lag_binary = (truth_lag_matrix > 0).astype(int)
    pred_lag_binary = (pred_lag_matrix > 0).astype(int)

    masked_true_lag = true_lag_binary * detected_edges
    masked_pred_lag = pred_lag_binary * detected_edges

    if np.sum(detected_edges) > 0:
        lag_true_flat = masked_true_lag[select_off_diagonal]
        lag_pred_flat = masked_pred_lag[select_off_diagonal]
        lag_scores_flat = pred_lag_matrix[select_off_diagonal] * detected_edges[select_off_diagonal]

        if len(np.unique(lag_true_flat)) > 1:
            metrics['lag_auc'] = roc_auc_score(lag_true_flat, lag_scores_flat)
            lag_fpr, lag_tpr, _ = roc_curve(lag_true_flat, lag_scores_flat)
            lag_optimal_idx = np.argmax(lag_tpr - lag_fpr)
            metrics['lag_tpr'] = lag_tpr[lag_optimal_idx]
            metrics['lag_fpr'] = lag_fpr[lag_optimal_idx]
        else:
            metrics['lag_auc'] = 0.5
            metrics['lag_tpr'] = 0.0
            metrics['lag_fpr'] = 0.0

        metrics['lag_shd'] = np.sum(masked_true_lag != masked_pred_lag)
    else:
        metrics['lag_auc'] = 0
        metrics['lag_tpr'] = 0
        metrics['lag_fpr'] = 0
        metrics['lag_shd'] = 0



    return metrics


def find_all_datasets(base_path='/content/TimeGraph/Datasets', sample_size=500):
    datasets = []
    for filepath in glob.glob(os.path.join(base_path, '**/*.csv'), recursive=True):
        filename = os.path.basename(filepath)
        # missing samples were not taken into account for this research and this function creation
        #remove D here to include all samples
        if (f'n{sample_size}' in filename and f'n{sample_size}0' not in filename and
            ('D1' not in filepath and 'D2' not in filepath and 'D3' not in filepath)):
            datasets.append(filepath)

    return sorted(datasets)

def extract_test_name(filepath):
    parts = filepath.split('/')

    dataset_type = None
    detail_parts = []

    for i, part in enumerate(parts):
        if part in ['A1','A1C', 'A2','A2C', 'B1', 'B1C', 'B2', 'B2C',
                    'C1', 'C1C', 'C2', 'C2C']:
            dataset_type = part
            for j in range(i+1, len(parts)-1):
                detail_parts.append(parts[j])
            break

    filename = parts[-1]
    n_vars, lag = extract_vars_and_lag(filename)

    sample_match = re.search(r'n(\d+)', filename)
    sample_size = sample_match.group(1) if sample_match else 'unknown'

    if dataset_type and detail_parts:
        detail_str = ' '.join(detail_parts).replace('_', ' ')
        test_name = f"{dataset_type} with {detail_str} for {sample_size} rows with {n_vars} var and lag {lag}"
    else:
        test_name = filename

    return test_name

def evaluate_causal_models(model_functions):
    results = []

    print("Auto-finding all n500 datasets in /content/TimeGraph/Datasets...")
    datasets = find_all_datasets()
    print(f"Found {len(datasets)} datasets with n500")
    for ds in datasets[:5]:
        print(f"  - {ds}")
    if len(datasets) > 5:
        print(f"  ... and {len(datasets) - 5} more")

    for dataset_path in datasets:
        if '/' in dataset_path:
            filename = dataset_path.split('/')[-1]
        else:
            filename = dataset_path

        n_vars, lag = extract_vars_and_lag(filename)


        truth_matrix, truth_lag_matrix = get_truth_matrices(n_vars, lag)


        data_type = determine_data_type(dataset_path)
        
        for model_name, model_func in model_functions.items():
            try:
                pred_matrix, pred_lag_matrix = model_func(dataset_path, max_lag=lag)
                metrics = calculate_metrics(pred_matrix, pred_lag_matrix, truth_matrix, truth_lag_matrix)

                result = {
                    'Model': model_name,
                    'Dataset': filename,
                    'Test_Name': extract_test_name(dataset_path),
                    'Type': data_type,
                    'Vars': n_vars,
                    'Lag': lag,
                    'AUC': metrics['auc'],
                    'TPR': metrics['tpr'],
                    'FPR': metrics['fpr'],
                    'SHD': metrics['shd'],
                    'Precision': metrics['precision'],
                    'FDR': metrics['fdr'],  # Added FDR
                    'Specificity': metrics['specificity'],  # Added Specificity
                    'Lag_AUC': metrics['lag_auc'],
                    'Lag_TPR': metrics['lag_tpr'],
                    'Lag_FPR': metrics['lag_fpr'],
                    'Lag_SHD': metrics['lag_shd'],
                    'Sensitivity': metrics['sensitivity'],
                    'Recall': metrics['recall'],
                    'F1_Score': metrics['f1_score'],
                    'Avg_Lag_Error': metrics['avg_lag_error']
                }
                results.append(result)

            except Exception as e:
                print(f"Model {model_name} failed on {filename}: {str(e)}")
                result = {
                    'Model': model_name,
                    'Dataset': filename,
                    'Test_Name': extract_test_name(dataset_path),
                    'Type': data_type,
                    'Vars': n_vars,
                    'Lag': lag,
                    'AUC': np.nan,
                    'TPR': np.nan,
                    'FPR': np.nan,
                    'SHD': np.nan,
                    'Precision': np.nan,
                    'FDR': np.nan,
                    'Specificity': np.nan,
                    'Lag_AUC': np.nan,
                    'Lag_TPR': np.nan,
                    'Lag_FPR': np.nan,
                    'Lag_SHD': np.nan,
                    'Sensitivity': np.nan,
                    'Recall': np.nan,
                    'F1_Score': np.nan,
                    'Avg_Lag_Error': np.nan
                }
                results.append(result)

    return pd.DataFrame(results)


model_functions = {
    'include_function_here': choose_tested_function
}
results_df = evaluate_causal_models(model_functions)
