In [None]:
import toad
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.tsa.stattools import grangercausalitytests
from statsmodels.tsa.stattools import ccf
import io
from PIL import Image
import seaborn as sns

# Read and prepare data
df = pd.read_csv("/Users/lokheilee/python/fyp/6. Stepwise/For SEM dataset quarterly.csv")
df = df.fillna(0)

# Function to calculate lead-lag relationships with max 2 periods
def calculate_lead_lag(data, target_col, max_lags=2):
    lead_lag_results = {}
    
    for column in data.columns:
        if column != target_col:
            correlations = {}
            for lag in range(-max_lags, max_lags + 1):
                if lag == 0:
                    corr = data[target_col].corr(data[column])
                else:
                    corr = data[target_col].corr(data[column].shift(lag))
                correlations[lag] = corr
            
            max_corr_lag = max(correlations.items(), key=lambda x: abs(x[1]))[0]
            
            granger_test = grangercausalitytests(pd.concat([data[target_col], data[column]], axis=1), 
                                               maxlag=max_lags, verbose=False)
            
            lead_lag_results[column] = {
                'max_correlation_lag': max_corr_lag,
                'correlation_value': correlations[max_corr_lag],
                'granger_min_pvalue': min(granger_test[i][0]['ssr_chi2test'][1] for i in range(1, max_lags + 1))
            }
    
    return lead_lag_results

# Function to create and save model summary visualization
def create_model_summary_plot(model, title, filename):
    plt.figure(figsize=(12, 8))
    plt.axis('off')
    plt.text(0.1, 0.95, model.summary().as_text(), fontsize=10, family='monospace', verticalalignment='top')
    plt.title(title)
    plt.savefig(filename, bbox_inches='tight', dpi=300)
    plt.close()

target_y = 'Innoviva Inc'
target_filename = target_y.replace(' ', '_').replace('/', '_').replace('\\', '_')
output_dir = f'analysis_results_{target_filename}'
os.makedirs(output_dir, exist_ok=True)

# Calculate lead-lag relationships
lead_lag_results = calculate_lead_lag(df, target_y, max_lags=2)

# Initial model preparation
df_analysis = pd.DataFrame()
df_analysis[target_y] = df[target_y]
used_base_columns = set()

for column, results in lead_lag_results.items():
    if column in used_base_columns:
        continue
        
    base_column = column.split('_lag_')[0].split('_lead_')[0]
    if base_column in used_base_columns:
        continue
        
    lag = results['max_correlation_lag']
    p_value = results['granger_min_pvalue']
    
    if p_value < 0.05:
        if lag == 0:
            df_analysis[column] = df[column]
        elif lag > 0:
            df_analysis[f'{column}_lag_{lag}'] = df[column].shift(lag)
        else:
            df_analysis[f'{column}_lead_{abs(lag)}'] = df[column].shift(-lag)
    else:
        df_analysis[column] = df[column]
        
    used_base_columns.add(base_column)

df_analysis = df_analysis.fillna(0)

# Iterative model refinement
iteration = 1
current_data = df_analysis.copy()

while True:
    # Perform stepwise regression
    both_aic_data = toad.selection.stepwise(current_data,
                                          target=target_y,
                                          estimator='ols',
                                          criterion='aic',
                                          direction='both')
    
    both_aic_data = both_aic_data[[target_y] + 
        [col for col in both_aic_data.columns if col != target_y]]
    
    # Fit model
    y = both_aic_data.iloc[:, 0]
    X = both_aic_data.iloc[:, 1:]
    X = sm.add_constant(X)
    model = sm.OLS(y, X).fit()
    
    # Save current iteration summary
    create_model_summary_plot(model, 
                            f'Model Summary - Iteration {iteration}',
                            os.path.join(output_dir, f'{target_filename}_model_summary_iteration_{iteration}.png'))
    
    # Check p-values
    p_values = model.pvalues[1:]  # Exclude constant
    max_p_value = p_values.max()
    
    if max_p_value <= 0.05:
        print(f"\nAll p-values are below 0.05 after {iteration} iterations.")
        final_model = model
        break
    
    # Remove variable with highest p-value
    var_to_remove = p_values.idxmax()
    current_data = current_data.drop(columns=[var_to_remove])
    print(f"\nIteration {iteration}: Removed {var_to_remove} (p-value: {max_p_value:.4f})")
    
    iteration += 1

# Save final visualizations
# Correlation heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(both_aic_data.corr(), annot=True, cmap='coolwarm', center=0)
plt.title(f'Final Correlation Heatmap - {target_y}')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'{target_filename}_final_correlation_heatmap.png'), dpi=300)
plt.close()

# Model performance visualization
plt.figure(figsize=(15, 10))
plt.subplot(2, 1, 1)
plt.title(f'Original vs Predicted Values - {target_y}')
plt.plot(y, label='Actual')
plt.plot(final_model.fittedvalues, label='Predicted')
plt.legend()

plt.subplot(2, 1, 2)
plt.title('Residuals')
plt.plot(final_model.resid)
plt.axhline(y=0, color='r', linestyle='--')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'{target_filename}_final_model_visualization.png'), dpi=300)
plt.close()

# Save final results to text file
with open(os.path.join(output_dir, f'{target_filename}_final_analysis_results.txt'), 'w') as f:
    f.write(f"Final Analysis Results for {target_y}\n")
    f.write("=" * 50 + "\n\n")
    f.write(f"Number of iterations required: {iteration}\n\n")
    f.write("Final Model Summary:\n")
    f.write(final_model.summary().as_text())
    f.write("\n\nFinal Variables:\n")
    for var in X.columns[1:]:  # Skip constant
        f.write(f"{var}: p-value = {final_model.pvalues[var]:.4f}\n")

print(f"\nFinal analysis results have been saved in directory: {output_dir}")
print(f"Files saved:")
print(f"1. Model summary for each iteration (PNG files)")
print(f"2. {target_filename}_final_correlation_heatmap.png")
print(f"3. {target_filename}_final_model_visualization.png")
print(f"4. {target_filename}_final_analysis_results.txt")