In [None]:
import polars as pl
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from datetime import datetime, timedelta, date
import pickle

Countries_indexes = ["EUSA", "EWC", "EWU", "EWG", "EWQ", "EWJ", "MCHI", "INDA", "EWA", "EWY", "EWW", "EWL", "EWT", "EWH", "EWS", "EWI", "EWP", "EWN", "EWD", "EWO", "EWK", "EDEN", "EFNL", "EIS", "EWZ"]
# Compute returns function in Polars
def compute_returns(df, price_col="weighted-avg-price", lag=1):
    return df.with_columns(
        (pl.col(price_col).diff(lag) / pl.col(price_col).shift(lag)).alias("return")
    ).drop_nulls(subset=["return"])


def rolling_cross_correlation(df1, df2, column1, column2, window_length, delta_step=5):
    """
    Compute rolling cross-correlation between two DataFrames.
    
    Parameters:
    - df1, df2: Polars DataFrames
    - column1, column2: Column names to correlate
    - window_length: Size of the rolling window
    - delta_step: Sliding window step size
    
    Returns:
    Polars DataFrame with correlation results
    """
    # Convert to NumPy for correlation computation
    arr1 = df1[column1].to_numpy()
    arr2 = df2[column2].to_numpy()

    # non_zero_mask = ~((arr1 == 0) & (arr2 == 0))

    # # Apply the mask to arr1 and arr2
    # arr1 = arr1[non_zero_mask]
    # arr2 = arr2[non_zero_mask]
    
    correlations = []
    window_starts = []
    covariances = []
    
    # Slide window with specified delta step
    for start in range(0, min(len(arr1),len(arr2)) - window_length + 1, delta_step):
        end = start + window_length
        window_start = start
        
        window1 = arr1[start:end]
        window2 = arr2[start:end]
        
        # Compute correlation for the current window
        correlation = np.corrcoef(window1, window2)[0, 1]
        cov = np.cov(window1, window2)[0, 1]

        covariances.append(cov)
        correlations.append(correlation)
        window_starts.append(window_start)
    
    # Create Polars DataFrame with results
    result_df = pl.DataFrame({
        "window_start": window_starts,
        "correlation": correlations,
        "covariance": covariances
    })
    
    return result_df

#  Function to compute rolling correlation for all pairs
def compute_all_pairs_rolling_correlation(countries, window_length=60, delta_step=5):
    results = {}
    for i in tqdm(range(len(countries))):
        for j in range(i + 1, len(countries)):
            country1 = countries[i]
            country2 = countries[j]
            print(f"Computing rolling correlation for {country1} and {country2}...")
            # Load data for the two countries
            df1 = pl.read_parquet(f"Data/clean/{country1}.parquet")
            df2 = pl.read_parquet(f"Data/clean/{country2}.parquet")

            ret_df1 = compute_returns(df1)
            ret_df2 = compute_returns(df2)
            
            # Compute rolling correlation
            rolling_corr_df = rolling_cross_correlation(
                ret_df1, ret_df2, 
                column1="return", column2="return", 
                window_length=window_length, 
                delta_step=delta_step,
            )
            
            # Drop NaNs and compute mean correlation
            # rolling_corr_df = rolling_corr_df.fill_nan(None)   
            # rolling_corr_df = rolling_corr_df.drop_nulls()
            # mean_corr = rolling_corr_df["correlation"].mean()
            
            # Store the result
            results[(country1, country2)] = rolling_corr_df
    
    return results

# Compute rolling correlation for all pairs in Countries_indexes

all_pairs_rolling_corr = compute_all_pairs_rolling_correlation(("EUSA","EWC"), window_length=60, delta_step=5)
print(all_pairs_rolling_corr)
with open('all_pairs_rolling_corr.pkl', 'wb') as f:
    pickle.dump(all_pairs_rolling_corr, f)
