In [1]:
%load_ext autoreload
%autoreload 2

In [341]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import dartwork_mpl as dm

from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [342]:
# Create dummy 5 x 5 pandas dataframe with index.
df = pd.DataFrame(
    2 * np.random.rand(5, 5) - 1,
    columns=['A', 'B', 'C', 'D', 'E'],
    index=['a', 'b', 'c', 'd', 'e'],
)

df

Unnamed: 0,A,B,C,D,E
a,0.168169,0.053587,-0.120382,0.044093,0.177809
b,-0.09418,-0.832469,0.62122,-0.048882,0.734618
c,0.509682,-0.527377,-0.515614,0.794818,-0.72928
d,0.302155,-0.756303,-0.856243,-0.999432,0.914754
e,0.520771,-0.923783,-0.4422,-0.982405,0.893554


In [385]:
options = {
    'figure': {
        'size': (dm.cm2in(12), dm.cm2in(7)),
        'dpi': 200,
        'facecolor': 'dm.gray4',
        'other_kwargs': {},
    },
    'heatmap': {
        'cmap': 'magma',
        'annot': True,
        'fmt': '.2f',
        'square': True,
        'linecolor': 'green',
        'linewidths': 20,
        'xtick': {
            'length': 2,
            'width': 0.5,
        },
        'ytick': {
            'length': 2,
            'width': 0.5,
        },
        # None, 'l', 'u', or 2D np.array.
        'mask': 'l',
        'other_kwargs': {
            'annot_kws': {
                'fontsize': 7,
            },
        },
    },
    'cbar': {
        'size': '20%',
        'pad': '20%',
        'clim': (-1-1e-3, 1+1e-3),
        # 'adjust_ticks': True,
        'tick': {
            'length': 0.0,
            'width': 0.0,
            'positions': np.arange(-1, 1.01, 0.5),
            'labels': [f'{v:.1f}' for v in np.arange(-1, 1.01, 0.5)],
        }
    },
}



dm.use_style('dmpl_light')

lw = options['heatmap']['linewidths']

fig = plt.figure(
    figsize=options['figure']['size'],
    dpi=options['figure']['dpi'],
    facecolor=options['figure']['facecolor'],
    **options['figure']['other_kwargs'],
)

gs = fig.add_gridspec(nrows=1, ncols=1)

mask = options['heatmap']['mask']
if mask is None:
    mask = np.zeros(df.values.shape, dtype=bool)
elif isinstance(mask, np.ndarray):
    mask = mask
elif mask == 'l':
    mask = ~np.tril(np.ones(df.values.shape, dtype=bool))
elif mask == 'u':
    mask = ~np.triu(np.ones(df.values.shape, dtype=bool))

ax = fig.add_subplot(gs[0, 0])
sns.heatmap(
    df,
    ax=ax,
    cmap=options['heatmap']['cmap'],
    annot=options['heatmap']['annot'],
    fmt=options['heatmap']['fmt'],
    cbar=False,
    square=options['heatmap']['square'],
    linecolor=options['heatmap']['linecolor'],
    linewidths=options['heatmap']['linewidths'],
    mask=mask,
    **options['heatmap']['other_kwargs'],
)

ax.tick_params(axis='x', length=options['heatmap']['xtick']['length'], width=options['heatmap']['xtick']['width'])
ax.tick_params(axis='y', length=options['heatmap']['ytick']['length'], width=options['heatmap']['ytick']['width'])

divider = make_axes_locatable(ax)
cax = divider.append_axes(
    'right',
    size=options['cbar']['size'],
    pad=options['cbar']['pad'],
)

cbar = fig.colorbar(
    ax.collections[0],
    cax=cax,
    orientation='vertical',
)

cbar.outline.set_color('red')
cbar.outline.set_linewidth(0)
ax.collections[0].set_clim(*options['cbar']['clim'])

# dm.simple_layout(
#     fig, use_all_axes=True, verbose=True,
#     importance_weights=(0, 0, 1, 1),
# )
fig.tight_layout()

# Get axes bounding box.
ax_bounds = ax.get_window_extent().bounds

# Get colorbar bounding box.
cax_bounds = cax.get_window_extent().bounds


# Convert lw pt to pixel.
lw_pixel = lw * fig.dpi / 72

if ax_bounds[2] < cax_bounds[0]:
    cax_offset = 2 * lw_pixel
    cax_last = cax_offset / 2
else:
    cax_offset = -lw_pixel / 4
    cax_last = -cax_offset / 4

cax_bounds = (
    cax_bounds[0] + cax_offset,
    # cax_bounds[0],
    ax_bounds[1] + lw_pixel / 2,
    cax_bounds[2] + cax_last,
    ax_bounds[3] - lw_pixel,
)

# Convert to figure coordinates.
cax_bounds = (
    cax_bounds[0] / fig.get_figwidth() / fig.dpi,
    cax_bounds[1] / fig.get_figheight() / fig.dpi,
    cax_bounds[2] / fig.get_figwidth() / fig.dpi,
    cax_bounds[3] / fig.get_figheight() / fig.dpi,
)

# Create new divider for colorbar.
new_divider = Divider(
    fig, cax_bounds,
    horizontal=[Size.Scaled(1.0)],
    vertical=[Size.Scaled(1.0)],
)

cax.set_axes_locator(new_divider.new_locator(nx=0, ny=0))
cbar.outline.set_color('red')
cbar.outline.set_linewidth(0)

for s in ['top', 'bottom', 'left', 'right']:
    ax.spines[s].set_linewidth(lw)
    ax.spines[s].set_color(options['heatmap']['linecolor'])
    ax.spines[s].set_visible(True)

tw = options['cbar']['tick']['width']
if options['cbar']['tick']['positions'] is not None:
    cax.yaxis.set_ticks(options['cbar']['tick']['positions'])

if options['cbar']['tick']['labels'] is not None:
    cax.yaxis.set_ticklabels(options['cbar']['tick']['labels'])

cbar.ax.yaxis.set_tick_params(
    length=options['cbar']['tick']['length'],
    width=options['cbar']['tick']['width'],
    color=options['heatmap']['linecolor'],
)

# if options['cbar']['adjust_ticks']:
#     yticks = cax.get_yticks()
#     # Convert lw in data coordinates.
#     print(tw)

#     tw_pixel = tw * fig.dpi / 72
#     print(tw_pixel)
#     tw_data = cax.transData.inverted().transform((0, tw_pixel))[1]
#     print(tw_data)
#     yticks[0] = yticks[0] - tw_data / 2
#     yticks[-1] = yticks[-1] + tw_data / 2
#     cax.set_yticks(yticks)
#     # tick.set_transform(tick.get_transform() + dm.make_offset(0, 50, fig))

if options['cbar']['tick']['labels'] is not None:
    cax.yaxis.set_ticklabels(options['cbar']['tick']['labels'])

dm.save_and_show(fig, facecolor=fig.get_facecolor(), size=700)
dm.save_formats(fig, 'test_heatmap', formats=['png', 'svg'], dpi=300, facecolor=fig.get_facecolor())

# Production level test

In [353]:
# Create dummy 5 x 5 pandas dataframe with index.
df = pd.DataFrame(
    2 * np.random.rand(5, 5) - 1,
    columns=['AAAAA', 'BBB', 'C', 'DDDDDD', 'EE'],
    index=['aaa', 'bb', 'ccccc', 'dd', 'eeeee'],
)

In [388]:
options = {
    'figure': {
        # 가로 여백이 더 넓어야 함.
        'size': (dm.cm2in(10), dm.cm2in(7)),
        'dpi': 300,
        'facecolor': 'white',
        'other_kwargs': {},
    },
    'heatmap': {
        'cmap': sns.diverging_palette(0, 255, sep=77, as_cmap=True),
        'annot': True,
        'fmt': '.2f',
        'square': True,
        'linecolor': 'white',
        'linewidths': 1.5,
        'xtick': {
            'length': 2,
            'width': 0.5,
        },
        'ytick': {
            'length': 2,
            'width': 0.5,
        },
        # None, 'l', 'u', or 2D np.array.
        'mask': 'l',
        'other_kwargs': {
            'annot_kws': {
                'fontsize': 6,
            },
        },
    },
    'cbar': {
        'size': '5%',
        'pad': '5%',
        'clim': (-1-1e-3, 1+1e-3),
        # 'adjust_ticks': True,
        'tick': {
            'length': 0.0,
            'width': 0.0,
            'positions': np.arange(-1, 1.01, 0.5),
            'labels': [f'{v:.1f}' for v in np.arange(-1, 1.01, 0.5)],
        }
    },
}



dm.use_style('dmpl_light')

lw = options['heatmap']['linewidths']

fig = plt.figure(
    figsize=options['figure']['size'],
    dpi=options['figure']['dpi'],
    facecolor=options['figure']['facecolor'],
    **options['figure']['other_kwargs'],
)

gs = fig.add_gridspec(nrows=1, ncols=1)

mask = options['heatmap']['mask']
if mask is None:
    mask = np.zeros(df.values.shape, dtype=bool)
elif isinstance(mask, np.ndarray):
    mask = mask
elif mask == 'l':
    mask = ~np.tril(np.ones(df.values.shape, dtype=bool))
elif mask == 'u':
    mask = ~np.triu(np.ones(df.values.shape, dtype=bool))

ax = fig.add_subplot(gs[0, 0])
sns.heatmap(
    df,
    ax=ax,
    cmap=options['heatmap']['cmap'],
    annot=options['heatmap']['annot'],
    fmt=options['heatmap']['fmt'],
    cbar=False,
    square=options['heatmap']['square'],
    linecolor=options['heatmap']['linecolor'],
    linewidths=options['heatmap']['linewidths'],
    mask=mask,
    **options['heatmap']['other_kwargs'],
)

ax.tick_params(axis='x', length=options['heatmap']['xtick']['length'], width=options['heatmap']['xtick']['width'])
ax.tick_params(axis='y', length=options['heatmap']['ytick']['length'], width=options['heatmap']['ytick']['width'])

divider = make_axes_locatable(ax)
cax = divider.append_axes(
    'right',
    size=options['cbar']['size'],
    pad=options['cbar']['pad'],
)

cbar = fig.colorbar(
    ax.collections[0],
    cax=cax,
    orientation='vertical',
)

cbar.outline.set_color('red')
cbar.outline.set_linewidth(0)
ax.collections[0].set_clim(*options['cbar']['clim'])

# dm.simple_layout(
#     fig, use_all_axes=True, verbose=True,
#     importance_weights=(0, 0, 1, 1),
# )
fig.tight_layout()

# Get axes bounding box.
ax_bounds = ax.get_window_extent().bounds

# Get colorbar bounding box.
cax_bounds = cax.get_window_extent().bounds


# Convert lw pt to pixel.
lw_pixel = lw * fig.dpi / 72

if ax_bounds[2] < cax_bounds[0]:
    cax_offset = 2 * lw_pixel
    cax_last = cax_offset / 2
else:
    cax_offset = -lw_pixel / 4
    cax_last = -cax_offset / 4

cax_bounds = (
    cax_bounds[0] + cax_offset,
    # cax_bounds[0],
    ax_bounds[1] + lw_pixel / 2,
    cax_bounds[2] + cax_last,
    ax_bounds[3] - lw_pixel,
)

# Convert to figure coordinates.
cax_bounds = (
    cax_bounds[0] / fig.get_figwidth() / fig.dpi,
    cax_bounds[1] / fig.get_figheight() / fig.dpi,
    cax_bounds[2] / fig.get_figwidth() / fig.dpi,
    cax_bounds[3] / fig.get_figheight() / fig.dpi,
)

# Create new divider for colorbar.
new_divider = Divider(
    fig, cax_bounds,
    horizontal=[Size.Scaled(1.0)],
    vertical=[Size.Scaled(1.0)],
)

cax.set_axes_locator(new_divider.new_locator(nx=0, ny=0))
cbar.outline.set_color('red')
cbar.outline.set_linewidth(0)

for s in ['top', 'bottom', 'left', 'right']:
    ax.spines[s].set_linewidth(lw)
    ax.spines[s].set_color(options['heatmap']['linecolor'])
    ax.spines[s].set_visible(True)

tw = options['cbar']['tick']['width']
if options['cbar']['tick']['positions'] is not None:
    cax.yaxis.set_ticks(options['cbar']['tick']['positions'])

if options['cbar']['tick']['labels'] is not None:
    cax.yaxis.set_ticklabels(options['cbar']['tick']['labels'])

cbar.ax.yaxis.set_tick_params(
    length=options['cbar']['tick']['length'],
    width=options['cbar']['tick']['width'],
    color=options['heatmap']['linecolor'],
)

# if options['cbar']['adjust_ticks']:
#     yticks = cax.get_yticks()
#     # Convert lw in data coordinates.
#     print(tw)

#     tw_pixel = tw * fig.dpi / 72
#     print(tw_pixel)
#     tw_data = cax.transData.inverted().transform((0, tw_pixel))[1]
#     print(tw_data)
#     yticks[0] = yticks[0] - tw_data / 2
#     yticks[-1] = yticks[-1] + tw_data / 2
#     cax.set_yticks(yticks)
#     # tick.set_transform(tick.get_transform() + dm.make_offset(0, 50, fig))

if options['cbar']['tick']['labels'] is not None:
    cax.yaxis.set_ticklabels(options['cbar']['tick']['labels'])

dm.save_and_show(fig, facecolor=fig.get_facecolor(), size=700)
dm.save_formats(fig, 'test_heatmap', formats=['png', 'svg'], dpi=1000, facecolor=fig.get_facecolor())