In [1]:
import os, copy, argparse
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# 현재 스크립트 파일 위치

import sys

from utils.affinity import BasisSlotAffinityGAT
from models.TabularFLM_S import Model
from dataset.data_dataloaders import prepare_embedding_dataloaders
from utils.util import fix_seed


  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class MVisualizer:
    def __init__(self, ckpt_path: str, device='cuda', auto_del_feat=None):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        ckpt = torch.load(ckpt_path, map_location=self.device)
        self.args = ckpt['args']

        if auto_del_feat is not None:
            self.args.del_feat = auto_del_feat
            print(f"[INFO] Applied del_feat from filename: {auto_del_feat}")

        # === checkpoint args 보정 ===
        if not hasattr(self.args, 'n_heads'):
            self.args.n_heads = 8
        if not hasattr(self.args, 'k_basis'):
            self.args.k_basis = 8
        if not hasattr(self.args, 'slot_kernel_rank'):
            self.args.slot_kernel_rank = self.args.k_basis

        # 모델 생성
        self.model = Model(
            self.args,
            self.args.input_dim,
            self.args.hidden_dim,
            self.args.output_dim,
            self.args.dropout_rate,
            self.args.llm_model,
            "viz",
            "viz"
        ).to(self.device)

        # state_dict 로드
        sd = ckpt['model_state_dict']
        sd = {k: v for k, v in sd.items() if 'alpha_ema' not in k}
        missing, unexpected = self.model.load_state_dict(sd, strict=False)
        if missing:
            print("[INFO] Missing keys:", missing)
        if unexpected:
            print("[INFO] Unexpected keys:", unexpected)

        self.model.eval()
        self.num_layers = int(getattr(self.args, 'num_basis_layers', 3))
        self.num_heads  = int(self.args.n_heads)
        self.num_slots  = int(self.args.k_basis)

        # === U_param 직접 보관 (checkpoint에서 바로 추출) ===
        self.U_param = None
        U_keys = [k for k in sd.keys() if "U_param" in k]
        if U_keys:
            self.U_param = sd[U_keys[0]].detach().cpu()
            print(f"[INFO] Found U_param in state_dict: {U_keys[0]}, shape={self.U_param.shape}")
        else:
            print("[WARN] No U_param found in checkpoint.")

    # === U_param 시각화 ===
    def plot_U_heatmap(self, save_path=None):
        """
        U_param (after softplus) 시각화
        - 전체 histogram
        - Slot × (K×R) heatmap
        - Slot별 histogram subplot
        """
        import matplotlib.pyplot as plt
        import seaborn as sns

        if self.U_param is None:
            print("[WARN] U_param not available.")
            return

        # softplus 적용
        U = torch.nn.functional.softplus(self.U_param)  # [M,K,R]
        M, K, R = U.shape

        # (1) 전체 분포
        plt.figure(figsize=(5, 3))
        plt.hist(U.numpy().ravel(), bins=50, color="steelblue")
        plt.title("Distribution of U values (all slots)")
        if save_path:
            plt.savefig(str(save_path).replace(".png", "_global_hist.png"),
                        dpi=250, bbox_inches="tight")
        plt.close()

        # (2) Slot × (K×R) heatmap (각 slot을 행으로)
        U_2d = U.reshape(M, K*R).numpy()  # [M, K*R]
        plt.figure(figsize=(max(8, K*R/2), max(4, M)))
        sns.heatmap(U_2d, cmap="viridis", cbar=True)
        plt.title("U values per Slot (flattened K×R)")
        plt.xlabel("K×R dimension")
        plt.ylabel("Slot m")
        if save_path:
            plt.savefig(str(save_path).replace(".png", "_heatmap.png"),
                        dpi=250, bbox_inches="tight")
        plt.close()

        # (3) Slot별 histogram subplot
        fig, axes = plt.subplots(1, M, figsize=(3*M, 3), sharey=True)
        if M == 1:
            axes = [axes]
        for m in range(M):
            axes[m].hist(U[m].numpy().ravel(), bins=30, color="steelblue")
            axes[m].set_title(f"Slot {m}")
            axes[m].set_xlabel("U values")
        axes[0].set_ylabel("Frequency")
        plt.tight_layout()
        if save_path:
            plt.savefig(str(save_path).replace(".png", "_slotwise.png"),
                        dpi=250, bbox_inches="tight")
        plt.close()
    def plot_U_per_slot(self, save_path=None):
        """
        각 Slot(M)별 U_param 분포를 히스토그램으로 시각화
        """
        import matplotlib.pyplot as plt

        if self.U_param is None:
            print("[WARN] U_param not available.")
            return

        # softplus 적용 (forward와 동일)
        U = torch.nn.functional.softplus(self.U_param).cpu().numpy()  # [M,K,R]
        M, K, R = U.shape

        # === subplot grid ===
        fig, axes = plt.subplots(1, M, figsize=(M*3, 3), sharey=True)
        if M == 1:
            axes = [axes]

        for m in range(M):
            ax = axes[m]
            vals = U[m].ravel()  # [K*R]
            ax.hist(vals, bins=20, color="steelblue", alpha=0.8)
            ax.set_title(f"Slot {m}", fontsize=10)
            ax.set_xlabel("U values")
            if m == 0:
                ax.set_ylabel("Frequency")

        plt.suptitle("U_param distribution per Slot", fontsize=14)
        plt.tight_layout(rect=[0,0,1,0.95])

        if save_path:
            plt.savefig(save_path, dpi=250, bbox_inches="tight")
        plt.close(fig)



In [10]:
viz = MVisualizer("/storage/personal/eungyeop/experiments/checkpoints/gpt2_mean/heart_target_1+heart_target_2+heart_target_3+heart_target_4/Pre/TabularFLM_attn-gat_v1_num_basis_2_num_shared_layers_2_num_basis_layers_2_scorer_slot_no_self_loop_False/42/best_20250923_012453.pt")

# 저장 없이 바로 확인
viz.plot_U_heatmap()

# 파일로 저장
viz.plot_U_heatmap("Uparam_heatmap.png")
viz.plot_U_per_slot("U_param.png")

[INFO] Found U_param in state_dict: basis_affinity.U_param, shape=torch.Size([8, 8, 8])


In [17]:
N = 13 
M = 6 
K = 8 
S = torch.softmax(torch.randn(N, M),dim=-1) 
G = torch.randn(M, K, K)

In [22]:
S1 = S.unsqueeze(-1)
S1.shape

torch.Size([13, 6, 1])

In [24]:
S2 = S1.permute(1,2,0)

In [34]:
import torch
torch.manual_seed(0)

N = 13
M = 6
K = 8

S = torch.softmax(torch.randn(N, M), dim=-1)  # [N, M]
G = torch.randn(M, K, K)                      # [M, K, K]

# 1) 슬롯-우선으로 S 정렬: [M, N]
S_mn = S.permute(1, 0)                        # [M, N]

# 2) 각 슬롯 m에 대해 S^(m) = S[:,m]을 K축으로 브로드캐스트 → [M, N, K]
#    (모든 열이 동일한 N×K 행렬이 됨)
S_mnk = S_mn.unsqueeze(-1).expand(M, N, K)    # [M, N, K]

# 3) 배치 matmul로 A_slot[m] = S^(m) G^(m) (S^(m))^T → [M, N, N]
left  = torch.matmul(S_mnk, G)                # [M, N, K]
A_slot = torch.matmul(left, S_mnk.transpose(1, 2))  # [M, N, N]

print(A_slot.shape)  # torch.Size([6, 13, 13])


torch.Size([6, 13, 13])


In [35]:
import torch
torch.manual_seed(0)

N = 13
M = 6
K = 8

S = torch.softmax(torch.randn(N, M), dim=-1)  # [N, M]
G = torch.randn(M, K, K)                      # [M, K, K]

# 1) 슬롯-우선으로 S 정렬: [M, N]
S_mn = S.permute(1, 0)                        # [M, N]

# 2) 각 슬롯 m에 대해 S^(m) = S[:,m]을 K축으로 브로드캐스트 → [M, N, K]
#    (모든 열이 동일한 N×K 행렬이 됨)
S_mnk = S_mn.unsqueeze(-1).expand(M, N, K)    # [M, N, K]

# 3) 배치 matmul로 A_slot[m] = S^(m) G^(m) (S^(m))^T → [M, N, N]
left  = torch.matmul(S_mnk, G)                # [M, N, K]
A_slot_matmul = torch.matmul(left, S_mnk.transpose(1, 2))  # [M, N, N]
A_slot_einsum = torch.einsum("nm,mkl,jm->mnj", S, G, S)
all_equal = torch.allclose(A_slot_matmul, A_slot_einsum, atol=1e-6)
max_abs_diff = (A_slot_matmul - A_slot_einsum).abs().max().item()

all_equal, max_abs_diff


(True, 4.76837158203125e-07)

In [36]:
import torch
B,H,M,N,K = 1,2,3,4,5
Dx = torch.randn(B,H,N,N)
Dy = torch.randn(B,M,K,K)
Pi = torch.rand(B,H,M,N,K)

# 올바른 einsum
cross_e = torch.einsum("bhij,bmkl,bhmjl->bhmik", Dx, Dy, Pi)

# 느리지만 정답인 이중 for-루프
cross_lo = torch.zeros(B,H,M,N,K)
for b in range(B):
  for h in range(H):
    for m in range(M):
      for i in range(N):
        for k in range(K):
          s = 0.0
          for j in range(N):
            for l in range(K):
              s += Dx[b,h,i,j] * Dy[b,m,k,l] * Pi[b,h,m,j,l]
          cross_lo[b,h,m,i,k] = s

print(torch.allclose(cross_e, cross_lo, atol=1e-5))  # → True 여야 함


True


: 