In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors
import torch
import numpy as np
import matplotlib.pyplot as plt

# data feature
features = [
    "FIT101","LIT101","MV101","P101","P102",
    "AIT201","AIT202","AIT203","FIT201","MV201",
    "P201","P202","P203","P204","P205","P206",
    "DPIT301","FIT301","LIT301","MV301","MV302",
    "MV303","MV304","P301","P302","AIT401","AIT402",
    "FIT401","LIT401","P401","P402","P403","P404",
    "UV401","AIT501","AIT502","AIT503","AIT504",
    "FIT501","FIT502","FIT503","FIT504","P501","P502",
    "PIT501","PIT502","PIT503","FIT601","P601","P602","P603"]

# load data
data_input = torch.load("dependency_matrix.pt", map_location=torch.device("cpu"))
data_recon = torch.load("dependency_recon.pt", map_location=torch.device("cpu"))

time_point = 10 
input_ = data_input[time_point]  # 
recon_ = data_recon[time_point]
diff = (input_ - recon_).abs() # calculate difference

# min-max scaling
diff_min = diff.min()
diff_max = diff.max()
denom = (diff_max - diff_min).clamp(min=1e-12)
X_scaled = (diff - diff_min) / denom  # (51, 51) 0~1


# plot
fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True)

new_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_blue",
    [(0.0, "white"),
     (0.85, "lightblue"),
     (1.0, "darkblue")])

im1 = ax.imshow(
    X_scaled,
    cmap=new_cmap,
    aspect="auto"
)

ax.set_yticks(np.arange(len(features)))
ax.set_yticklabels(features, fontsize=10)
ax.set_xticks(np.arange(len(features)))
ax.set_xticklabels(features, rotation=90, fontsize=10)

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="3%", pad=0.25)  # size/pad 조절 가능
cbar = fig.colorbar(im1, cax=cax)

fig.set_constrained_layout(True)