In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
data = pd.read_csv('./data/correlation_matrix.csv', index_col=0)

In [4]:
feature_type_names = {
    'age': 'numeric', 
    'job': 'category',
    'marital': 'category', 
    'education': 'category', 
    'default': 'category',
    'balance': 'numeric', 
    'housing': 'category',
    'loan': 'category',
    'contact': 'category', 
    'day': 'numeric',
    'month': 'category',
    'duration': 'numeric', 
    'campaign': 'numeric',
    'pdays': 'numeric',
    'previous': 'numeric',
    'poutcome': 'category',
}

In [5]:
# create a mixed correlation matrix
# cols = num_cols + cat_cols
# _data = {c: {c: np.nan for c in cols} for c in cols}
# df = pd.DataFrame(_data)

# if n_num > 0:
#     num_corr = NumericCorr(method='pearson')(data, num_cols).abs()
#     df.loc[num_cols, num_cols] = num_corr
# if n_cat > 0:
#     cat_corr = CramersVCorr()(data, cat_cols)
#     df.loc[cat_cols, cat_cols] = cat_corr
# if n_num > 0 and n_cat > 0:
#     mix_corr = CorrelationRatio()(data, cat_cols, num_cols)
#     df.loc[num_cols, cat_cols] = mix_corr
#     df.loc[cat_cols, num_cols] = mix_corr.T

### Correlation matrix

### This plot shows
- Use heatmap to show correlation matrix.
- Use two buttons, one decides whether to show annotation or not, the other one shows rectangle on the plot for specific type of features.
- Fix the axes range. User cannot zoom in, zoom out, or select any region in the plot.

#### Main trace
- Heatmap
```python
trace = go.Heatmap(
    x=np.arange(len(data.columns)),
    y=np.arange(len(data.index)),
    z=data.values,
    zmin=0,
    zmax=1,
    hovertemplate="Corr ( %{x}, %{y} ) = %{z} <extra></extra>",
    ...
)
```

#### Layout: shapes
```python
shape = dict(
    type="rect", 
    x0=x0, y0=y0, x1=x1, y1=y1,
    line=dict(color="blue", width = 4), 
    layer='above'
)
```
- layer: `below` and `above`. This means whether shapes are drawn below or above traces.

#### Layout: annotation
We need to create a list of dictionary or instances of go.layout.Annotation that represent each element in the correlation matrix.<br>
For each annotation, the font color is controled by the size of the element in the matrix.

In [6]:
def gen_annotations(data, reverse_font_color=False):
    x = np.arange(data.shape[1])
    y = np.arange(data.shape[0])
    annotation_text = np.round(data.values, 2).tolist()
    annotations = []
    
    for n, row in enumerate(annotation_text):
        for m, val in enumerate(row):
            font_color = "#000000" if (val < 0.5) ^ reverse_font_color else "#FFFFFF"
            annotations.append(
                go.layout.Annotation(
                    text=str(annotation_text[n][m]),
                    x=x[m],
                    y=y[n],
                    xref="x1",
                    yref="y1",
                    font=dict(color=font_color),
                    showarrow=False,
                )
            )
    return annotations

In [7]:
def plot_correlation_table(data, fig=None, width=500, height=500, colorscale='YlGnBu', reverse_font_color=False):
    if fig is None:
        fig = go.Figure()

    x = data.columns.values.tolist()
    y = data.index.values.tolist()
    z = np.round(data.values, 2).tolist()

    fig.add_trace(
        go.Heatmap(
            x=np.arange(len(x)),
            y=np.arange(len(y)),
            z=z,
            colorscale=colorscale,
            showscale=True,
            reversescale=False,
            zmin=0,
            zmax=1,
            hovertemplate="Corr ( %{x}, %{y} ) = %{z} <extra></extra>"
        )
    )

    fig.update_layout(
        xaxis = dict(scaleanchor = 'y', constrain='domain'),
        width=width,
        height=height,
    )
    
    fig.update_yaxes(autorange='reversed', tickvals=np.arange(len(y)), ticktext=y)
    fig.update_xaxes(tickangle=-45, tickvals=np.arange(len(x)), ticktext=x)

    return fig

In [8]:
def plot_correlation_matrix(data, num_cols=[], cat_cols=[], fig=None, width=500, height=500, colorscale='YlGnBu', reverse_font_color=False):
    df = data.copy(deep=True)
    
    n_num = len(num_cols)
    n_cat = len(cat_cols)
    buttons_list = []
    
    ticktext = []
    for n in num_cols:
        ticktext.append(str(n)+' (N)  ')
    for n in cat_cols:
        ticktext.append(str(n)+' (C)  ')
        
    fig = plot_correlation_table(df, fig=fig, width=width, height=height, colorscale=colorscale)
    annotations = gen_annotations(data, reverse_font_color=False)
    
    
    if n_num > 0 and n_cat > 0:
        line_config = dict(color="blue", width = 4)
        coor = [-0.5, n_num - 0.5, n_num + n_cat - 0.5]
        button_label = ['numeric-numeric', 'categorical-numeric', 'numeric-categorical', 'categorical-categorical']
        
        shapes = []
        buttons_list.append(dict(label="None", method="relayout", args=[{"shapes": []}]))
        idx = 0
        for i in range(2):
            for j in range(2):
                shape_ = dict(type="rect", x0=coor[i], y0=coor[j], x1=coor[i+1], y1=coor[j+1] ,line=line_config, layer='above')
                shapes.append(shape_)
                buttons_list.append(dict(label=button_label[idx], method="relayout", args=[{"shapes": [shape_]}]))
                idx += 1
    
    fig.update_layout(
        title={
            'text': 'Correlation',
            'y':0.99,
            'x':0,
            'xanchor': 'left',
            'yanchor': 'top',
            'font': {'size': 22},
        },
        margin=dict(t=150, b=0, l=270, r=280),
        annotations=annotations,
        coloraxis_colorbar=dict(
            thicknessmode="pixels", thickness=50,
            lenmode="pixels",# len=200,
            yanchor="top", y=1,
            ticks="outside",
            dtick=5
        ),
        hoverlabel=dict(
            bgcolor="white",
            font_size=12,
            font_family="Arial",
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(238,238,238,1)',
        modebar=dict(bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68,0.7)', color='rgba(68,68,68,0.3)'),
    )
                
    fig.update_layout(
    updatemenus=[
        dict(
            type="dropdown",
            direction='down',
            buttons=buttons_list,
            active=0,
            x=0.1, xanchor="left",
            y=1, yanchor="top",
            pad={"r": 0, "t": -50, 'l': -60},
            bgcolor='rgba(255,255,255,1)',
        ),
        dict(
            type = "buttons",
            direction = "right",
            buttons=list([
                dict(
                    args=[{"annotations": annotations}],
                    args2=[{"annotations": []}],
                    label="Show value",
                    method="relayout"
                )
            ]),
            active=0,
            x=0.1, xanchor="left",
            y=1, yanchor="top",
            pad={"r": 0, "t": -100, 'l': -60},
            bgcolor='rgba(255,255,255,1)',
        ),
    ]
    )

    fig.update_yaxes(
        tickmode='array',
        tickvals=np.arange(df.shape[1]),
        ticktext=ticktext,
        fixedrange=True
    )
    fig.update_xaxes(fixedrange=True)
    
    return fig

In [9]:
num_cols = [k for k, v in feature_type_names.items() if v == 'numeric']
cat_cols = [k for k, v in feature_type_names.items() if v == 'category']

fig = plot_correlation_matrix(data, num_cols=num_cols, cat_cols=cat_cols, fig=go.Figure(), width=1118, height=766, colorscale='YlGnBu', reverse_font_color=False)
fig.show(config={'displaylogo':False})
# fig.write_html('./automl_plot/corr.html', config={'displaylogo':False}, include_plotlyjs='cdn', full_html=False)

### Correlation matrix with filter

#### Layout: buttons
One button decides whether to show annotation or not, the other one is used to filter different types of correlation matrix.<br>

- Dropdown button: filter different types of correlation matrix<br>
In a mixing correlation matrix, we combine four different correlation matrix.<br>

    - Numeric-Numeric: Pearson correlation
    - Categorical-Categorical: Cramer's V correlation.
    - Categorical-Numeric: Ratio correlation.
    - Numeric-Categorical: Ratio correlation.
    
    
- Toggle button: showing annotation values<br>
For each correlation matrix, we need to create a list of annotation properties. Hence, we have to update `Show value` button when different correlation matrix is filtered.<br>
In plotly, buttons are defined in `updatemenus`. `updatemenus` itself can be updated through `updatemenus`.
```python
dict(args=[{'x': [np.arange(len(df.columns))], 
            'y': [np.arange(len(df.index))], 
            'z': [df.values]}, 
           {'annotations': gen_annotations(df, reverse_font_color=False),
            'updatemenus': [button_annotation(df),],  # update button through updatemenus
            'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(df.columns)), 'ticktext': df.columns},
            'yaxis': {'autorange':'reversed', 'fixedrange': True, 
                      'tickvals': np.arange(len(df.index)), 'ticktext': df.index},
           }],
     label="All", 
     method="update")
```

In [10]:
def button_annotation(data):
    r'''
    Create a "Show value" button for different data set.
    '''
    return dict(type = "buttons", 
                direction = "right",
                buttons=list([
                    dict(
                        args=[{"annotations": gen_annotations(data, reverse_font_color=False)}],
                        args2=[{"annotations": []}],
                        label="Show value",
                        method="relayout"
                    )
                ]),
                active=0,
                x=0.1, xanchor="left",
                y=1, yanchor="top",
                pad={"r": 0, "t": -250, 'l': -387},
                bgcolor='rgba(255,255,255,1)')

In [11]:
def tick_color(color, text):
    r'''
    This function is used to change tick text color.
    '''
    s = '$\color{' + str(color) + '}{' + str(text) + '}$'
    return s

In [12]:
def plot_correlation_matrix_filter(data, num_cols=[], cat_cols=[], fig=None, width=500, height=500, colorscale='YlGnBu', reverse_font_color=False):
    df = data.copy(deep=True)
    
    n_num = len(num_cols)
    n_cat = len(cat_cols)
    buttons_list = []
    
    fig = plot_correlation_table(df, fig=fig, width=width, height=height, colorscale=colorscale)
    annotations = gen_annotations(data, reverse_font_color=False)
    
    
    if n_num > 0 and n_cat > 0:
        num_df = df.loc[num_cols, num_cols]
        cat_df = df.loc[cat_cols, cat_cols]
        eta_df = df.loc[cat_cols, num_cols]
        
        buttons_list.append(dict(args=[{'x': [np.arange(len(df.columns))], 'y': [np.arange(len(df.index))], 'z': [df.values]}, 
                                       {'annotations': gen_annotations(df, reverse_font_color=False),
                                        'updatemenus': [button_annotation(df),],
                                        'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(df.columns)), 'ticktext': df.columns},
                                        'yaxis': {'autorange':'reversed', 'fixedrange': True,
                                                  'tickvals': np.arange(len(df.index)), 'ticktext': df.index},
                                       }], 
                                 label="All", method="update"))
        
        buttons_list.append(dict(args=[{'x': [np.arange(len(num_df.columns))], 'y': [np.arange(len(num_df.index))], 'z': [num_df.values]}, 
                                       {'annotations': gen_annotations(num_df, reverse_font_color=False),
                                        'updatemenus': [button_annotation(num_df),],
                                        'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(num_df.columns)), 'ticktext': num_df.columns},
                                        'yaxis': {'autorange':'reversed', 'fixedrange': True,
                                                  'tickvals': np.arange(len(num_df.index)), 'ticktext': num_df.index},
                                       }], 
                                 label="Numeric-Numeric", method="update"))
        
        buttons_list.append(dict(args=[{'x': [np.arange(len(cat_df.columns))], 'y': [np.arange(len(cat_df.index))], 'z': [cat_df.values]}, 
                                       {'annotations': gen_annotations(cat_df, reverse_font_color=False),
                                        'updatemenus': [button_annotation(cat_df),],
                                        'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(cat_df.columns)), 'ticktext': cat_df.columns},
                                        'yaxis': {'autorange':'reversed', 'fixedrange': True,
                                                  'tickvals': np.arange(len(cat_df.index)), 'ticktext': cat_df.index},
                                       }], 
                                 label="Categorical-Categorical", method="update"))
        
        buttons_list.append(dict(args=[{'x': [np.arange(len(eta_df.columns))], 'y': [np.arange(len(eta_df.index))], 'z': [eta_df.values]}, 
                                       {'annotations': gen_annotations(eta_df, reverse_font_color=False),
                                        'updatemenus': [button_annotation(eta_df),],
                                        'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(eta_df.columns)), 'ticktext': eta_df.columns},
                                        'yaxis': {'autorange':'reversed', 'fixedrange': True,
                                                  'tickvals': np.arange(len(eta_df.index)), 'ticktext': eta_df.index},
                                       }], 
                                 label="Categorical-Numeric", method="update"))
        
        buttons_list.append(dict(args=[{'x': [np.arange(len(eta_df.T.columns))], 'y': [np.arange(len(eta_df.T.index))], 'z': [eta_df.T.values]}, 
                                       {'annotations': gen_annotations(eta_df.T, reverse_font_color=False),
                                        'updatemenus': [button_annotation(eta_df.T),],
                                        'xaxis': {'fixedrange': True, 'tickvals': np.arange(len(eta_df.T.columns)), 'ticktext': eta_df.T.columns},
                                        'yaxis': {'autorange':'reversed', 'fixedrange': True,
                                                  'tickvals': np.arange(len(eta_df.T.index)), 'ticktext': eta_df.T.index},
                                       }], 
                                 label="Numeric-Categorical", method="update"))
        
    
    fig.update_layout(
        margin=dict(t=150, b=0, l=270, r=280),
        title={
            'text': 'Correlation',
            'y':0.99,
            'x':0,
            'xanchor': 'left',
            'yanchor': 'top',
            'font': {'size': 22},
        },
        annotations=annotations,
        coloraxis_colorbar=dict(
            thicknessmode="pixels", thickness=50,
            lenmode="pixels", len=200,
            yanchor="top", y=1,
            ticks="outside",
            dtick=5
        ),
        hoverlabel=dict(
            bgcolor="white",
            font_size=12,
            font_family="Arial",
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(238,238,238,0)',
        modebar=dict(bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68,0.7)', color='rgba(68,68,68,0.3)'),
    )
                
    fig.update_layout(
    updatemenus=[
        dict(
            type="dropdown",
            direction='down',
            buttons=buttons_list,
            active=0,
            x=0.1, xanchor="left",
            y=1, yanchor="top",
            pad={"r": 0, "t": -50, 'l': -60},
            bgcolor='rgba(255,255,255,1)',
        ),
        dict(
            type = "buttons",
            direction = "right",
            buttons=list([
                dict(
                    args=[{"annotations": annotations}],
                    args2=[{"annotations": []}],
                    label="Show value",
                    method="relayout"
                )
            ]),
            active=0,
            x=0.1, xanchor="left",
            y=1, yanchor="top",
            pad={"r": 0, "t": -100, 'l': -60},
            bgcolor='rgba(255,255,255,1)',
        ),
    ]
    )

    fig.update_yaxes(
        tickmode='array',
        ticksuffix='   ',
        fixedrange=True
    )
    fig.update_xaxes(fixedrange=True)
    
    return fig

In [13]:
num_cols = [k for k, v in feature_type_names.items() if v == 'numeric']
cat_cols = [k for k, v in feature_type_names.items() if v == 'category']

fig = plot_correlation_matrix_filter(data, num_cols=num_cols, cat_cols=cat_cols, fig=go.Figure(), width=1118, height=766, colorscale='YlGnBu', reverse_font_color=False)
fig.show(config={'displaylogo':False})
# fig.write_html('./automl_plot/corr.html', config={'displaylogo':False}, include_plotlyjs='cdn', full_html=False)