In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import wandb

# import constants
from constants import *

# Add the parent directory to the system path
sys.path.append("..")

from utils import load_data_multilabel_pre_split , create_dataloaders, initialize_model, create_subset_dataloader, create_vocabulary_label_pre_split, create_vocabulary
from HAN_model import HierarchicalAttentionNetwork

from metrics import *
from train import train

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataloader, valid_dataloader, vocab_size = create_dataloaders()

cache_path: ../cache_vocabulary_label_pik/mimic3-ds-50-HAN_word_vocabulary.pik file_exists: True
load_data.started...
load_data_multilabel_new.data_path: ../datasets/data/train_50_eamc.csv
load_data.ended...
load_data.started...
load_data_multilabel_new.data_path: ../datasets/data/dev_50_eamc.csv
load_data.ended...
shuffled training data


In [3]:
tiny_loader = create_subset_dataloader(valid_dataloader, 1)

In [4]:
model = HierarchicalAttentionNetwork(vocab_size=vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_sentences=NUM_SENTENCES, sentence_length=SENTENCE_LENGTH, num_classes=NUM_CLASSES)

In [5]:
model.load_state_dict(torch.load("../checkpoints/20231128_1247_100epochs_yPaLn/best_valid_loss.pt", map_location='cpu'))

<All keys matched successfully>

In [6]:
def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))

In [7]:
vocabulary_word2index, vocabulary_index2word = create_vocabulary(WORD2VEC_MODEL_PATH,name_scope=DATASET + "-HAN")

cache_path: ../cache_vocabulary_label_pik/mimic3-ds-50-HAN_word_vocabulary.pik file_exists: True


In [25]:
vocabulary_word2index_label,vocabulary_index2word_label = create_vocabulary_label_pre_split(training_data_path=TRAINING_DATA_PATH, validation_data_path=VALIDATION_DATA_PATH, testing_data_path=TESTING_DATA_PATH, name_scope=DATASET + "-HAN")

In [8]:
PAD_ID = 0
token_reference = TokenReferenceBase(reference_token_idx=PAD_ID)

In [9]:
lig = LayerIntegratedGradients(model, model.embeddings)

In [10]:
vis_data_records_ig = []

In [11]:
x, y = next(iter(tiny_loader))
pred = forward_with_sigmoid(x)

In [12]:
pred.shape

torch.Size([1, 50])

In [13]:
pred

tensor([[6.9934e-01, 1.1955e-02, 8.1924e-02, 6.1842e-03, 3.3330e-01, 6.4586e-02,
         1.9102e-02, 2.0049e-02, 3.1443e-01, 7.5767e-01, 4.9988e-01, 8.2690e-02,
         4.2240e-02, 6.7354e-04, 2.8976e-02, 5.4142e-02, 1.5240e-01, 2.6216e-01,
         1.4898e-02, 6.3594e-03, 1.1615e-01, 1.7024e-03, 5.8901e-05, 3.7454e-02,
         2.1619e-01, 1.1237e-04, 9.1838e-03, 1.8380e-02, 1.5117e-03, 8.1204e-03,
         1.8971e-02, 5.1045e-03, 1.3901e-02, 5.1605e-02, 9.4627e-03, 4.3310e-02,
         5.1580e-03, 3.2039e-02, 8.2546e-02, 1.7999e-03, 3.4053e-01, 1.2028e-03,
         4.5430e-03, 2.8430e-03, 2.4618e-04, 7.3593e-02, 2.6221e-03, 1.6576e-02,
         1.6066e-03, 6.0232e-04]], grad_fn=<SigmoidBackward0>)

In [14]:
torch.argmax(pred, dim=1)

tensor([9])

In [15]:
reference_indices = token_reference.generate_reference(SEQUENCE_LENGTH, device='cpu').unsqueeze(0)


In [17]:
reference_indices.shape

torch.Size([1, 2500])

In [18]:
x.shape

torch.Size([1, 2500])

In [20]:
# compute attributions and approximation delta using layer integrated gradients
attributions_ig, delta = lig.attribute(x, reference_indices, target = torch.argmax(pred, dim=1).item(), return_convergence_delta=True)

In [24]:
attributions = attributions_ig.sum(dim=2).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.cpu().detach().numpy()

In [45]:
vis_data_records_ig = []
pred_ind = torch.argmax(pred, dim=1)
pred_label = vocabulary_index2word_label[pred_ind.item()]
pred_label

'96.6'

In [30]:
y[:,pred_ind]

tensor([1.])

In [33]:
text = [vocabulary_index2word[i.item()] for i in x[0]]

In [46]:
pred_ind

tensor([9])

In [55]:
f"{pred[:,pred_ind.item()].item():.2f}"

'0.76'

In [56]:
vis_data_records_ig.append(visualization.VisualizationDataRecord(
    attributions,
    pred[:,pred_ind.item()].item(),
    pred_label,
    pred_label,
    pred_label,
    attributions.sum(),
    text,
    delta))
    


In [61]:
_ = visualization.visualize_text(vis_data_records_ig)

TypeError: unsupported format string passed to Tensor.__format__