In [None]:
import os, re, math
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Folder where the 17 heatmaps are stored
desktop = os.path.join(os.path.expanduser("~"), "Desktop")
heatmap_dir = os.path.join(desktop, "AI_AI_heatmaps")

# Grab and sort by temperature from filename: cosine_5x5_temp_0.38.png
pat = re.compile(r"cosine_5x5_temp_(\d+\.\d{2})\.png")
pngs = []
for f in os.listdir(heatmap_dir):
    m = pat.match(f)
    if m:
        pngs.append((float(m.group(1)), os.path.join(heatmap_dir, f)))

pngs.sort(key=lambda x: x[0])           # sort by temperature
paths = [p for _, p in pngs]
titles = [f"T = {t:.2f}" for t, _ in pngs]

n = len(paths)                           
ncols = 6                            
nrows = math.ceil(n / ncols)

# Size per panel
panel_w, panel_h = 3.0, 3.0
fig = plt.figure(figsize=(ncols * panel_w, nrows * panel_h))

for i, (p, title) in enumerate(zip(paths, titles), start=1):
    ax = fig.add_subplot(nrows, ncols, i)
    img = mpimg.imread(p)
    ax.imshow(img)
    ax.set_title(title, fontsize=10)
    ax.axis("off")

# Hide any empty slots
for j in range(n + 1, nrows * ncols + 1):
    ax = fig.add_subplot(nrows, ncols, j)
    ax.axis("off")

plt.tight_layout()

# Save the combined grid on Desktop
out_path = os.path.join(desktop, "AI_AI_heatmaps_grid.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()

print(f"Combined figure saved to: {out_path}")