In [1]:
from model.model import Model
import pickle
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import re
import sklearn.metrics

  from .autonotebook import tqdm as notebook_tqdm
2025-11-17 10:15:59.703533: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-17 10:15:59.819600: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-17 10:16:00.350423: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2025-11-17 10:16:00.350483: W tensorflow/compiler/xla/s

In [2]:
with open('./utils/tokenizer.pickle', 'rb') as f :
    tokenizer = pickle.load(f)

In [3]:
model_path = f'./utils/model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = Model(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict'])
model.eval();

In [4]:
sequence = 'AAAAA'

In [5]:
sequence = tokenizer.texts_to_sequences([sequence])
sequence[0] = [22] + sequence[0]
sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
sequence = torch.Tensor(sequence).int()

In [6]:
model(sequence, mode='infer')

tensor([[ -1.5120,  -8.0787, -13.7206,  ..., -27.9469, -22.9021, -27.6669]],
       grad_fn=<AddBackward0>)

In [7]:
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=0)
    return cam

def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def generate_relevance(sequence, model, level_sizes=[7,72,268,4255], labels=None, target_ec=None, force_chain=True, max_len=1024, bos_id=22):
    seq = tokenizer.texts_to_sequences([sequence])
    seq[0] = [bos_id] + seq[0]
    seq[0] += [0 for _ in range(max_len - len(seq[0]))]
    device = next(model.parameters()).device if any(True for _ in model.parameters()) else torch.device("cpu")
    seq = torch.tensor(seq, dtype=torch.int32, device=device)

    starts = [0]
    for s in level_sizes[:-1]:
        starts.append(starts[-1] + s)

    def rollout(global_idx):
        model.eval()
        out_inter = model(seq, mode='inter')
        one_hot = torch.zeros((1, out_inter.size(-1)), dtype=torch.float32, device=out_inter.device)
        one_hot[0, global_idx] = 100.0
        t = torch.sum(one_hot * out_inter)
        model.zero_grad()
        t.backward()
        R = torch.eye(max_len, max_len, device=out_inter.device)
        for blk in [model.model.enc_1, model.model.enc_2, model.model.enc_3, model.model.enc_4]:
            cam = blk.inter_attention.get_attn()
            grad = blk.inter_attention.get_attn_gradients()
            cam = avg_heads(cam, grad)
            R = R + apply_self_attention_rules(R, cam)
        return R[0,1:]

    def global_from_level(i_level, local_idx):
        return starts[i_level] + local_idx

    if target_ec is not None:
        if labels is None:
            raise ValueError("labels (per-level or flat) required to locate target_ec")
        if force_chain and isinstance(labels[0], (list, tuple)):
            parts = str(target_ec).split(".")
            chain = []
            if len(parts) >= 1: chain.append(parts[0])
            if len(parts) >= 2: chain.append(".".join(parts[:2]))
            if len(parts) >= 3: chain.append(".".join(parts[:3]))
            if len(parts) >= 4: chain.append(".".join(parts[:4]))
            targets = []
            for i, ec_str in enumerate(chain):
                try:
                    li = labels[i].index(ec_str)
                    gi = global_from_level(i, li)
                    targets.append(gi)
                    print(f"Forced L{i+1}: {ec_str} | gidx={gi}")
                except ValueError:
                    print(f"Forced L{i+1}: {ec_str} not found, skipping")
            if not targets:
                raise ValueError("No levels found for target_ec in labels")
            Rs = [rollout(gi) for gi in targets]
            return torch.stack(Rs, dim=0).mean(dim=0)
        else:
            if isinstance(labels[0], (list, tuple)):
                off = 0
                gi = None
                for i, lab in enumerate(labels):
                    try:
                        li = lab.index(target_ec)
                        gi = off + li
                        break
                    except ValueError:
                        off += level_sizes[i]
                if gi is None:
                    raise ValueError("target_ec not found in labels")
            else:
                try:
                    gi = int(labels.index(target_ec))
                except ValueError:
                    raise ValueError("target_ec not found in labels")
            print(f"Forced EC: {target_ec} | gidx={gi}")
            return rollout(gi)

    model.eval()
    out_infer = model(seq, mode='infer')
    Rs = []
    for i in range(4):
        s, e = starts[i], starts[i] + level_sizes[i]
        logits = out_infer[0, s:e]
        li = int(torch.argmax(logits).item())
        gi = s + li
        sc = float(torch.max(logits).item())
        if labels is not None:
            try:
                lbl = labels[i][li] if isinstance(labels[0], (list,tuple)) else labels[gi]
            except Exception:
                lbl = f"idx{li}"
            print(f"L{i+1}: {lbl} | score={sc:.6f} | gidx={gi}")
        else:
            print(f"L{i+1}: idx={li} | score={sc:.6f} | gidx={gi}")
        Rs.append(rollout(gi))
        
    thresh=0.4
    s4, e4 = starts[3], starts[3] + level_sizes[3]
    logits_l4 = out_infer[0, s4:e4]
    probs_l4 = torch.sigmoid(logits_l4)  
    idxs = (probs_l4 > thresh).nonzero(as_tuple=False).flatten().tolist()
    if idxs:
        cands = sorted(((li, float(probs_l4[li].item())) for li in idxs), key=lambda x: -x[1])
        print(f"--- L4 candidates with score > {thresh} ---")
        for li, sc in cands:
            gi = s4 + li
            if labels is not None:
                try:
                    lbl = labels[3][li] if isinstance(labels[0], (list, tuple)) else labels[gi]
                except Exception:
                    lbl = f"idx{li}"
            else:
                lbl = f"idx{li}"
            print(f"L4 candidate: {lbl} | score={sc:.6f} | gidx={gi}")
            
    return torch.stack(Rs, dim=0).mean(dim=0)

In [8]:
sequence = 'AAA'

In [11]:
exp = generate_relevance(sequence, model).detach()

L1: idx=0 | score=-4.133714 | gidx=0
L2: idx=18 | score=-3.383185 | gidx=25
L3: idx=173 | score=-2.259731 | gidx=252
L4: idx=4014 | score=-6.605191 | gidx=4361


In [12]:
kernel_size = 6
kernel = np.ones(kernel_size) / kernel_size
exp = np.convolve(exp, kernel, mode='same')

exp = exp - exp.min()
exp = exp / exp.max()

In [15]:
exp

array([0.97563864, 0.9851105 , 0.99296846, ..., 0.00512913, 0.0038235 ,
       0.00244637])

In [33]:
def plotting_alpha(value, text, color):
    img = plot(value, color)
    img = np.asarray(img, dtype=np.uint8)
    pil_img = Image.fromarray(img)
    draw = ImageDraw.Draw(pil_img)

    margin = 0
    for size in range(30, 5, -1):
        try:
            font = ImageFont.truetype("DejaVuSans.ttf", size=size)
        except:
            font = ImageFont.load_default()
        bbox = draw.textbbox((0, 0), str(text), font=font)
        w = bbox[2] - bbox[0]
        h = font.getmetrics()[0] + font.getmetrics()[1]
        if w <= pil_img.width - margin and h <= pil_img.height - margin:
            break

    ascent, descent = font.getmetrics()
    text_h = ascent + descent
    x = (pil_img.width - w) // 2
    y = (pil_img.height - text_h) // 2 + ascent//10

    draw.text((x, y), str(text), fill=(10, 10, 10), font=font)
    return np.array(pil_img, dtype=np.uint8)

In [34]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image, ImageDraw, ImageFont

r_img = np.zeros((30, 30 * len(sequence), 3), dtype=np.uint8)

cmap = plt.get_cmap('Reds')
norm = mcolors.Normalize(vmin=min(exp), vmax=max(exp))
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

for e in range(len(sequence)):
    r, g, b, a = sm.to_rgba(np.array(exp)[e])
    color = np.array([r, g, b]) * 255
    img = plotting_alpha(np.array(exp)[e], sequence[e], color)
    r_img[:, e * 30:(e + 1) * 30] = img

Image.fromarray(r_img).save("interpretation.png")