# Explaining Medical coding

Notebook này có sử dụng một số phần trong mã nguồn công trình Explainable Medical Coding (Joakim Edin et. al - https://github.com/JoakimEdin/explainable-medical-coding)

Notebook này thực hiện:
- Đánh giá khả năng giải thích của mô hình PLM-ICD trên nhiệm vụ gán nhãn ICD tự động
- So sánh phương pháp giải thích ILCA với các phương pháp giải thích khác
- Nhóm tác giả có thay đổi một số mục trong mã nguồn của Edin:
  - Bổ sung phương pháp giải thích Invert Label Cross Attention
  - Thay đổi một số config training
  - Mã nguồn đã qua chỉnh sửa https://github.com/chancholat/explain-icd.git
  - Nếu sử dụng Colab hoặc Kaggle, vui lòng restart kernel notebook sau khi reinstall các package ở cell đầu tiên

In [None]:
!git clone https://github.com/chancholat/explain-icd.git
!pip install -q python-dotenv==1.0.0
!pip install -q datasets==3.4.1
!pip install -q omegaconf==2.3.0
!pip install -q captum==0.7.0
!pip install -q --force-reinstall transformers==4.38.1
!pip install -q --force-reinstall numpy==2.2.0
!pip install -q hydra-core

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/487.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m481.3/487.4 kB[0m [31m14.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/183.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.
torch 2.6.0+cu124 requires n

In [None]:
import os

# Change to a specific directory
os.chdir("/content/explain-icd")

In [None]:
import kagglehub

#Download the latest version.
path = kagglehub.dataset_download('chanhainguyen/thesis-data-process')
print(path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/chanhainguyen/thesis-data-process?dataset_version_number=1...


100%|██████████| 2.95G/2.95G [00:18<00:00, 168MB/s]

Extracting files...





/root/.cache/kagglehub/datasets/chanhainguyen/thesis-data-process/versions/1


In [None]:
!mv /root/.cache/kagglehub/datasets/chanhainguyen/thesis-data-process/versions/1/processed /content/explain-icd/data
# !mv /kaggle/input/thesis-data-process/processed /content/explain-icd/data

## Evaluate Plausibility

In [None]:
from omegaconf import OmegaConf
from pathlib import Path
# Define the config path
experimant_config_path = Path("explainable_medical_coding/config/explainability.yaml")

# Load the configuration
exp_config = OmegaConf.load(experimant_config_path)

In [None]:
mdace_inpatent_icd9 = OmegaConf.create({
    "dataset_path": "explainable_medical_coding/datasets/mdace_inpatient_icd9.py",
    "target_columns": ["diagnosis_codes", "procedure_codes"],
    "max_length": 6000
})
exp_config.data = mdace_inpatent_icd9

In [None]:
# exp_config.model_name = "igr"
# exp_config.run_id = 'igr/1p0vue7o'
# exp_config.model_name = "pgd"
# exp_config.run_id = 'pgd/06mt02mq'
# exp_config.model_name = "supervised"
# exp_config.run_id = 'supervised/4wj6cabu'
# exp_config.model_name = "tm"
# exp_config.run_id = 'tm/3hvfq75j'
# exp_config.model_name = "unsupervised"
# exp_config.run_id = 'unsupervised/0fom6iwn'


# exp_config.model_name = "tm"
# exp_config.run_id = 'tm/a6ulgei0'

# exp_config.model_name = "pgd"
# exp_config.run_id = 'pgd/9hpw0up3'

# exp_config.model_name = "pgd"
# exp_config.run_id = 'pgd/b213y2m6'

# exp_config.model_name = "igr"
# exp_config.run_id = 'igr/kbs093u4'

# exp_config.model_name = "tm"
# exp_config.run_id = 'tm/a6ulgei0'

# exp_config.model_name = "unsupervised"
# exp_config.run_id = 'unsupervised/ov55kelz'

exp_config.model_name = "supervised"
exp_config.run_id = 'supervised/v5vsimfr'


exp_config.explainers=["laat", "deeplift", "gradient_x_input", "grad_attention", "atgrad_attention", "invert_label_att"]
exp_config.evaluate_faithfulness=False

In [None]:
import torch
import os
import sys
from pathlib import Path
sys.path.append('/content/explain-icd')

from transformers import AutoTokenizer
from explainable_medical_coding.utils.tokenizer import TargetTokenizer
from explainable_medical_coding.utils.loaders import (
    load_and_prepare_dataset,
    load_trained_model,
)

  diet_gradient_scaler = torch.cuda.amp.GradScaler()
  advesarial_noise_gradient_scaler = torch.cuda.amp.GradScaler()


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "CUDA_VISIBLE_DEVICES" not in os.environ:
  print("CUDA_VISIBLE_DEVICES not set")
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
target_columns = list(exp_config.data.target_columns)
dataset_path = Path(exp_config.data.dataset_path)
model_folder_path = Path(exp_config.model_folder_path)
run_id = exp_config.run_id
model_path = model_folder_path / run_id
saved_config = OmegaConf.load(model_path / "config.yaml")
text_tokenizer_path = saved_config.model.configs.model_path

# target_tokenizer.load(experiment_path / "target_tokenizer.json")
target_tokenizer = TargetTokenizer(autoregressive=False)
target_tokenizer.load(model_path / "target_tokenizer.json")

text_tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_path)
max_input_length = int(saved_config.data.max_length)

CUDA_VISIBLE_DEVICES not set


In [None]:
dataset = load_and_prepare_dataset(
    dataset_path, text_tokenizer, target_tokenizer, max_input_length, target_columns
)
dataset = dataset.filter(
    lambda x: x["note_type"] == "Discharge summary",
    desc="Filtering all notes that are not discharge summaries",
)

model, decision_boundary = load_trained_model(
    model_path,
    saved_config,
    pad_token_id=text_tokenizer.pad_token_id,
    device=device,
)

results_dir = Path("reports/explainability_results/") /  run_id
results_dir.mkdir(parents=True, exist_ok=True)

The repository for mdace_inpatient_icd9 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mdace_inpatient_icd9.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Tokenizing text (num_proc=8):   0%|          | 0/355 [00:00<?, ? examples/s]

Tokenizing text (num_proc=8):   0%|          | 0/122 [00:00<?, ? examples/s]

Tokenizing text (num_proc=8):   0%|          | 0/127 [00:00<?, ? examples/s]

Creating targets column:   0%|          | 0/355 [00:00<?, ? examples/s]

Creating targets column:   0%|          | 0/122 [00:00<?, ? examples/s]

Creating targets column:   0%|          | 0/127 [00:00<?, ? examples/s]

Number of test targets before filtering: 702 


Filter unknown targets:   0%|          | 0/355 [00:00<?, ? examples/s]

Filter unknown targets:   0%|          | 0/122 [00:00<?, ? examples/s]

Filter unknown targets:   0%|          | 0/127 [00:00<?, ? examples/s]

Number of test targets after filtering: 702 


Filtering empty targets:   0%|          | 0/355 [00:00<?, ? examples/s]

Filtering empty targets:   0%|          | 0/122 [00:00<?, ? examples/s]

Filtering empty targets:   0%|          | 0/127 [00:00<?, ? examples/s]

Converting targets to IDs:   0%|          | 0/355 [00:00<?, ? examples/s]

Converting targets to IDs:   0%|          | 0/122 [00:00<?, ? examples/s]

Converting targets to IDs:   0%|          | 0/127 [00:00<?, ? examples/s]

Map:   0%|          | 0/355 [00:00<?, ? examples/s]

Map:   0%|          | 0/122 [00:00<?, ? examples/s]

Map:   0%|          | 0/127 [00:00<?, ? examples/s]

Filtering all notes that are not discharge summaries:   0%|          | 0/355 [00:00<?, ? examples/s]

Filtering all notes that are not discharge summaries:   0%|          | 0/122 [00:00<?, ? examples/s]

Filtering all notes that are not discharge summaries:   0%|          | 0/127 [00:00<?, ? examples/s]

{'name': 'PLMICD', 'autoregressive': False, 'configs': {'model_path': 'models/roberta-base-pm-m3-voc-hf', 'chunk_size': 128, 'cross_attention': True, 'loss': 'binary_cross_entropy', 'lambda_1': 0.0, 'scale': 1, 'mask_input': False}}


In [None]:
from explainable_medical_coding.eval.plausibility_metrics import evaluate_plausibility_and_sparsity

evaluate_plausibility_and_sparsity(
  model=model,
  model_path=model_path,
  datasets=dataset,
  text_tokenizer=text_tokenizer,
  target_tokenizer=target_tokenizer,
  decision_boundary=decision_boundary,
  explainability_methods=exp_config.explainers,
  cache_explanations=exp_config.cache_explanations,
  save_path=results_dir / "plausibility_and_sparsity.csv",
)

Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






Output()

  df = pl.DataFrame(schema=schema, data=rows)


Output()

  df = pl.DataFrame(schema=schema, data=rows)






In [None]:
import shutil
from google.colab import files

shutil.make_archive(run_id.replace('/','_'), 'zip', results_dir)
files.download(f"{run_id.replace('/','_')}.zip")  # Replace with your file path

'/content/explain-icd/supervised_v5vsimfr.zip'

## Testing

### Show explains

In [None]:
!ls -a .cache/

.
..
d3aa98ab6d411894_3bf94d6802fd709e_9ab1f4ba369b5b47_5a13c741343b6cac.parquet
d3aa98ab6d411894_46626e54961c0fe9_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_46626e54961c0fe9_9ab1f4ba369b5b47_5a13c741343b6cac.parquet
d3aa98ab6d411894_55c64f476e83010f_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_5f729e7f31df513f_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_5f729e7f31df513f_9ab1f4ba369b5b47_5a13c741343b6cac.parquet
d3aa98ab6d411894_8d8d47e5a419e9aa_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_99adcd9e1660b5a7_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_99adcd9e1660b5a7_9ab1f4ba369b5b47_5a13c741343b6cac.parquet
d3aa98ab6d411894_afdfdd4ac638be04_92297ef836b1e4c5_5a13c741343b6cac.parquet
d3aa98ab6d411894_afdfdd4ac638be04_9ab1f4ba369b5b47_5a13c741343b6cac.parquet
d3aa98ab6d411894_cfd9c4c4bd681e39_9ab1f4ba369b5b47_5a13c741343b6cac.parquet


In [None]:
# import os

# folder1 = "/content/explain-icd/reports/cache/tm/3hvfq75j"
# folder2 = "/content/explain-icd/reports/cache/unsupervised/0fom6iwn"

# for filename in os.listdir(folder1):
#     file_path = os.path.join(folder2, filename)  # Construct path in folder2
#     if os.path.exists(file_path):  # Check if file exists
#         os.remove(file_path)  # Delete file
#         print(f"Deleted: {file_path}")
#     else:
#         print(f"Not found: {file_path}")


In [None]:
import shutil

cache_path = f"/content/explain-icd/reports/cache/{run_id}"
src_folder = Path(".cache/")
dst_folder = Path(cache_path)

dst_folder.mkdir(parents=True, exist_ok=True)

for file in src_folder.iterdir():
    if file.is_file():
        shutil.copy2(file, dst_folder / file.name)

In [None]:
shutil.make_archive(f"cache_{run_id.replace('/','_')}", 'zip', '/content/explain-icd/reports/cache/')

'/content/explain-icd/cache_supervised_v5vsimfr.zip'

In [None]:
files.download('.cache/993fdbf01d5d514f_02684f16e21b842b_bf3a5555623e3e78_9c04c226392f283b.parque')  # Replace with your file path
files.download('.cache/993fdbf01d5d514f_02684f16e21b842b_bf3a5555623e3e78_9c04c226392f283b.parque')  # Replace with your file path

In [None]:
import polars as pl

explanations_val_df = pl.read_parquet(".cache/993fdbf01d5d514f_02684f16e21b842b_bf3a5555623e3e78_9c04c226392f283b.parquet") # validation
explanations_test_df = pl.read_parquet(".cache/993fdbf01d5d514f_c97cb0a2a071cda6_66a672f034179133_9c04c226392f283b.parquet")  # test

In [None]:
# First half (0 to 4)
split_idx = 5
df1 = df.slice(0, split_idx)

# Second half (5 to end)
df2 = df.slice(split_idx, len(df) - split_idx)
print(df1.head())

explanations_val_df = df1
explanations_test_df = df2

shape: (5, 5)
┌─────────┬───────────┬──────────┬─────────────────────────────────┬────────────────────┐
│ note_id ┆ target_id ┆ y_prob   ┆ attributions                    ┆ evidence_token_ids │
│ ---     ┆ ---       ┆ ---      ┆ ---                             ┆ ---                │
│ str     ┆ i64       ┆ f64      ┆ list[f64]                       ┆ list[i64]          │
╞═════════╪═══════════╪══════════╪═════════════════════════════════╪════════════════════╡
│ 23168   ┆ 3493      ┆ 0.80501  ┆ [0.0017, 0.000012, … 0.000073]  ┆ [625]              │
│ 23168   ┆ 4264      ┆ 0.912789 ┆ [0.000302, 0.000028, … 0.00010… ┆ [144, 145, … 636]  │
│ 23168   ┆ 8104      ┆ 0.960634 ┆ [0.000187, 0.000005, … 0.00036… ┆ [2372, 2373, 2374] │
│ 23168   ┆ 2827      ┆ 0.788849 ┆ [0.000059, 0.000003, … 0.00040… ┆ [669]              │
│ 23168   ┆ 4751      ┆ 0.049201 ┆ [0.002332, 0.000015, … 0.0004]  ┆ [640]              │
└─────────┴───────────┴──────────┴─────────────────────────────────┴──────────────────

In [None]:
print(explanations_test_df.head())

shape: (5, 5)
┌─────────┬───────────┬──────────┬─────────────────────────────────┬────────────────────┐
│ note_id ┆ target_id ┆ y_prob   ┆ attributions                    ┆ evidence_token_ids │
│ ---     ┆ ---       ┆ ---      ┆ ---                             ┆ ---                │
│ str     ┆ i64       ┆ f64      ┆ list[f64]                       ┆ list[i64]          │
╞═════════╪═══════════╪══════════╪═════════════════════════════════╪════════════════════╡
│ 23168   ┆ 3493      ┆ 0.80501  ┆ [2.6812e-9, 4.5563e-11, … 1.19… ┆ [625]              │
│ 23168   ┆ 4264      ┆ 0.912789 ┆ [2.1894e-10, 3.8074e-11, … 3.3… ┆ [144, 145, … 636]  │
│ 23168   ┆ 8104      ┆ 0.960634 ┆ [2.6349e-11, 1.7133e-12, … 1.0… ┆ [2372, 2373, 2374] │
│ 23168   ┆ 2827      ┆ 0.788849 ┆ [3.4599e-10, 3.8024e-11, … 1.7… ┆ [669]              │
│ 23168   ┆ 4751      ┆ 0.049201 ┆ [0.000001, 2.3029e-8, … 2.8749… ┆ [640]              │
└─────────┴───────────┴──────────┴─────────────────────────────────┴──────────────────

In [None]:
from explainable_medical_coding.eval.plausibility_metrics import find_explanation_decision_boundary

explanation_decision_boundary = find_explanation_decision_boundary(explanations_val_df)  # use validation set to find the decision boundary

In [None]:
print(explanation_decision_boundary)

0.098


In [None]:
from explainable_medical_coding.eval.plausibility_metrics import attributions2token_ids

def decode_target_tokens(token_ids):
    return target_tokenizer.decode(token_ids)

pred_groundtruth = explanations_test_df.select(
    note_id=pl.col("note_id"),
    target_text=pl.col("target_id").map_elements(lambda x: decode_target_tokens([x])),
    predicted_token_ids=pl.col("attributions").map_elements(
        lambda x: attributions2token_ids(x, explanation_decision_boundary)
    ),
    evidence_token_ids=pl.col("evidence_token_ids"),
)



In [None]:
from datasets import load_dataset
import polars as pl
from transformers import AutoTokenizer


# Load dataset once to avoid multiple calls
check_dataset = load_dataset("explainable_medical_coding/datasets/mdace_inpatient_icd9.py", name="mdace_inpatient_icd9")

# Create a dictionary for fast lookup (note_id → original_text)
note_text_lookup = {row["note_id"]: row["text"] for row in check_dataset["test"]}

def reconstruct_text(note_id: str, token_ids: list[int]) -> str:
    """
    Given a note_id and a list of token_ids, reconstruct the text.

    Args:
        note_id (str): The note identifier.
        token_ids (list[int]): List of token IDs representing extracted text.

    Returns:
        str: Reconstructed text.
    """
    # Retrieve original text
    original_text = note_text_lookup.get(note_id, None)
    if original_text is None:
        return "Unknown Note ID"

    # Tokenize original text
    tokenized = text_tokenizer(original_text, return_offsets_mapping=True)
    offsets = tokenized["offset_mapping"]  # Character spans of each token

    # Extract token spans using token_ids
    extracted_spans = [original_text[start:end] for idx, (start, end) in enumerate(offsets) if idx in token_ids]

    # Reconstruct text by joining extracted spans
    return " ".join(extracted_spans)

# Apply the function to a Polars DataFrame
df = pred_groundtruth.with_columns(
    pl.struct(["note_id", "predicted_token_ids"]).map_elements(
        lambda row: reconstruct_text(row["note_id"], row["predicted_token_ids"]),
        return_dtype=pl.Utf8  # Specify output type to avoid warnings
    ).alias("predicted_text")
)

df = df.with_columns(
    pl.struct(["note_id", "evidence_token_ids"]).map_elements(
        lambda row: reconstruct_text(row["note_id"], row["evidence_token_ids"]),
        return_dtype=pl.Utf8
    ).alias("evidence_text")
)

In [None]:
print(df.head())

shape: (5, 6)
┌─────────┬─────────────┬───────────────────┬───────────────────┬──────────────────┬───────────────┐
│ note_id ┆ target_text ┆ predicted_token_i ┆ evidence_token_id ┆ predicted_text   ┆ evidence_text │
│ ---     ┆ ---         ┆ ds                ┆ s                 ┆ ---              ┆ ---           │
│ str     ┆ list[str]   ┆ ---               ┆ ---               ┆ str              ┆ str           │
│         ┆             ┆ list[i64]         ┆ list[i64]         ┆                  ┆               │
╞═════════╪═════════════╪═══════════════════╪═══════════════════╪══════════════════╪═══════════════╡
│ 23168   ┆ ["493.22"]  ┆ []                ┆ []                ┆                  ┆               │
│ 23168   ┆ ["530.81"]  ┆ [651, 657, …      ┆ [658, 659]        ┆ - / G ERD / hi   ┆ G ERD         │
│         ┆             ┆ 3120]             ┆                   ┆ atal hernia # G… ┆               │
│ 23168   ┆ ["493.92"]  ┆ [2202, 3375]      ┆ [139, 633]        ┆ acerb acerb

### A little testing

In [None]:
from explainable_medical_coding.utils.analysis import predict

test_dataset = dataset["test"]
example_idx = 0
note_id = test_dataset["note_id"][example_idx]
input_ids = test_dataset["input_ids"][example_idx].to(device).unsqueeze(0)
ground_truth_target_ids = test_dataset["target_ids"][example_idx].tolist()
evidence_input_ids = test_dataset["evidence_input_ids"][example_idx]
target_id2evidence_input_ids = {
    ground_truth_target_id: evidence_input_ids[idx]
    for idx, ground_truth_target_id in enumerate(ground_truth_target_ids)
}

y_probs = predict(model, input_ids, device).cpu()[0]
predicted_target_ids = torch.where(y_probs > decision_boundary)[0].tolist()
target_ids = torch.tensor(
    list(set(ground_truth_target_ids) | set(predicted_target_ids))
)

In [None]:
print(target_ids.shape)

torch.Size([11])


In [None]:
# !mv /content/explain-icd/modeling_roberta.py /usr/local/lib/python3.11/dist-packages/transformers/models/roberta/modeling_roberta.py

In [None]:
from explainable_medical_coding.explainability.helper_functions import create_attention_mask

#from explanability_methods
input_ids = input_ids.to(device)
sequence_length = input_ids.shape[1]
attention_mask = create_attention_mask(input_ids)
print(input_ids.shape)
print(attention_mask.shape)

#from models.py
#from encoder
input_ids = model.split_input_into_chunks(input_ids, model.pad_token_id)
if attention_mask is not None:
  attention_masks = model.get_chunked_attention_masks(attention_mask)

batch_size, num_chunks, chunk_size = input_ids.size()
last_hidden_state, pooler_output = model.roberta_encoder(
            input_ids=input_ids.view(-1, chunk_size),
            attention_mask=attention_masks.view(-1, chunk_size)
            if attention_masks is not None
            else None,
            return_dict=False
        )
# last_hidden_state, pooler_output, hidden_states, attentions = model.roberta_encoder(
#             input_ids=input_ids.view(-1, chunk_size),
#             attention_mask=attention_masks.view(-1, chunk_size)
#             if attention_masks is not None
#             else None,
#             return_dict=False,
#             output_hidden_states=True,
#             output_attentions=True,
#         )

# last_hidden_state, pooler_output, norms  = model.roberta_encoder(
#             input_ids=input_ids.view(-1, chunk_size),
#             attention_mask=attention_masks.view(-1, chunk_size)
#             if attention_masks is not None
#             else None,
#             return_dict=False,
#             output_norms=True
#         )
hidden_output = last_hidden_state.view(batch_size, num_chunks * chunk_size, -1)

torch.Size([1, 3592])
torch.Size([1, 3592])


In [None]:
print(input_ids.shape)
print(attention_mask.shape)

torch.Size([1, 29, 128])
torch.Size([1, 3592])


In [None]:
_, label_attentions = model.label_wise_attention(
            hidden_output,
            attention_masks=attention_mask,
            output_attention=True,
            attn_grad_hook_fn=None,
        )

In [None]:
print(label_attentions.shape)

torch.Size([1, 8943, 3712])


### Go for new explainer

In [None]:
lable_cross_attention = model.label_wise_attention

V = lable_cross_attention.weights_v(hidden_output) # [batch_size, seq_len, input_size]
K = lable_cross_attention.weights_k(hidden_output) # [batch_size, seq_len, input_size]
Q = lable_cross_attention.label_representations # [num_classes, input_size]

In [None]:
print(V.shape)
print(K.shape)
print(Q.shape)

torch.Size([1, 3712, 768])
torch.Size([1, 3712, 768])
torch.Size([8943, 768])


In [None]:
att_weights = K.matmul(Q.transpose(1, 0))  # [batch_size, seq_len, num_classes]

In [None]:
print(att_weights.shape)

torch.Size([1, 3712, 8943])


In [None]:
attention_mask = torch.nn.functional.pad(
                attention_mask, (0, hidden_output.size(1) - attention_mask.size(1)), value=0
            )
attention_mask = attention_mask.to(torch.bool)
# repeat attention masks for each class
attention_mask = attention_mask.unsqueeze(2).repeat(
    1, 1, lable_cross_attention.num_classes
)
attention_mask = attention_mask.masked_fill_(
    attention_mask.logical_not(), float("-inf")
)
att_weights += attention_mask

In [None]:
print(attention_mask.shape)

torch.Size([1, 3712, 8943])


In [None]:
attention = torch.softmax(
            att_weights / lable_cross_attention.scale, dim=2
        )  # [batch_size,  seq_len, num_classes]

In [None]:
print(attention.shape)
print(att_weights.shape)

torch.Size([1, 3712, 8943])
torch.Size([1, 3712, 8943])


In [None]:
attention = attention.squeeze(
    0
).T.detach()  # [sequence_length+padding, num_classes]
attention = attention[target_ids, :sequence_length]

In [None]:
attributions = attention.transpose(0, 1).cpu()

In [None]:
print(attributions.shape)

torch.Size([3592, 11])


### Process the attribution

In [None]:
#form analysis
rows = []
for idx, target_id in enumerate(target_ids):
    row = [
        note_id,
        target_id.item(),
        y_probs[target_id].item(),
        attributions[:, idx].tolist(),
        target_id2evidence_input_ids.get(target_id.item(), []),
    ]
    rows.append(row)

In [None]:
import polars as pl
schema = {
        "note_id": pl.Utf8,
        "target_id": pl.Int64,
        "y_prob": pl.Float64,
        "attributions": pl.List(pl.Float64),
        "evidence_token_ids": pl.List(pl.Int64),
    }
df = pl.DataFrame(schema=schema, data=rows)

  df = pl.DataFrame(schema=schema, data=rows)
