In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
data = pd.read_csv('./data/correlation_matrix.csv', index_col=0)

In [4]:
feature_type_names = {
    'age': 'numeric', 
    'job': 'category',
    'marital': 'category', 
    'education': 'category', 
    'default': 'category',
    'balance': 'numeric', 
    'housing': 'category',
    'loan': 'category',
    'contact': 'category', 
    'day': 'numeric',
    'month': 'category',
    'duration': 'numeric', 
    'campaign': 'numeric',
    'pdays': 'numeric',
    'previous': 'numeric',
    'poutcome': 'category',
}

In [5]:
def plot_correlation_table(data, fig=None, width=500, height=500, colorscale='YlGnBu', reverse_font_color=False):
    if fig is None:
        fig = go.Figure()

    x = data.columns.values.tolist()
    y = data.index.values.tolist()
    z = np.round(data.values, 2).tolist()

    fig.add_trace(
        go.Heatmap(
            x=x,
            y=y,
            z=z,
            colorscale=colorscale,
            showscale=True,
            reversescale=False,
            zmin=0,
            zmax=1,
            hovertemplate="Corr ( %{x}, %{y} ) = %{z} <extra></extra>"
        )
    )

    annotation_text = z
    annotations = []
    for n, row in enumerate(annotation_text):
        for m, val in enumerate(row):
            font_color = "#000000" if (val < 0.5) ^ reverse_font_color else "#FFFFFF"
            annotations.append(
                go.layout.Annotation(
                    text=str(annotation_text[n][m]),
                    x=x[m],
                    y=y[n],
                    xref="x1",
                    yref="y1",
                    font=dict(color=font_color),
                    showarrow=False,
                )
            )


    fig.update_layout(
        title_text='Correlation',
        annotations=annotations,
        xaxis = dict(scaleanchor = 'y', constrain='domain'),
#         xaxis = dict(scaleanchor = 'y', constrain='range'),
        width=width,
        height=height,
        template='plotly_white',
    )
    
    fig.update_yaxes(autorange='reversed')
    fig.update_xaxes(tickangle=-45)

    return fig, annotations

In [6]:
def plot_correlation_matrix(data, num_cols=[], cat_cols=[], fig=None, width=500, height=500, colorscale='YlGnBu', reverse_font_color=False):
#     cols = num_cols + cat_cols
#     _data = {c: {c: np.nan for c in cols} for c in cols}
#     df = pd.DataFrame(_data)
    df = data.copy(deep=True)
    
    n_num = len(num_cols)
    n_cat = len(cat_cols)
    buttons_list = []
    
#     if n_num > 0:
#         num_corr = NumericCorr(method='pearson')(data, num_cols).abs()
#         df.loc[num_cols, num_cols] = num_corr
#     if n_cat > 0:
#         cat_corr = CramersVCorr()(data, cat_cols)
#         df.loc[cat_cols, cat_cols] = cat_corr
#     if n_num > 0 and n_cat > 0:
#         mix_corr = CorrelationRatio()(data, cat_cols, num_cols)
#         df.loc[num_cols, cat_cols] = mix_corr
#         df.loc[cat_cols, num_cols] = mix_corr.T
    
    
    ticktext = []
    for n in num_cols:
        ticktext.append(str(n)+' (N)')
    for n in cat_cols:
        ticktext.append(str(n)+' (C)')
        
    fig, annotations = plot_correlation_table(df, fig=fig, width=width, height=height, 
                                 colorscale=colorscale, 
                                 reverse_font_color=reverse_font_color)
    
    if n_num > 0 and n_cat > 0:
        line_config = dict(color="blue", width = 4)
        coor = [-0.5, n_num - 0.5, n_num + n_cat - 0.5]
        button_label = ['num-num', 'cat-num', 'num-cat', 'cat-cat']
        
        shapes = []
        buttons_list = [dict(label="None", method="relayout", args=["shapes",[]])]
        idx = 0
        for i in range(2):
            for j in range(2):
                shape_ = dict(type="rect", x0=coor[i], y0=coor[j], x1=coor[i+1], y1=coor[j+1],line=line_config)
                shapes.append(shape_)
                buttons_list.append(dict(label=button_label[idx], method="relayout", args=["shapes",[shape_]]))
                idx += 1
            
    
    fig.update_layout(
        margin=dict(t=150, b=0, l=270, r=280),
        title={
            'y':0.99,
            'x':0,
            'xanchor': 'left',
            'yanchor': 'top',
            'font': {'size': 22},
        },
        coloraxis_colorbar=dict(
            thicknessmode="pixels", thickness=50,
            lenmode="pixels", len=200,
            yanchor="top", y=1,
            ticks="outside", ticksuffix=" bills",
            dtick=5
        ),
        hoverlabel=dict(
            bgcolor="white",
            font_size=12,
            font_family="Arial",
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(238,238,238,1)',
        modebar=dict(
            bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68, 0.7)', color='rgba(68, 68, 68, 0.3)',)
    )
                
    fig.update_layout(
    updatemenus=[
        dict(
            type="buttons",
            direction='right',
            x=0.1,
            xanchor="left",
            y=1,
            yanchor="top",
            pad={"r": 0, "t": -100, 'l': -60},
            buttons=buttons_list,
            bgcolor='rgba(255,255,255,1)',
        ),
        dict(
            type = "buttons",
            direction = "right",
            buttons=list([
                dict(
                    args=[{"annotations": []}],
                    args2=[{"annotations": annotations}],
                    label="Show value",
                    method="relayout"
                )
            ]),
            active=1,
            x=0.1,
            xanchor="left",
            y=1,
            yanchor="top",
            pad={"r": 0, "t": -50, 'l': -60},
            bgcolor='rgba(255,255,255,1)',
        ),
    ]
    )

    fig.update_yaxes(
        tickmode='array',
        tickvals=np.arange(df.shape[1]),
        ticktext=ticktext,
        fixedrange=True
    )
    fig.update_xaxes(fixedrange=True)
    
    return fig

In [7]:
num_cols = [k for k, v in feature_type_names.items() if v == 'numeric']
cat_cols = [k for k, v in feature_type_names.items() if v == 'category']

fig = plot_correlation_matrix(data, num_cols=num_cols, cat_cols=cat_cols, fig=go.Figure(), width=1118, height=766, colorscale='YlGnBu', reverse_font_color=False)
fig.show(config={'displaylogo':False})
# fig.write_html('./automl_plot/corr.html', config={'displaylogo':False}, include_plotlyjs='cdn', full_html=False)