In [1]:
import itertools
import math

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly import subplots

np.random.seed(1337)

In [2]:
n = 50
k = 10
d_plot = pd.DataFrame()
for i in range(k):
    d_plot[f"column_name_x_{i:02d}"] = np.random.randn(n)

In [3]:
# Heatmap with tick labels, color bar and cell annotations
# Mostly to be used as a correlation matrix

col_names = [
    "column_name_x_00",
    "column_name_x_01",
    "column_name_x_02",
    "column_name_x_03",
    "column_name_x_04",
    "column_name_x_05",
    "column_name_x_06",
    "column_name_x_07",
    "column_name_x_08",
    "column_name_x_09",
]

corr = np.flip(np.corrcoef(d_plot[col_names], rowvar=False), 0)
fig = go.Figure()
fig.add_trace(
    go.Heatmap(
        x=col_names,
        y=col_names[::-1],
        z=corr,
        zmin=-1,
        zmax=1,
        zauto=False,
        name="",
        colorscale="RdBu_r",
        showscale=False,
        text=corr,
        texttemplate="%{text:0.2f}",
        textfont={"size": 12},
    )
)
fig.update_traces(hovertemplate="x: %{x}" + "<br>y: %{y}" + "<br>value: %{z:0.3f}")
fig.update_layout(
    template="plotly_white",
    width=200 + 50 * corr.shape[0],
    height=200 + 50 * corr.shape[0],
    title="Correlations",
)
fig.update_xaxes(tickangle=270)
fig.show()

In [4]:
# Correlation plot with a list of pairwise scatter plots

col_names = [
    "column_name_x_00",
    "column_name_x_01",
    "column_name_x_02",
    "column_name_x_03",
    "column_name_x_04",
    "column_name_x_05",
]
pairs = list(itertools.combinations(col_names, r=2))
n_plots = len(pairs)
n_plot_cols = 4
n_plot_rows = math.ceil(n_plots / n_plot_cols)

fig = subplots.make_subplots(
    rows=n_plot_rows,
    cols=n_plot_cols,
    subplot_titles=[
        f"correlation: {np.corrcoef(d_plot[pair[0]], d_plot[pair[1]])[0, 1]:0.2f}"
        for pair in pairs
    ],
    horizontal_spacing=0.09,
)
for i in range(n_plots):
    fig.add_trace(
        go.Scatter(
            x=d_plot[pairs[i][0]],
            y=d_plot[pairs[i][1]],
            name="",
            mode="markers",
            marker=dict(color="cornflowerblue", size=4, symbol="circle"),
        ),
        row=1 + i // n_plot_cols,
        col=1 + i % n_plot_cols,
    )
    fig.update_xaxes(
        title_text=pairs[i][0], row=1 + i // n_plot_cols, col=1 + i % n_plot_cols
    )
    fig.update_yaxes(
        title_text=pairs[i][1], row=1 + i // n_plot_cols, col=1 + i % n_plot_cols
    )
fig.update_traces(hovertemplate="x: %{x:0.3f}" + "<br>y: %{y:0.3f}")
fig.update_layout(
    template="plotly_white",
    width=1200,
    height=80 + 280 * n_plot_rows,
    title="Correlations",
    legend_title="Legend",
    showlegend=False,
)

In [5]:
# Correlation plot with one specific variable

main_col = "column_name_x_00"
col_names = [
    "column_name_x_01",
    "column_name_x_02",
    "column_name_x_03",
    "column_name_x_04",
    "column_name_x_05",
    "column_name_x_06",
    "column_name_x_07",
    "column_name_x_08",
]
pairs = list(itertools.combinations(col_names, r=2))
n_plots = len(col_names)
n_plot_cols = 4
n_plot_rows = math.ceil(n_plots / n_plot_cols)

fig = subplots.make_subplots(
    rows=n_plot_rows,
    cols=n_plot_cols,
    subplot_titles=[
        f"correlation: {np.corrcoef(d_plot[main_col], d_plot[col])[0, 1]:0.2f}"
        for col in col_names
    ],
    horizontal_spacing=0.09,
)
for i in range(n_plots):
    fig.add_trace(
        go.Scatter(
            x=d_plot[col_names[i]],
            y=d_plot[main_col],
            name="",
            mode="markers",
            marker=dict(color="cornflowerblue", size=4, symbol="circle"),
        ),
        row=1 + i // n_plot_cols,
        col=1 + i % n_plot_cols,
    )
    fig.update_xaxes(
        title_text=col_names[i], row=1 + i // n_plot_cols, col=1 + i % n_plot_cols
    )
    fig.update_yaxes(
        title_text=main_col, row=1 + i // n_plot_cols, col=1 + i % n_plot_cols
    )
fig.update_traces(hovertemplate="x: %{x:0.3f}" + "<br>y: %{y:0.3f}")
fig.update_layout(
    template="plotly_white",
    width=1200,
    height=80 + 280 * n_plot_rows,
    title="Correlations",
    legend_title="Legend",
    showlegend=False,
)