In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
data = pd.read_csv('./data/bank.csv')

In [4]:
feature_type_names = {
    'age': 'numeric', 
    'job': 'category',
    'marital': 'category', 
    'education': 'category', 
    'default': 'category',
    'balance': 'numeric', 
    'housing': 'category',
    'loan': 'category',
    'contact': 'category', 
    'day': 'numeric',
    'month': 'category',
    'duration': 'numeric', 
    'campaign': 'numeric',
    'pdays': 'numeric',
    'previous': 'numeric',
    'poutcome': 'category',
}

In [5]:
def plot_scatter_matrix(data, cols, types, height=800, width=800):
    ncol = len(cols)

    fig = make_subplots(rows=ncol, cols=ncol, shared_yaxes=False, shared_xaxes=False, 
                        horizontal_spacing=0.01,
                        vertical_spacing=0.01,
                       )

    for i in range(ncol):
        for j in range(ncol):
            if i == j:
                if types[j] == 'numeric':
                    fig.add_trace(go.Histogram(x=data[cols[j]], marker=dict(color='rgb(61,133,198)')),
                                 row=i+1, col=j+1)
                else:
                    fig.add_trace(go.Bar(x=data[cols[j]].value_counts().index, y=data[cols[j]].value_counts().values, marker=dict(color='rgb(61,133,198)')),
                                 row=i+1, col=j+1)
#             elif j>i:
#                 continue
            else:
                if types[j] == 'numeric' and types[i] == 'numeric':
                    fig.add_trace(go.Scatter(x=data[cols[j]], y=data[cols[i]], mode='markers', marker=dict(color='rgba(64,87,210,0.5)', size=3)),
                              row=i+1, col=j+1)
                elif types[j] == 'numeric' and types[i] == 'category':
                    fig.add_trace(go.Box(x=data[cols[j]], y=data[cols[i]], orientation='h', marker=dict(color='rgba(64,87,210,0.5)')),
                              row=i+1, col=j+1)
                elif types[j] == 'category' and types[i] == 'numeric':
                    fig.add_trace(go.Box(x=data[cols[j]], y=data[cols[i]], orientation='v', marker=dict(color='rgba(64,87,210,0.5)')),
                              row=i+1, col=j+1)
                elif types[j] == 'category' and types[i] == 'category':
                    cross_tab = pd.crosstab(data[cols[j]], data[cols[i]])
                    fig.add_trace(go.Heatmap(x=cross_tab.columns, y=cross_tab.index, z=cross_tab.values, colorscale='YlGnBu', showscale=False),
                              row=i+1, col=j+1)

    fig.update_xaxes(showticklabels=False, fixedrange=False)
    fig.update_yaxes(showticklabels=False, fixedrange=False)

    fig.update_layout(height=height, width=width,
                      showlegend=False,
                      hovermode=False,
                      margin={'t': 50, 'l':100, 'r':0, 'b':100},
                      title=dict(
                        text='Pair Plot',
                        font=dict(size=22),
                        pad=dict(b=0, l=0, r=0, t=0),
                        xanchor='left',
                        yanchor='top',
                        y=0.99,
                        x=0,),
                      paper_bgcolor='rgba(0,0,0,0)',
                      plot_bgcolor='rgba(238,238,238,1)',
                      modebar=dict(
                          bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68, 0.7)', color='rgba(68, 68, 68, 0.3)',
                          remove=['zoomin', 'zoomout', 'lasso', 'select'],
                      ),
                      )

    for i, col in enumerate(cols):
        fig.update_xaxes(title_text=str(col), title_standoff=5000, row=ncol, col=i+1, automargin=False, showticklabels=True)
        fig.update_yaxes(title_text=str(col), title_standoff=10000, row=i+1, col=1, automargin=False, showticklabels=True)

    return fig

In [6]:
fig = plot_scatter_matrix(data.sample(3000), cols=['age', 'balance', 'duration', 'job'], types=['numeric', 'numeric', 'numeric', 'category'], height=800, width=800)
fig.show(config={'displaylogo': False})
# fig.write_html('./pair_plot.html', config={'displaylogo': False}, include_plotlyjs='cdn', full_html=False)