In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import AutoBins

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
data = pd.read_csv('./data/bank.csv')

### Pair plot
Pair plot is used to present the relationship between arbitrary two features in dataset.

### This plot shows
- Put multiple subplots in a plot that is like a n by n squared matrix.
- Non-diagonal:
    - Box plot is used to present the relationship between categorical feature and numeric feature.
    - Scatter plot is used to present the relationship between two numeric features.
    - Heatmap is used to present the relationship between two categorical features.
- Diagonal:
    - Bar plot is used to visualize categorical feature.
    - Histogram is used to visualize numeric feature's distribution.
- To avoid time consuming and output plot with larg size, we sample the input data.

#### Main trace
- Subplots
```python
from plotly.subplots import make_subplots
fig = make_subplots(
    rows=ncol, cols=ncol, 
    shared_yaxes=False, shared_xaxes=False, 
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)
```
When adding a subplot, we have to specify the position of this subplot.<br>
For example,
```python
fig.add_trace(trace, row=0, col=1)
```
This means add the trace at the `(first row, second column)` on the plot.


- Box plot
```python
trace = go.Box(
    x=data[cols[j]], 
    y=data[cols[i]], 
    orientation='h', 
    marker=dict(color='rgba(64,87,210,0.5)')
)
```

#### Layout: axes
In order to avoid tick label and axis title overlaping, we use title_standoff and automargin to control the distance between tick and axis title. First, we have to set `margin={'b': n px}` at the layout to create a margin at the bottom. If automargin is False and title_standoff exceeds the margin, the title text will show at the bottom of the plot.

In [4]:
feature_type_names = {
    'age': 'numeric', 
    'job': 'category',
    'marital': 'category', 
    'education': 'category', 
    'default': 'category',
    'balance': 'numeric', 
    'housing': 'category',
    'loan': 'category',
    'contact': 'category', 
    'day': 'numeric',
    'month': 'category',
    'duration': 'numeric', 
    'campaign': 'numeric',
    'pdays': 'numeric',
    'previous': 'numeric',
    'poutcome': 'category',
}

In [5]:
def get_num_data(data, col):
    arr = data[~data[col].isna()][col]
    autobins = AutoBins()
    nbins = np.nanmin([
        autobins.get_len_step(arr),
        autobins.get_mean_diff_step(arr),
        autobins.get_power_step(arr)
    ])
    min_x = min(arr)
    max_x = max(arr)
    bin_width = (max_x - min_x)/int(nbins)
    
    return arr, bin_width, nbins

In [6]:
def plot_scatter_matrix(data, cols, types, height=800, width=800):
    ncol = len(cols)

    fig = make_subplots(rows=ncol, cols=ncol, shared_yaxes=False, shared_xaxes=False, 
                        horizontal_spacing=0.01,
                        vertical_spacing=0.01,
                       )

    for i in range(ncol):
        for j in range(ncol):
            if i == j:
                if types[j] == 'numeric':
                    _data, bin_width, nbins = get_num_data(data, cols[j])
                    counts, bin_edges = np.histogram(_data, bins=nbins)
                    fig.add_trace(go.Histogram(x=bin_edges,
                                               y=counts,
                                               histfunc='sum',
                                               xbins=go.histogram.XBins(start=min(_data), end=max(_data)+bin_width, size=bin_width),
                                               marker=dict(color='rgb(61,133,198)')),
                                 row=i+1, col=j+1)
                else:
                    fig.add_trace(go.Bar(x=data[cols[j]].value_counts().index, y=data[cols[j]].value_counts().values, marker=dict(color='rgb(61,133,198)')),
                                 row=i+1, col=j+1)
#             elif j>i:
#                 continue
            else:
                if types[j] == 'numeric' and types[i] == 'numeric':
                    fig.add_trace(go.Scatter(x=data[cols[j]], y=data[cols[i]], mode='markers', marker=dict(color='rgba(64,87,210,0.5)', size=3)),
                              row=i+1, col=j+1)
                elif types[j] == 'numeric' and types[i] == 'category':
                    fig.add_trace(go.Box(x=data[cols[j]], y=data[cols[i]], orientation='h', marker=dict(color='rgba(64,87,210,0.5)')),
                              row=i+1, col=j+1)
                elif types[j] == 'category' and types[i] == 'numeric':
                    fig.add_trace(go.Box(x=data[cols[j]], y=data[cols[i]], orientation='v', marker=dict(color='rgba(64,87,210,0.5)')),
                              row=i+1, col=j+1)
                elif types[j] == 'category' and types[i] == 'category':
                    cross_tab = pd.crosstab(data[cols[j]], data[cols[i]])
                    fig.add_trace(go.Heatmap(x=cross_tab.columns, y=cross_tab.index, z=cross_tab.values, colorscale='YlGnBu', showscale=False),
                              row=i+1, col=j+1)
                    

    fig.update_xaxes(showticklabels=False, fixedrange=False)
    fig.update_yaxes(showticklabels=False, fixedrange=False)

    fig.update_layout(height=height, width=width,
                      showlegend=False,
                      hovermode=False,
                      margin={'t': 50, 'l':100, 'r':0, 'b':100},
                      title=dict(
                        text='Pair Plot',
                        font=dict(size=22),
                        pad=dict(b=0, l=0, r=0, t=0),
                        xanchor='left',
                        yanchor='top',
                        y=0.99,
                        x=0,),
                      paper_bgcolor='rgba(0,0,0,0)',
                      plot_bgcolor='rgba(238,238,238,1)',
                      modebar=dict(
                          bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68, 0.7)', color='rgba(68, 68, 68, 0.3)',
                          remove=['zoomin', 'zoomout', 'lasso', 'select'],
                      ),
                      )

    for i, (type_, col) in enumerate(zip(types, cols)):
        if type_ == 'category':
            tickangle = -90
        else:
            tickangle = 45
        fig.update_xaxes(title_text=str(col), title_standoff=5000, tickangle=tickangle,
                         row=ncol, col=i+1, automargin=False, showticklabels=True)
        fig.update_yaxes(title_text=str(col), title_standoff=10000, tickangle=0,
                         row=i+1, col=1, automargin=False, showticklabels=True)

    return fig

In [7]:
fig = plot_scatter_matrix(data.sample(3000), 
                          cols=['age', 'balance', 'job', 'education'], 
                          types=['numeric', 'numeric', 'category', 'category'], 
                          height=800, width=800)
fig.show(config={'displaylogo': False})
# fig.write_html('./example_plots/pair_plot.html', config={'displaylogo':False}, include_plotlyjs='cdn', full_html=False)