In [None]:
from clickhouse_driver import Client
import pandas as pd
import numpy as np
from statsmodels.tsa.stattools import coint
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

In [None]:
client = Client(host='localhost', port=9000, user='default', password='password')

start_date = client.execute('SELECT MIN(date) FROM stock_data')[0][0]
end_date = start_date + timedelta(days=48*30)

def get_sector_pairs(sector=None):
    query = '''
    SELECT pair_key, sector, symbol1, symbol2 
    FROM stock_pairs
    '''
    if sector:
        query += f" WHERE sector = '{sector}'"
    return pd.DataFrame(client.execute(query), 
                       columns=['pair_key', 'sector', 'symbol1', 'symbol2'])

def get_stock_data(symbols):
    placeholders = ', '.join(f"'{s}'" for s in symbols)
    query = f'''
    SELECT symbol, date, close 
    FROM stock_data
    WHERE symbol IN ({placeholders})
    AND date BETWEEN '{start_date}' AND '{end_date}'
    ORDER BY symbol, date
    '''
    df = pd.DataFrame(
        client.execute(query), 
        columns=['symbol', 'date', 'close']
    )
    return df.pivot(columns='symbol', values='close', index='date')

pairs_df = get_sector_pairs()
unique_symbols = pd.concat([pairs_df['symbol1'], pairs_df['symbol2']]).unique()
data = get_stock_data(unique_symbols)

scores, pvalues, pairs = find_cointegrated_pairs(data)

plt.figure(figsize=(12, 8))
mask = (pvalues >= 0.98)
sns.heatmap(pvalues, 
            xticklabels=data.columns, 
            yticklabels=data.columns, 
            cmap='RdYlGn_r',
            mask=mask)
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.title('Cointegration p-values heatmap')
plt.tight_layout()
plt.show()

print("\nCointegrated Pairs (p-value < 0.05):")
for pair in pairs:
    print(f"{pair[0]} - {pair[1]}")