# Efficient Cointegration Analysis for Statistical Arbitrage

This notebook demonstrates an optimized approach to identifying stock pairs for statistical arbitrage, with a focus on efficiently processing large-scale financial datasets.

## Problem Statement

In the context of financial markets, cointegration refers to a statistical relationship between two or more time series where the series move together over the long term, despite short-term deviations. This is particularly valuable in pairs trading strategies, where a stable, mean-reverting relationship between two stocks can be exploited for profit. However, analyzing cointegration on a massive dataset—such as 1 billion rows of 1-minute data across 6,000 stocks—presents significant computational challenges, especially when the tests span multiple years.

## Solution

### Data Filtering

To address the computational demands, the dataset was first filtered to include only those stocks that meet specific criteria (e.g., minimum price, volume, and sufficient historical data). This filtering step, optimized using multiprocessing, rapidly reduced the dataset size, ensuring that only liquid and relevant stocks proceeded to the next stage of analysis.

### Efficient Cointegration Testing

The Cointegration Augmented Dickey-Fuller (CADF) test was applied to pairs of filtered stocks to identify those that exhibit a cointegrated relationship, crucial for pairs trading. Instead of testing all possible periods, the solution implemented a random sampling approach, conducting CADF tests on 10,000 randomly selected periods. This method provided a comprehensive yet time-efficient overview of potential cointegrated pairs, with the cointegration score reflecting the consistency of this relationship over time.

Additionally, multiprocessing was employed to run the analysis concurrently across multiple pairs, significantly reducing the overall processing time.

## Impact

The approach effectively manages large-scale datasets, achieving a balance between computational efficiency and rigorous analysis. This solution is particularly suitable for developing high-performance trading algorithms that require real-time processing and decision-making in a production environment.


## 1. Stock Pair Filtering

In [4]:
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
from tqdm.notebook import tqdm
import psycopg2

def fetch_data_from_db(CONN_STRING, symbol, start_date, end_date, table):
    """
    Fetches minute data from a specified table and date range.
    """
    try:
        with psycopg2.connect(CONN_STRING) as conn:
            with conn.cursor() as cur:
                query = f"""
                SELECT symbol, timestamp, open, high, low, close, volume
                FROM public.{table}
                WHERE symbol = %s 
                AND timestamp BETWEEN %s AND %s
                ORDER BY timestamp;
                """
                cur.execute(query, (symbol, start_date, end_date))
                columns = ['symbol', 'timestamp', 'open', 'high', 
                           'low', 'close', 'volume']
                data = pd.DataFrame(cur.fetchall(), columns=columns)

        data['timestamp'] = pd.to_datetime(data['timestamp'])
        data.set_index('timestamp', inplace=True)
        data = data.drop('symbol', axis=1)

        return data

    except Exception as e:
        print(f"Error fetching data: {e}")
        return pd.DataFrame()

# Check if a symbol has any rows with price out of the range, volume 
# below the minimum, or insufficient data length
def has_price_out_of_range_or_low_volume(
    symbol, start_date, end_date, conn_string, table_name, 
    min_stock_price=10, max_stock_price=1000, min_volume=5000, 
    min_data_length=1000):
    
    data = fetch_data_from_db(conn_string, symbol, start_date, 
                              end_date, table_name)
    if data.empty or len(data) < min_data_length:
        return True 

    data['avg_price'] = (data['open'] + data['high'] + data['low'] 
                         + data['close']) / 4
    is_price_out_of_range = (data['avg_price'] < min_stock_price).any() \
                            or (data['avg_price'] > max_stock_price).any()
    has_low_volume = (data['volume'] < min_volume).any()

    return is_price_out_of_range or has_low_volume

# Load symbols from a CSV file
symbols = pd.read_csv(
    '/home/jj/projects/algo_trading/chapter-strategy-optimisation/'
    'data/symbols_1m.csv'
)
symbols = list(symbols['Symbol'].values)

# Define constants
CONN_STRING = ("host='192.168.3.41' dbname='proxima' user='airflow' "
               "password='airflow' port='5432'")
TABLE = 'data_bars_1min_adj_splitdiv'
START_DATE = '2024-01-01'
END_DATE = '2024-01-15'
MIN_STOCK_PRICE = 10
MAX_STOCK_PRICE = 300
MIN_VOLUME = 2000  # Minimum volume to ensure liquidity
MIN_DATA_LENGTH = 100  # Minimum number of rows of data required

# Process the symbols in parallel
with ProcessPoolExecutor(max_workers=32) as executor:
    price_and_volume_flags = list(tqdm(
        executor.map(has_price_out_of_range_or_low_volume, symbols, 
                     [START_DATE]*len(symbols), [END_DATE]*len(symbols), 
                     [CONN_STRING]*len(symbols), [TABLE]*len(symbols), 
                     [MIN_STOCK_PRICE]*len(symbols), 
                     [MAX_STOCK_PRICE]*len(symbols), 
                     [MIN_VOLUME]*len(symbols), 
                     [MIN_DATA_LENGTH]*len(symbols)), 
        total=len(symbols), 
        desc="Checking for price out of range or low volume"
    ))

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Symbol': symbols, 
    'Has Price Out of Range or Low Volume': price_and_volume_flags
})

# Filter symbols based on the price range and volume flag
filtered_symbols = list(results_df[
    results_df['Has Price Out of Range or Low Volume'] == False]['Symbol']
)
print(len(filtered_symbols))
print(filtered_symbols)


Checking for price out of range or low volume:   0%|          | 0/6116 [00:00<?, ?it/s]

16
['AAPL', 'AMD', 'AMZN', 'BABA', 'BAC', 'F', 'GOOG', 'GOOGL', 'INTC', 'MARA', 'PFE', 'PLTR', 'PYPL', 'TSLA', 'UBER', 'XOM']


## Cointegration Analysis of Stock Pairs


In [3]:
import pandas as pd
import numpy as np
from statsmodels.tsa.stattools import coint
from datetime import datetime as dt, timedelta
import matplotlib.pyplot as plt
import sys
import random
from tqdm.notebook import tqdm
from concurrent.futures import ProcessPoolExecutor
import ast
sys.path.append(
    '/home/jj/anaconda3/envs/stocks/Dropbox/Code/Notebooks/lib/'
)
from data_fetcher import fetch_data_from_db

def fetch_data(symbol, start_date, end_date, conn_string, table_name):
    """Fetch data from the database."""
    return fetch_data_from_db(
        conn_string, 
        symbol, 
        start_date, 
        end_date, 
        table_name
    )

def cadf_test(y, x):
    """Perform the cointegration test."""
    cadf_test = coint(y, x)
    cadf_stat = cadf_test[0]
    cadf_critical_values = cadf_test[2]
    return cadf_stat, cadf_critical_values

def analyze_pair(pair, start_date, end_date, lookback_period_rows, 
                 conn_string, table_name, n_periods):
    """Analyze a single pair of symbols for cointegration."""
    try:
        symbol1, symbol2 = pair
        stock_data1 = fetch_data(symbol1, start_date, end_date, 
                                 conn_string, table_name)
        stock_data2 = fetch_data(symbol2, start_date, end_date, 
                                 conn_string, table_name)

        # Ensure both data sets have the same index
        stock_data1.index = pd.to_datetime(stock_data1.index)
        stock_data2.index = pd.to_datetime(stock_data2.index)

        # Concatenate data for the pair into a single DataFrame
        df = pd.concat([stock_data1["close"].rename(symbol1), 
                        stock_data2["close"].rename(symbol2)], axis=1)

        # Handle missing data by forward-filling and then backward-filling
        df.ffill(inplace=True)
        df.bfill(inplace=True)

        # Limit the DataFrame to the specified date range
        df = df.loc[start_date:end_date]

        report_list = []
        annual_scores = {}
        max_start_index = len(df) - lookback_period_rows
        start_indices = random.sample(range(max_start_index), n_periods)

        for start_index in start_indices:
            end_index = start_index + lookback_period_rows
            window_df = df.iloc[start_index:end_index]
            y = window_df[symbol1]
            x = window_df[symbol2]

            # Check if either series is constant
            if y.nunique() == 1 or x.nunique() == 1:
                continue  # Skip this sample if any series is constant
            
            # Perform the CADF test to check if the time series are cointegrated.
            cadf_stat, cadf_critical_values = cadf_test(y, x)
            is_cointegrated = cadf_stat < cadf_critical_values[1]  # 5% critical value
            # Determine if the pair is cointegrated by comparing the CADF statistic 
            # with the 5% critical value.
            
            start_date = window_df.index[0]
            end_date = window_df.index[-1]

            report_list.append({
                'Pair': [symbol1, symbol2],
                'Start Date': start_date, 
                'End Date': end_date, 
                'CADF Statistic': cadf_stat,
                'Critical Value (5%)': cadf_critical_values[1],
                'Cointegrated': is_cointegrated
            })

            year = start_date.year
            if year not in annual_scores:
                annual_scores[year] = {'cointegrated': 0, 'total': 0}
            annual_scores[year]['total'] += 1
            if is_cointegrated:
                annual_scores[year]['cointegrated'] += 1

        annual_report_list = []
        for year in annual_scores:
            annual_score = annual_scores[year]['cointegrated'] / \
                           annual_scores[year]['total']
            annual_report_list.append({
                'Pair': [symbol1, symbol2],
                'Year': year,
                'Cointegration Ratio': annual_score
            })
        
        return annual_report_list
    except Exception as e:
        print(f"Error processing pair {pair}: {e}")
        return []

def analyze_cointegration(symbols, start_date, end_date, 
                          lookback_period_rows, conn_string, 
                          table_name, n_periods, max_workers=32):
    """Analyze cointegration for all pairs of symbols."""
    symbol_pairs = [(symbols[i], symbols[j]) for i in range(len(symbols)) 
                    for j in range(i + 1, len(symbols))]

    annual_report_list = []

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(analyze_pair, pair, start_date, end_date, 
                                   lookback_period_rows, conn_string, 
                                   table_name, n_periods) 
                   for pair in symbol_pairs]
        for future in tqdm(futures, total=len(symbol_pairs)):
            annual_report_list.extend(future.result())

    annual_report = pd.DataFrame(annual_report_list)

    return annual_report

if __name__ == "__main__":
    CONN_STRING = ("host='192.168.3.41' dbname='proxima' user='airflow' "
                   "password='airflow' port='5432'")
    TABLE1M = 'data_bars_1min_adj_splitdiv'
    START_DATE = dt(2023, 1, 1)
    END_DATE = dt(2023, 12, 1)
    LOOKBACK_PERIOD_ROWS = 25  # lookback period for ADF test in number of data points
    N_PERIODS = 10000  # Number of random samples to analyze
    MAX_WORKERS = 32

    annual_report = analyze_cointegration(filtered_symbols[:], START_DATE, 
                                          END_DATE, LOOKBACK_PERIOD_ROWS, 
                                          CONN_STRING, TABLE1M, N_PERIODS, 
                                          MAX_WORKERS)
    annual_report = annual_report.sort_values(by='Cointegration Ratio', 
                                              ascending=False)
    
    # Convert the 'Pair' column to string format before saving to CSV
    annual_report['Pair'] = annual_report['Pair'].apply(str)
    annual_report.to_csv('./cointegrated_pairs_2.csv', index=False)
   
    # To read the DataFrame back with 'Pair' as list
    def read_cointegrated_pairs(file_path):
        df = pd.read_csv(file_path)
        df['Pair'] = df['Pair'].apply(ast.literal_eval)
        return df

    annual_report_read = read_cointegrated_pairs('./cointegrated_pairs_2.csv')
    display(annual_report_read)


  0%|          | 0/561 [00:00<?, ?it/s]

Unnamed: 0,Pair,Year,Cointegration Ratio
0,"[GOOG, GOOGL]",2023,0.300600
1,"[T, VZ]",2023,0.141100
2,"[MARA, RIOT]",2023,0.132800
3,"[F, PLTR]",2023,0.126500
4,"[F, SQ]",2023,0.123500
...,...,...,...
556,"[MARA, XOM]",2023,0.077362
557,"[TSLA, WBA]",2023,0.076954
558,"[TSLA, XOM]",2023,0.076214
559,"[AMD, VZ]",2023,0.075000
