In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def conv_index_to_bins(index) -> pd.Index:
    """Calculate bins to contain the index values.
    The start and end bin boundaries are linearly extrapolated from
    the two first and last values. The middle bin boundaries are
    midpoints.

    Example 1: [0, 1] -> [-0.5, 0.5, 1.5]
    Example 2: [0, 1, 4] -> [-0.5, 0.5, 2.5, 5.5]
    Example 3: [4, 1, 0] -> [5.5, 2.5, 0.5, -0.5]"""
    assert index.is_monotonic_increasing or index.is_monotonic_decreasing

    # the beginning and end values are guessed from first and last two
    start = index[0] - (index[1] - index[0]) / 2
    end = index[-1] + (index[-1] - index[-2]) / 2

    # the middle values are the midpoints
    middle = pd.DataFrame({"m1": index[:-1], "p1": index[1:]})
    middle = middle["m1"] + (middle["p1"] - middle["m1"]) / 2

    if isinstance(index, pd.DatetimeIndex):
        idx = pd.DatetimeIndex(middle).union([start, end])
    elif isinstance(index, (pd.Float64Index, pd.RangeIndex, pd.Int64Index)):
        idx = pd.Float64Index(middle).union([start, end])
    else:

        idx = pd.Float64Index(middle).union([start, end])

    return idx.sort_values(ascending=index.is_monotonic_increasing)


def calc_df_mesh(df) -> list:
    """Calculate the two-dimensional bins to hold the index and
    column values."""
    return np.meshgrid(conv_index_to_bins(df.columns), conv_index_to_bins(df.index))


def heatmap(
    df, ax: mpl.axes = None, cmap=None, vmin=None, vmax=None, norm=None
):  # -> mpl.colorbar.Colorbar:
    """Plot a heatmap of the dataframe values using the index and
    columns"""
    X, Y = calc_df_mesh(df)
    if ax:
        im = ax.pcolormesh(X, Y, df.values, norm=norm, cmap=cmap, vmin=vmin, vmax=vmax)
        return im
    else:
        c = plt.pcolormesh(X, Y, df.values, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.colorbar(c)