In [1]:
import pandas as pd
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from MongoDB.client import SyncDB

### Set Configs

In [2]:
# Please change the parameters here
# For corr_method you can set the following:
# 'pearson' : standard correlation coefficient
# 'kendall' : Kendall Tau correlation coefficient
# 'spearman' : Spearman rank correlation
# 'all': All of 3 methods

startDate = datetime(2019, 1, 1)
endDate = datetime(2020, 1, 1)
stock_collection = 'stock_daily'
stock_column = 'Close'
stock_label = 'SPY'
stock_match = {'symbol': 'SPY'} 
data_collection = 'TA_Daily'  
data_column = 'RSI'
data_label = 'RSI'
data_match = {'symbol': 'SPY'}

corr_method = 'pearson'

SmoothPeriod = 5

### Prepare data from stock_collection

In [3]:
dateRange = (endDate - startDate).days
modifiedStart = startDate - timedelta(days=int(SmoothPeriod*1.4)+1)

query_1 = {**stock_match, **{'date': {'$gte': modifiedStart, '$lte': endDate}}}

data_list_1 = list(SyncDB.find(stock_collection, query_1))
if not data_list_1:
    query_1 = {**stock_match, **{'date': {'$gte': modifiedStart.strftime('%Y-%m-%d'), '$lte': endDate.strftime('%Y-%m-%d')}}}
    data_list_1 = list(SyncDB.find(stock_collection, query_1))
if not data_list_1:
    query_1 = {**stock_match, **{'date': {'$gte': modifiedStart, '$lte': endDate}}}
    data_list_1 = list(SyncDB.find(stock_collection, query_1))
if not data_list_1:
    query_1 = {**stock_match, **{'Date': {'$gte': modifiedStart.strftime('%Y-%m-%d'), '$lte': endDate.strftime('%Y-%m-%d')}}}
    data_list_1 = list(SyncDB.find(stock_collection, query_1))
if not data_list_1:
    query_1 = {**stock_match, **{'TradeTime': {'$gte': modifiedStart, '$lte': endDate}}}
    pipeline = [{'$match': query_1},
               {'$group' : {
                   '_id': {
                       "year": {"$year": "$TradeTime" },
                       "month": {"$month": "$TradeTime"}, 
                       "day": {"$dayOfMonth": "$TradeTime"}
                   },
                   ('total%s' % stock_column): { '$sum': '$%s' % stock_column}                   
                }},
               ]
    data_list_1 = list(SyncDB.aggregate(stock_collection, pipeline))
    new_data_list = list()
    for record in data_list_1:
        record['date'] = datetime(record['_id']['year'], record['_id']['month'], record['_id']['day'])
        record[stock_column] = record['total%s' % stock_column]
        new_data_list.append({'date': record['date'], stock_column: record[stock_column]})
    data_list_1 = new_data_list

index_1 = 'date' if 'date' in data_list_1[0] else 'Date'
cols = [index_1, stock_column]
df = pd.DataFrame(data_list_1)[cols]
if type(df[index_1]) != pd.core.indexes.datetimes.DatetimeIndex:
    df[index_1] = pd.to_datetime(df[index_1], infer_datetime_format=True)
df.set_index(index_1, drop=True, inplace=True)
df.sort_values(index_1, inplace=True)

### Prepare data from data_collection

In [4]:
def add_collection_to_df(collection, query, column, label, df):
    query['date'] = {'$gte': modifiedStart, '$lte': endDate}
    data_list = list(SyncDB.find(collection, query))
    if not data_list:
        query['date'] = {'$gte': modifiedStart.strftime('%Y-%m-%d'), '$lte': endDate.strftime('%Y-%m-%d')}
        data_list = list(SyncDB.find(collection, query))
        if not data_list:
            query['Date'] = {'$gte': modifiedStart, '$lte': endDate}
            data_list = list(SyncDB.find(collection, query))
            if not data_list:
                query['Date'] = {'$gte': modifiedStart.strftime('%Y-%m-%d'), '$lte': endDate.strftime('%Y-%m-%d')}
                data_list = list(SyncDB.find(collection, query))
    index = 'date' if 'date' in data_list[0] else 'Date'
    cols = [index, column]
    df_n = pd.DataFrame(data_list)[cols]
    if type(df_n[index]) != pd.core.indexes.datetimes.DatetimeIndex:
        df_n[index] = pd.to_datetime(df_n[index], infer_datetime_format=True)
    df_n.set_index(index, drop=True, inplace=True)
    df_n.sort_values(index, inplace=True)
    if column in list(df.columns) or ((column+'SMA') in list(df.columns)):
        col_new = label
    else:
        col_new = column
    df_n.columns = [col_new]
    tmp_col_s = df_n[col_new].rolling(SmoothPeriod).mean().to_frame(name=col_new+'SMA')
    df = df.join(df_n[col_new], how='left')
    df = df.join(tmp_col_s, how='left')
    return col_new, df
    
data_column, df = add_collection_to_df(data_collection, data_match, data_column, data_label, df)
df_to_plot = df[(df.index<=endDate)&(df.index>=startDate)].dropna(how='any')

In [5]:
# Calculate correlation value
corr_values = [(df_to_plot[[stock_column, data_column]].dropna(how='any')).corr(method=corr_method).iat[0,1],
            (df_to_plot[[stock_column, data_column+'SMA']].dropna(how='any')).corr(method=corr_method).iat[0,1]]

In [None]:
### Plot results
# create a fig of width=10 and length=200
fig, ax = plt.subplots(2, 1, figsize=(26, 15), sharex='col')
fig.subplots_adjust(hspace=0)

# set the label of x axis and y axis

ax[0].set_ylabel(stock_column)
ax[0].text(0.02, 0.78, 'megapro.com', horizontalalignment='left', color='gray', alpha=0.4,
        verticalalignment='center', rotation=0, fontsize=25, transform=ax[0].transAxes, zorder=0)

ax[0].text(0.02, 0.62, 'Join Discord: mCmMjSRuBn', horizontalalignment='left', color='gray', alpha=0.4,
        verticalalignment='center', rotation=0, fontsize=25, transform=ax[0].transAxes, zorder=0)

if not stock_label:
    stock_label = stock_column

ax[0].plot(df_to_plot.index, df_to_plot[stock_column], color='blue', label=stock_label)
ax[0].text(0.08, 0.9, '%s: %.4f' % (stock_label, df_to_plot[stock_column].iat[-1]), 
        horizontalalignment='left', color='blue', verticalalignment='center', fontsize=10, transform=ax[0].transAxes)

        
def plot_data_panel(inx, label, col, corr_values):
    label_1 = '[DATA]' + label if label else '[DATA]' + col
    ax[inx].set_ylabel(label_1[6:])
    ax[inx].plot(df_to_plot.index, df_to_plot[col], color='orange', label=label_1)
    ax[inx].text(0.2, 0.9, '%s: %.4f' % (label_1, df_to_plot[col].iat[-1]), 
            horizontalalignment='left', color='black', verticalalignment='center', fontsize=10, transform=ax[inx].transAxes)
    ax[inx].text(0.9, 0.9, 'Correlation: %.4f' % corr_values[0], 
            horizontalalignment='right', color='orange', verticalalignment='center', fontsize=15, transform=ax[inx].transAxes)           
    if SmoothPeriod > 1:
        label_2 = label_1 + ' SMA %d' % SmoothPeriod
        ax[inx].plot(df_to_plot.index, df_to_plot[col+'SMA'], color='magenta', label=label_2)
        ax[inx].text(0.4, 0.9, '%s: %.4f' % (label_2, df_to_plot[col+'SMA'].iat[-1]), 
                horizontalalignment='left', color='black', verticalalignment='center', fontsize=10, transform=ax[inx].transAxes)
        ax[inx].text(0.9, 0.8, '%.4f' % corr_values[1], 
                horizontalalignment='right', color='magenta', verticalalignment='center', fontsize=15, transform=ax[inx].transAxes)       
    return
      
inx = 1
plot_data_panel(inx, data_label, data_column, corr_values)

# set the legend at upper left corner
#ax[0].legend(loc=[0.002, 0.88],prop={'size': 18})
#ax[1].legend(loc=[0.002, 0.94],prop={'size': 18})
# set date xaxis format
for axi in ax:
    axi.grid(axis="x", color='grey',linestyle=':',linewidth=0.75)
    axi.set_xmargin(0.005)
    axi.legend(loc=[0.002, 0.85],prop={'size': 10})
    axi.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    axi.xaxis.set_major_locator(mdates.DayLocator(interval=int(dateRange/80)))

fig.autofmt_xdate(rotation=90)
# set title

pic_title = 'Megapro Chart %s-%s Correlation Study\n%s-%s' % (stock_label, data_label, startDate.strftime('%Y%m%d'), endDate.strftime('%Y%m%d'))
fig.suptitle(pic_title, fontsize=30, y=0.98)
#fig.tight_layout()
fig.subplots_adjust(top=0.88)

# save plot to file
filename = ('%s_%s_Corr_%s-%s.png' % (stock_label, data_label, startDate.strftime('%y%m%d'), endDate.strftime('%y%m%d'))).replace(' ', '')

plt.savefig(filename)

url = 'https://jbook123456.megagurus.net/user/yourusername/view/%s' % filename
print(url)