# Correlation Testing

In [8]:

import os
from pricepredict import PricePredict
from datetime import datetime, timedelta
import pandas as pd

model_dir = '../models/'
chart_dir = '../charts/'
preds_dir = '../predictions/'
ppo_dir = '../ppo/'
ppo_save_dir = '../ppo_save/'

def read_ppos(symbols: [str] = None) -> dict:
    """
    Read the PPO objects from the ppo_dir
    """
    ret_ppos = {}
    for file in os.listdir(ppo_dir):
        # Check if filename has _D_ in it and ends with .dill
        if '_D_' in file and file.endswith('.dill'):
            # Get the symbol name from the file name (first chars before _D_)
            symbol = file.split('_D_')[0]
            if symbols is not None:
                if symbol not in symbols:
                    continue
            # Load the PPO object from the file
            with open(f'{ppo_dir}/{file}', 'rb') as f:
                pp_obj = f.read()
                # unserialize the PPO object
                ret_ppos[symbol] = PricePredict.unserialize(pp_obj)

    return ret_ppos

dally_ppos = read_ppos()
pass

In [None]:

def create_ppos(symbols: [str]):
    # Create a PricePredict object for each symbol
    ppos = {}
    for sym in symbols:
        ppo = PricePredict(sym, period=PricePredict.PeriodDaily,
                           model_dir=model_dir,
                           chart_dir=chart_dir,
                           preds_dir=preds_dir,)
        end_dt = datetime.now()
        # Load up over 5 years of data
        start_dt = end_dt - timedelta(days=365 * 5)
        end_date = end_dt.strftime('%Y-%m-%d')
        start_date = start_dt.strftime('%Y-%m-%d')

        # Fetch data for the ppo
        try:
            ppo.fetch_data_yahoo(ppo.ticker, start_date, end_date)
        except Exception as e:
            print(f'Error fetching data for {sym}')
            continue

        ppos[sym] = ppo

    return ppos


In [None]:

dally_ppos = read_ppos()

print(f'Loaded {len(dally_ppos)} daily PPO objects')
print(f'Daily Symbols: {dally_ppos.keys()}')



In [None]:
from tqdm import tqdm

# symbols = ['AAPL', '000001.SS', 'EURUSD=X', 'IBM', 'TSLA', 'SYK', 'RTX', 'QCOM', 'PACB', 'MDLZ']
# ppos = create_ppos(symbols)

all_symbols = sorted(dally_ppos.keys())
print(f'Loaded {len(dally_ppos)} daily PPO objects')
all_ptp = None
# Loop through the ppos sorted by key (symbol)
for pc_period in tqdm(range(7, 271, 2), "Pair Trading Period"):
    # print(f'.', end='', flush=True)
    sym1_pb = tqdm(all_symbols, f'Corr Period: {pc_period}', leave=False)
    for sym1 in all_symbols:
        # print(f'===== {ppos[symbol].ticker} =====')
        # Generate correlations between each symbol and all other symbols
        sym1_pb.update(1)
        sym2_pb = tqdm(all_symbols, f'{sym1}: Corr Period: {pc_period}', leave=False)
        for sym2 in all_symbols:
            if sym1 != sym2:
                # Get the corr between the two symbols
                    try:
                        corr = all_symbols[sym1].periodic_correlation(all_symbols[sym2], pc_period_len=pc_period)
                    except Exception as e:
                        print(f'Error calculating correlation between {sym1} and {sym2}\n{e}')
                        continue
    
                    if corr['coint_stationary']:
                        corr_dict = {'potential_pair': f'{sym1}:{sym2}',
                                     'corr_start_date': corr['start_date'], 'corr_end_date': corr['end_date'],
                                     'period_days': corr['corr_period_len'],
                                     'coint_stasn': corr['coint_stationary'],
                                     'coint_pval':  corr['coint_test']['p_val'],
                                     'adf_pval': corr['adf_test']['p_val']}
                        ptp = pd.DataFrame(corr_dict, index=[0])
                        if all_ptp is None:
                            all_ptp = ptp
                        else:
                            all_ptp = pd.concat([all_ptp, ptp])
                        print(corr_dict)
                    
            sym2_pb.update(1)
            
all_ptp


In [None]:
import inspect
dally_ppos['SEDG'].orig_downloaded_data.__len__()
dally_ppos['SEDG'].date_start, dally_ppos['SEDG'].date_end
# dally_ppos['AAPL'].fetch_data_yahoo('SEDG', '2020-12-31', '2021-01-01')


In [7]:
from pricepredict import PricePredict

sym1, sym2 = ('AMZN', 'XAF=F')
all_symbols[sym1].periodic_correlation(all_symbols[sym2], pc_period_len=pc_period)


TypeError: list indices must be integers or slices, not str