# Heatmap in Python using pandas and seaborn

Install requirements:

In [None]:
%pip install ipympl pandas seaborn matplotlib

Generate some data: an ID for the row, some date range and a random selection of Chinese zodiacs because I'm still a Sinologist at heart. This boilerplate code was generated with GenAI.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random

# Column 1: Sequential enumeration of ints
ids = list(range(1, 101))

# Column 2: Date range with repeated entries in "MM-YYYY" format
dates = ["01-2023", "02-2023", "03-2023", "04-2023", "05-2023", "06-2023", "07-2023"]
date_range = [random.choice(dates) for _ in range(100)]

# Column 3: Values from the specified range
animals = ["snake", "dragon", "unknown", "tiger", "monkey", "rabbit"]
animal_values = [random.choice(animals) for _ in range(100)]

# Create DataFrame
df = pd.DataFrame({
    'ID': ids,
    'Date': date_range,
    'Zodiac': animal_values
})
df

Create a pivot table that correlates dates with zodiac signs and counts occurrences.

In [None]:
animal_per_month = df.pivot_table(index="Date", columns="Zodiac", aggfunc="size", fill_value=0)
groups = animal_per_month.columns.tolist()
animal_per_month = animal_per_month.reindex(columns=groups)
animal_per_month

Create grand totals for rows and columns:

In [None]:
grand_total = animal_per_month.copy()
grand_total.loc["Grand Total"] = animal_per_month.sum(axis=0)
grand_total["Grand Total"] = animal_per_month.sum(axis=1)
grand_total

Now comes the tricky part: Create a mask so the grand total row and column are not used when creating the heatmap.

In [None]:
mask = np.zeros(grand_total.shape)
mask[-1, :] = True  # mask the Grand Total row
mask[:, -1] = True  # mask the Grand Total column
mask

Generate the heatmap by overlaying two heatmaps, one with the grand totals as mask, one with the actual data.

In [None]:
# solution based on Khalil Al Hooti here: https://stackoverflow.com/questions/53606027/exclude-a-column-from-seaborn-heatmap-formatting-but-keep-in-the-map
sns.heatmap(grand_total, mask=mask, cmap="RdYlGn_r", cbar=False)
sns.heatmap(grand_total, alpha=0, cbar=False, annot=True, cmap="RdYlGn_r", annot_kws={"size": 10, "color":"k"})
plt.tight_layout()
plt.show()
