# MFHT Grid Plot
plt.close('all')

In [None]:
import sqlite3

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm

from scipy.optimize import curve_fit

from stabilvol.utility import functions as f

DATABASE = '../data/processed/trapezoidal_selection/stabilvol.sqlite'
# Connect to the SQLite database
conn = sqlite3.connect(DATABASE)
cur = conn.cursor()

In [None]:
import os
print(os.getcwd())
os.path.exists(DATABASE)

In [None]:
start_date = '2010-01-01'
end_date = '2022-07-01'
t1_string = "m0p5"
t2_string = "m1p5"
vol_limit = 0.5
market = "UN"
# Write the SQL query
query = f'''
SELECT *
FROM stabilvol_{t1_string}_{t2_string}
WHERE Volatility < {vol_limit} 
AND Market = "{market}"
AND start > "{start_date}"
AND end < "{end_date}"    
'''
df = pd.read_sql_query(query, conn)
df

In [None]:
def select_bins(df, max_n=1000):
    nbins = 50
    
    while True:
        # Use qcut to bin 'Volatility' values
        df['Bins'] = pd.qcut(df['Volatility'], nbins, duplicates='drop')
        
        # Group by the bins and calculate the mean and standard error of 'value'
        grouped = df.groupby('Bins')['FHT'].agg(['mean', error_on_the_mean, 'size'])
        count = grouped['size'].min()
        
        if count < max_n or nbins > 1000:
            break
        else:
            nbins += 50
    return grouped, nbins

def error_on_the_mean(values):
    return np.std(values)/np.sqrt(len(values))

In [None]:
MARKETS = ["UN", "UW", "LN", "JT"]
START_LEVELS = [-2.0, -1.0, -0.5, -0.2, -0.1, 0.1, 0.2, 0.5, 1.0, 2.0]
DELTAS = [2.0, 1.0, 0.5, 0.2, 0.1, -0.1, -0.2, -0.5, -1.0, -2.0]
LEVELS = {
    (start, round(start+delta, 2)) for start in START_LEVELS for delta in DELTAS
}
LEVELS = sorted(LEVELS)

VOL_LIMIT= 0.5  # Change this will change all the pickle files, remember to re-generate them
LEVELS

In [None]:
def query_binned_data(market: str, t1_string: str, t2_string: str, vol_limit: float):
    grouped_data = None
    try:            
        # Write the SQL query
        query = f'''
        SELECT *
        FROM stabilvol_{t1_string}_{t2_string}
        WHERE Volatility < {vol_limit} AND Market = "{market}"
        '''
        # Load the FHT data from the database
        df = pd.read_sql_query(query, conn)
    except pd.errors.DatabaseError:
        print(f'No data for market {market} with thresholds {t1_string}-{t2_string}')
        nbins = 0
    else:
        grouped_data, nbins = select_bins(df)
    return grouped_data, nbins

def save_all_mfhts(market, save=True):
    bins_dict = {}
    for t1, t2 in tqdm(LEVELS):
        # Create the strings for the threshold values
        t1_string = f.stringify_threshold(t1)
        t2_string = f.stringify_threshold(t2)
        # Filename for the MFHT data
        filename = f'../data/processed/trapezoidal_selection/mfht_{market}_{t1_string}_{t2_string}.pkl'
           
        if save and not os.path.exists(filename):
            # Load the dataframe from the database if it exists
            grouped_data, nbins = query_binned_data(market, t1_string, t2_string, VOL_LIMIT)
            grouped_data.to_pickle(filename)
        else:
            print(f"File '{filename}' already exists")
            nbins = 0
        bins_dict[(t1, t2)] = nbins  
            
    return bins_dict

In [None]:
market = "UN"
nbins_un = save_all_mfhts(market)

In [None]:
market = "UW"
nbins_uw = save_all_mfhts(market)

In [None]:
market = "LN"
nbins_ln = save_all_mfhts(market)

In [None]:
market = "JT"
nbins_jt = save_all_mfhts(market)

In [None]:
def get_thresholds(market):
    for filename in os.listdir('../data/processed/trapezoidal_selection/'):
        if filename.startswith(f'mfht_{market}_'):
            t1, t2 = filename.replace(".pkl", "").split('_')[2:4]
            t1 = f.numerify_threshold(t1)
            t2 = f.numerify_threshold(t2)
            yield (t1, t2)

In [None]:
from IPython.display import display, Markdown

# Create a dictionary where keys are the unique first elements and values are lists of corresponding second elements
table_dict = {}

thresholds_table = [[t1, t2] for t1, t2 in sorted(get_thresholds("UN"), key=lambda x: float(x[0]), reverse=False)]
for item in thresholds_table:
    if item[0] not in table_dict:
        table_dict[item[0]] = [item[1]]
    else:
        table_dict[item[0]].append(item[1])

# Create the markdown table
markdown_table = f"| Start Threshold {'| End |'*1} |\n{'|:-------:|'*1}\n"
for key, values in table_dict.items():
    values = sorted(values, key=lambda x: float(x))
    markdown_table += f"| {key} | {'| '.join([str(s) for s in values])} |\n"

# Display the markdown table
display(Markdown(markdown_table))

In [None]:
import matplotlib.colors as mcolors

def desaturate_color(color):
    # Convert RGB to HLS
    rgb = mcolors.to_rgb(color)
    h, s, v = mcolors.rgb_to_hsv(rgb)

    # Decrease the saturation by 50% to get a desaturated color
    return mcolors.hsv_to_rgb((h, s/2, v))

def plot_mfht_grid(markets, plotsscale='', show=False):
    if not isinstance(markets, list):
        markets = [markets]
    if plotsscale not in ['', 'log', 'logx', 'logy']:
        raise ValueError("plotsscale must be either '' or 'log'")
    
    n_rows = 10
    n_cols = 10
    
    # Create a grid of subplots
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 18))
    
    # Flatten the array of axes
    axs = axs.flatten()
    
    # Iterate over the indices and axes
    for i, ((t1, t2), ax) in enumerate(zip(LEVELS, axs)):
        # Create the strings for the threshold values
        t1_string = f.stringify_threshold(t1)
        t2_string = f.stringify_threshold(t2)
        
        for market in markets:
            # Load the dataframe from the database if it exists
            try:
                df = pd.read_pickle(f'../data/processed/trapezoidal_selection/mfht_{market}_{t1_string}_{t2_string}.pkl')
            except FileNotFoundError as e:
                print(f"File 'mfht_{market}_{t1_string}_{t2_string}.pkl' not found")
                continue
            else:
                
                x = df.index.categories.left.values
                y = df['mean'].values
                
                y_err = df['error_on_the_mean'].values
                
                line, = ax.plot(x, y, label=market if i==0 else "")
                
                ax.fill_between(x, y - y_err, y + y_err, color=desaturate_color(line.get_color()))
            
            if plotsscale == 'logx':
                ax.set_xscale('log')
            elif plotsscale == 'logy':
                ax.set_yscale('log')
            elif plotsscale == 'log':
                ax.set_xscale('log')
                ax.set_yscale('log')
            
            # If this is the first column, set the y-label
            if i % n_cols == 0:
                ax.set_ylabel(f"$\\theta_i = {t1}$", fontsize=16)
            # Always set the title with the final threshold value
            ax.set_title(f"$\\theta_f = {t2}$", fontsize=16)
    
    fig.tight_layout()
    plt.subplots_adjust(wspace=0.5)
    
    # Place a legend above the subplots
    legend = fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.008), ncol=4)
    
    # Increase the font size
    plt.setp(legend.get_texts(), fontsize=18)
    
    # Increase the linewidth
    plt.setp(legend.get_lines(), linewidth=3)
    
    if show:
        plt.show()
    
    marketsname = ''.join(markets) 
    fig.savefig(f'../visualization/mfhts/{marketsname}_FHT_threshold_{plotsscale}grid.png', bbox_inches='tight')
    fig.savefig(f'../visualization/mfhts/{marketsname}_FHT_threshold_{plotsscale}grid.eps', bbox_inches='tight')

In [None]:
plt.close('all')
# plot_mfht_grid(["UN", "UW", "LN", "JT"], plotsscale='logx')
plot_mfht_grid(["UN", "UW", "LN", "JT"], plotsscale='')

In [None]:
n_rows = 10
n_cols = 10

# Create a grid of subplots
fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 16))

# Flatten the array of axes
axs = axs.flatten()

for i, ((t1, t2), ax) in enumerate(zip(LEVELS, axs)):
    ax.text(0.5, 0.5, 
            f"UN:{nbins_un[(t1, t2)]}\nUW:{nbins_uw[(t1, t2)]}\nLN:{nbins_ln[(t1, t2)]}\nJT:{nbins_jt[(t1, t2)]}", 
            ha='center', va='center')

for market in ["UN", "UW", "LN", "JT"]:
    plot_mfht_grid(market)

In [None]:
def get_max_values(market):
    maxs = np.zeros(100)
    # Iterate over the indices and axes
    for i, (t1, t2) in enumerate(LEVELS):
        # Create the strings for the threshold values
        t1_string = f.stringify_threshold(t1)
        t2_string = f.stringify_threshold(t2)

        # Load the dataframe from the database if it exists
        try:
            df = pd.read_pickle(f'../data/processed/trapezoidal_selection/mfht_{market}_{t1_string}_{t2_string}.pkl')
        except FileNotFoundError as e:
            print(f"File 'mfht_{market}_{t1_string}_{t2_string}.pkl' not found")
            continue
        else:
            maxs[i] = df['mean'].max()
            
    maxs = maxs.reshape((10, 10))
    return maxs


def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = mcolors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap


def add_annotations(ax, peaks):
    for i in range(10):
        for j in range(10):
            text = ax.text(j, i, round(peaks[i, j],1), 
                           ha="center", va="center", color="white", fontsize=9)
    return text


def plot_mfht_peaks(market, peaks, min_value, max_value, ax=None, annotations=True):

    if ax is None:
        fig, ax = plt.subplots(figsize=(16, 18))
        
    ax.set_title(r"\emph{" +  market + r"}", fontsize=18)
        
    im = ax.imshow(peaks, cmap=truncate_colormap(plt.get_cmap('inferno'), 0.1, 1.0), 
                   norm = mcolors.LogNorm(vmin=min_value, vmax=max_value)
                   )
    
    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(10), labels=reversed(DELTAS), fontsize=11, rotation=45, ha='right', rotation_mode='anchor')
    ax.set_yticks(np.arange(10), labels=START_LEVELS, fontsize=11)
    
    # Set the grid
    ax.set_xticks(np.arange(-.5, 10, 1), minor=True)
    ax.set_yticks(np.arange(-.5, 10, 1), minor=True)
        
    # Remove the grid lines in the middle of the cells
    ax.grid(False)
    ax.grid(which='minor', color='w', linestyle='-', linewidth=1)
    
    # Loop over data dimensions and create text annotations.
    if annotations:
        annotations_text = add_annotations(ax, peaks)
                    
    if ax is None:
        plt.show()
    else:
        return ax, im

In [None]:
# Use LaTeX for text rendering
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'

peaks = {market: get_max_values(market) for market in ["UN", "UW", "LN", "JT"]}
max_value, min_value = max(v.max() for v in peaks.values()), min(v.min() for v in peaks.values())
print(max_value, min_value)

fig, axs = plt.subplots(1, 4, figsize=(16, 3.5))

for market, ax in zip(MARKETS, axs.flatten()):
    ax, im = plot_mfht_peaks(market, peaks[market], min_value, max_value, ax, annotations=False)

# Add the y-axis and x-axis labels
axs[0].set_ylabel(r"$\theta_i$", fontsize=16)
axs[0].set_xlabel(r"$d$", fontsize=16, x=-0.2, labelpad=-20)

# Add a colorbar
cbar = fig.colorbar(im, ax=axs.ravel().tolist(), pad=0.01)
cbar.set_label('Maximum MFHT', rotation=270, labelpad=15)

# fig.tight_layout()

plt.show()

In [None]:
fig.savefig(f'../visualization/mfhts/max_MFHT_comparison.png', bbox_inches='tight')
fig.savefig(f'../visualization/mfhts/max_MFHT_comparison.eps', bbox_inches='tight')
fig.savefig(f'../visualization/mfhts/max_MFHT_comparison.pdf', bbox_inches='tight')