In [None]:
import pandas as pd
import ipywidgets as widgets
from ipywidgets import VBox, HBox
import plotly.graph_objs as go
import numpy as np



class MarketArbitrageReport:
    
    
    def __init__(self, model_arbitrage_df):
        self.df=model_arbitrage_df
        self.df['Adjusted_call_butterfly']=self.df['Call_butterfly']*self.df['Multiplier']
        self.df['Adjusted_put_butterfly']=self.df['Put_butterfly']*self.df['Multiplier']
        self.tolerance=1.0e-10
        self.tolerance_criteria=VBox([widgets.HTMLMath(value=r'Tolerance level = accepted numerical error')])

        
    def aggregated_results(self):
        self.agg_df=pd.DataFrame({'min_call_difference':self.df.groupby(['DATE','CURRENCY','TENOR', 'EXPIRY'])[['Call_difference']].apply(lambda df: pd.eval('-1*df.Call_difference').min())})
        self.agg_df['min_put_difference']=pd.DataFrame({'min_put_difference':self.df.groupby(['DATE','CURRENCY','TENOR', 'EXPIRY'])[['Put_difference']].apply(lambda df: pd.eval('df.Put_difference').min())})
        self.agg_df['min_call_butterfly']=pd.DataFrame({'min_call_butterfly':self.df.groupby(['DATE','CURRENCY','TENOR', 'EXPIRY'])[['Adjusted_call_butterfly']].apply(lambda df: pd.eval('df.Adjusted_call_butterfly').min())})
        self.agg_df['min_put_butterfly']=pd.DataFrame({'min_put_butterfly':self.df.groupby(['DATE','CURRENCY','TENOR', 'EXPIRY'])[['Adjusted_put_butterfly']].apply(lambda df: pd.eval('df.Adjusted_put_butterfly').min())})
        
        self.agg_df=self.agg_df.reset_index(drop=False)
        self.agg_df['TENOR'] = pd.Categorical(self.agg_df['TENOR'], categories=[f'{i}Y' for i in range(1,50)])
        self.agg_df['EXPIRY'] = pd.Categorical(self.agg_df['EXPIRY'], categories=[f'{i}M' for i in range(1,12)]+['1Y']+[f'{i}M' for i in range(13,24)]+[f'{i}Y' for i in range(2,50)])
        self.agg_df=self.agg_df.sort_values(by=['DATE','CURRENCY','TENOR', 'EXPIRY'])
        
        return self.agg_df
    
    
    def aggregated_results_heatmap(self, currency, date):
        call_difference_df = self.agg_df.loc[(self.agg_df['DATE']==date) & (self.agg_df['CURRENCY']==currency)].set_index(['TENOR','EXPIRY'])['min_call_difference'].unstack()
        call_difference_df = call_difference_df.reindex(call_difference_df.columns.sort_values(), axis=1)
        call_difference_df = call_difference_df.reindex(call_difference_df.index.sort_values(), axis=0)
        x=call_difference_df.columns.tolist()
        y=call_difference_df.index.tolist()
        call_difference=call_difference_df.values
        
        put_difference_df = self.agg_df.loc[(self.agg_df['DATE']==date) & (self.agg_df['CURRENCY']==currency)].set_index(['TENOR','EXPIRY'])['min_put_difference'].unstack()
        put_difference_df = put_difference_df.reindex(put_difference_df.columns.sort_values(), axis=1)
        put_difference_df = put_difference_df.reindex(put_difference_df.index.sort_values(), axis=0)
        put_difference=put_difference_df.values
        
        call_butterfly_df = self.agg_df.loc[(self.agg_df['DATE']==date) & (self.agg_df['CURRENCY']==currency)].set_index(['TENOR','EXPIRY'])['min_call_butterfly'].unstack()
        call_butterfly_df = call_butterfly_df.reindex(call_butterfly_df.columns.sort_values(), axis=1)
        call_butterfly_df = call_butterfly_df.reindex(call_butterfly_df.index.sort_values(), axis=0)
        call_butterfly=call_butterfly_df.values
        
        put_butterfly_df = self.agg_df.loc[(self.agg_df['DATE']==date) & (self.agg_df['CURRENCY']==currency)].set_index(['TENOR','EXPIRY'])['min_put_butterfly'].unstack()
        put_butterfly_df = put_butterfly_df.reindex(put_butterfly_df.columns.sort_values(), axis=1)
        put_butterfly_df = put_butterfly_df.reindex(put_butterfly_df.index.sort_values(), axis=0)
        put_butterfly=put_butterfly_df.values
        
        z=[[np.nan if np.isnan(call_difference[i][j]) else ((call_difference[i][j]<-self.tolerance) or (put_difference[i][j]<-self.tolerance) 
                                                            or (call_butterfly[i][j]<-self.tolerance/(2**0.5)) or 
                                                            (put_butterfly[i][j]<-self.tolerance/(2**0.5)))*1 for j in range(len(call_difference[:][i]))] for i in range(len(call_difference))]
        
        smiles_total=np.count_nonzero(~np.isnan(z))
        smiles_failed=int(np.nansum(z))
    
        self.fig=go.FigureWidget([go.Heatmap(z=z,x=x,y=y, name = '',xgap = 1, ygap = 1,zmin=0,zmax=1,
                                             colorbar = dict(nticks=2, 
                                                             tickmode="array", 
                                                             tickvals=[0,1], 
                                                             ticktext=[f'No arbitrage: {(smiles_total-smiles_failed)}/{smiles_total}',
                                                                       f'Arbitrage: {smiles_failed}/{smiles_total}'], 
                                                             tickfont={"size":15}), colorscale=[[0.0, "rgb(0,0,200)"], 
                                                                                                [1.0, "rgb(200,0,0)"]])])
        self.fig['layout'].update(height=600, width=750, title='Success/failure map')
        self.fig.update_yaxes(title_text='Tenor')
        self.fig.update_xaxes(title_text='Expiry')
        
        self.fig.data[0].on_click(self.select_smile_on_heatmap)
        
        return self.fig  
    
    
    def plot_selected_smile(self, currency, date, tenor, expiry, value):
        
        selected_smile=self.df.loc[(self.df['CURRENCY']==currency) & (self.df['DATE']==date) & (self.df['TENOR']==tenor) & (self.df['EXPIRY']==expiry)].sort_values(by=['MONEYNESS'])
        strike=selected_smile['STRIKE'].tolist()
        if value=='difference':
            name1=r'$C(K_{i})-C(K_{i+1})$'
            name2=r'$P(K_{i+1})-P(K_{i})$'
            mult=1.
            title=r'$C(K_{i})-C(K_{i+1}), \, P(K_{i+1})-P(K_{i})$'
            y1=(-selected_smile['Call_difference']).tolist()
            y2=selected_smile['Put_difference'].tolist()
            
        elif value=='butterfly':
            name1='Call butterfly'
            name2='Put butterfly'
            mult=1/(2**0.5)
            title='Adjusted call/put butterflies'
            y1=selected_smile['Adjusted_call_butterfly']
            y2=selected_smile['Adjusted_put_butterfly']
        
        data=[go.Scatter(x=strike, y=y1, name=name1,mode='lines', line=dict(color='blue', width=5)), 
              go.Scatter(x=strike, y=y2, name=name2,mode='lines', line=dict(color='orange', width=5)), 
              go.Scatter(x=strike, y=[-mult*self.tolerance]*len(strike), name='Threshold',mode='lines', line=dict(color='red', width=5, dash='dash' ))]
    
        fig = go.FigureWidget(data=data)
        fig['layout'].update(height=300, width=750, title=title)
        fig.update_xaxes(title_text='Strike, %')

        return fig

     
    def display_plots(self):
        self.agg_df=self.aggregated_results()
        self.currencies=self.agg_df['CURRENCY'].unique().tolist()
        tenor_sorting=[f'{i}Y' for i in range(1,50)]
        expiry_sorting=[f'{i}M' for i in range(1,12)]+['1Y']+[f'{i}M' for i in range(13,24)]+[f'{i}Y' for i in range(2,50)]
        self.active_tenor='5Y'
        self.active_expiry='5Y'

        
        # Tolerance slider      
        self.slider_wdg_tolerance=widgets.FloatLogSlider(value=self.tolerance,base=10,min=-20, max=-4, step=1, description='Tolerance', disabled=False, 
                                                         continuous_update=False,orientation='horizontal',readout=True)
        self.slider_wdg_tolerance.observe(lambda x: self.update_tolerance_level(x['new']), names = 'value')

        # Currency tabs
        tab_list=[]
        for currency in self.currencies:
            dates=pd.to_datetime(self.agg_df.loc[self.agg_df['CURRENCY']==currency]['DATE'].unique().tolist())
            tenors=sorted(self.agg_df.loc[self.agg_df['CURRENCY']==currency]['TENOR'].unique().tolist(), key=lambda x: tenor_sorting.index(x))
            expiries=sorted(self.agg_df.loc[self.agg_df['CURRENCY']==currency]['EXPIRY'].unique().tolist(), key=lambda x: expiry_sorting.index(x))
            tab_list.append(VBox([widgets.SelectionSlider(description='Date', options=dates, continuous_update=False, layout={'width': '90%'}),
                                  self.aggregated_results_heatmap(currency,dates[0]),
                                  HBox([widgets.Dropdown(description='Tenor', options=tenors, value=self.active_tenor), widgets.Dropdown(description='Expiry', options=expiries, value=self.active_expiry)]),
                                  self.plot_selected_smile(currency, dates[0],  self.active_tenor,  self.active_expiry, 'difference'), 
                                  self.plot_selected_smile(currency, dates[0],  self.active_tenor,  self.active_expiry, 'butterfly')]))
        
        self.tabs = widgets.Tab(children=tab_list)     
            
        for currency in self.currencies:
            self.tabs.set_title(self.currencies.index(currency), currency)
            self.tabs.children[self.currencies.index(currency)].children[0].observe(lambda x, currency=currency: self.update_date(currency,x['new']), names = 'value')
            self.tabs.children[self.currencies.index(currency)].children[2].children[0].observe(lambda x, currency=currency: self.update_tenor(currency,x['new']), names = 'value')
            self.tabs.children[self.currencies.index(currency)].children[2].children[1].observe(lambda x, currency=currency: self.update_expiry(currency,x['new']), names = 'value')
        
        self.report_page=HBox([self.tabs, VBox([self.slider_wdg_tolerance, self.tolerance_criteria])])
        
        return self.report_page
        
        
    def update_date(self, currency, value):
        tab_index=self.currencies.index(currency)
        date=value
        tenor=self.tabs.children[tab_index].children[2].children[0].value
        expiry=self.tabs.children[tab_index].children[2].children[1].value
        self.update_plots(currency, date, tenor, expiry, 1)
           
        
    def update_tolerance_level(self, value):
        self.tolerance=value
        for currency in self.currencies:
            date=self.tabs.children[self.currencies.index(currency)].children[0].value
            tenor=self.tabs.children[self.currencies.index(currency)].children[2].children[0].value
            expiry=self.tabs.children[self.currencies.index(currency)].children[2].children[1].value
            self.update_plots(currency, date, tenor, expiry, 0)
    
    
    def update_tenor(self, currency, value):
        tab_index=self.currencies.index(currency)
        date=self.tabs.children[tab_index].children[0].value
        tenor=value
        expiry=self.tabs.children[tab_index].children[2].children[1].value
        self.update_plots(currency, date, tenor, expiry,2)
        
        
    def update_expiry(self, currency, value):
        tab_index=self.currencies.index(currency)
        date=self.tabs.children[tab_index].children[0].value
        tenor=self.tabs.children[tab_index].children[2].children[0].value
        expiry=value
        self.update_plots(currency, date, tenor, expiry,2)
    
    
    def select_smile_on_heatmap(self, trace, points, state):
        expiry=points.xs[0]
        tenor=points.ys[0]
        tab_index=self.tabs.selected_index
        self.tabs.children[tab_index].children[2].children[0].value, self.tabs.children[tab_index].children[2].children[1].value=tenor, expiry      
    
    
    def update_plots(self, currency, date, tenor, expiry, flag):
        tab_index=self.currencies.index(currency)
        
        if flag==0:

            self.tabs.children[tab_index].children[1].data[0].z=self.aggregated_results_heatmap(currency, date).data[0].z
            self.tabs.children[tab_index].children[1].data[0].colorbar=self.aggregated_results_heatmap(currency, date).data[0].colorbar
        
            self.tabs.children[tab_index].children[3].data[2].y=self.plot_selected_smile(currency, date, tenor, expiry, 'difference').data[2].y
            self.tabs.children[tab_index].children[4].data[2].y=self.plot_selected_smile(currency, date, tenor, expiry, 'butterfly').data[2].y

                
        elif flag==1:

            self.tabs.children[tab_index].children[1].data[0].x=self.aggregated_results_heatmap(currency, date).data[0].x
            self.tabs.children[tab_index].children[1].data[0].y=self.aggregated_results_heatmap(currency, date).data[0].y
            self.tabs.children[tab_index].children[1].data[0].z=self.aggregated_results_heatmap(currency, date).data[0].z
            self.tabs.children[tab_index].children[1].data[0].colorbar=self.aggregated_results_heatmap(currency, date).data[0].colorbar

            self.tabs.children[tab_index].children[3].data=[]
            for j in range(3):
                self.tabs.children[tab_index].children[3].add_trace(self.plot_selected_smile(currency, date, tenor, expiry, 'difference').data[j])
            self.tabs.children[tab_index].children[4].data=[]
            for j in range(3):
                self.tabs.children[tab_index].children[4].add_trace(self.plot_selected_smile(currency, date, tenor, expiry, 'butterfly').data[j])

                
        elif flag==2:

            self.tabs.children[tab_index].children[3].data=[]
            for j in range(3):
                self.tabs.children[tab_index].children[3].add_trace(self.plot_selected_smile(currency, date, tenor, expiry, 'difference').data[j])
            self.tabs.children[tab_index].children[4].data=[]
            for j in range(3):
                self.tabs.children[tab_index].children[4].add_trace(self.plot_selected_smile(currency, date, tenor, expiry, 'butterfly').data[j])

                

    