In [None]:
# setup
import os
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.stats.stattools import durbin_watson

# Add project root to path
sys.path.append("/Users/gilanorup/Desktop/Studium/MSc/MA/code/masters_thesis_gn/src")

from config.constants import GIT_DIRECTORY
from regression.multiple_regression import run_multiple_regression

# Set task name
task_name = "cookieTheft"

# Run regression and get all relevant variables
model, X_scaled, y, X_train, X_test, y_train, y_test = run_multiple_regression(
    features_path=os.path.join(GIT_DIRECTORY, f"results/features/{task_name}.csv"),
    scores_path=os.path.join(GIT_DIRECTORY, "resources/language_scores_all_subjects.csv"),
    target="PhonemicFluencyScore",
    output_dir=os.path.join(GIT_DIRECTORY, "results/regression"),
    task_name=task_name,
    save_outputs=False
)


def check_regression_assumptions(model, X_train, y_train, output_dir):
    """
    Generate diagnostic plots and statistics for linear regression assumptions:
    - Residuals vs. Fitted
    - Histogram of residuals
    - Q-Q Plot
    - Durbin-Watson test for autocorrelation

    Saves all plots and a text file summarizing the assumption checks.
    """
    os.makedirs(output_dir, exist_ok=True)

    # calculate residuals
    fitted_vals = model.fittedvalues
    residuals = model.resid
    standardized_residuals = model.get_influence().resid_studentized_internal

    # Plot: Residuals vs Fitted
    plt.figure(figsize=(6, 4))
    plt.scatter(fitted_vals, residuals, alpha=0.7, edgecolor='k')
    plt.axhline(0, color='gray', linestyle='--')
    plt.xlabel("Fitted Values")
    plt.ylabel("Residuals")
    plt.title("Residuals vs. Fitted Values")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "residuals_vs_fitted.png"), dpi=300)
    plt.close()

    # Plot: Histogram of residuals
    plt.figure(figsize=(6, 4))
    plt.hist(residuals, bins=30, edgecolor='black', alpha=0.75)
    plt.title("Histogram of Residuals")
    plt.xlabel("Residual")
    plt.ylabel("Frequency")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "residuals_histogram.png"), dpi=300)
    plt.close()

    # Plot: Q-Q Plot
    plt.figure(figsize=(6, 4))
    sm.qqplot(residuals, line='45', fit=True)
    plt.title("Q-Q Plot of Residuals")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "qq_plot.png"), dpi=300)
    plt.close()

    # Durbin-Watson test
    dw_stat = durbin_watson(residuals)

    # save summary
    summary_path = os.path.join(output_dir, "assumption_checks_summary.txt")
    with open(summary_path, "w") as f:
        f.write("Regression Assumption Checks\n")
        f.write("==============================\n")
        f.write(f"Durbin-Watson Statistic: {dw_stat:.4f}\n")
        f.write("\n")
        f.write("See generated plots for:")
        f.write("\n - Residuals vs. Fitted")
        f.write("\n - Histogram of Residuals")
        f.write("\n - Q-Q Plot")

    print(f"Assumption checks saved to: {output_dir}")
