In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../'))

import captum

import torch
import torchtext
import torchtext.data

import torch.nn as nn
import torch.nn.functional as F

from torchtext.vocab import Vocab

from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
from torchtext.data.utils import get_tokenizer

from models.cnn.model import CNN
from datasets.scar import SCAR


tokenizer = get_tokenizer('basic_english')

In [2]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [3]:
# Seeing a psychiatrist
ckpt_name = "CNN_20230222-2048"
ckpt_path = os.path.join(r"C:\Users\jjnunez\PycharmProjects\scar_nlp_psych\results\final_results\dspln_PSYCHIATRY_12\CNN", ckpt_name + ".pt")


checkpoint = torch.load(ckpt_path)

model = CNN(config=checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = model.to(device)
print(model)

CNN(
  (embed): Embedding(22338, 300)
  (conv1): Conv2d(1, 500, kernel_size=(3, 300), stride=(1, 1), padding=(2, 0))
  (conv2): Conv2d(1, 500, kernel_size=(4, 300), stride=(1, 1), padding=(3, 0))
  (conv3): Conv2d(1, 500, kernel_size=(5, 300), stride=(1, 1), padding=(4, 0))
  (dropout): Dropout(p=0.85, inplace=False)
  (fc1): Linear(in_features=1500, out_features=1, bias=True)
)


In [4]:
def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))

In [5]:
config = checkpoint['config']
scar = SCAR(config.batch_size, config.data_dir, config.target, eval_only=False)



In [6]:
vocab = scar.vocab
itos = vocab.get_itos()
print(len(vocab))

22338


In [7]:
# PAD_IND = TEXT.vocab.stoi[TEXT.pad_token]
print(vocab['doctor'])
PAD_IND = vocab['<PAD>']
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(model, model.embed)

909


In [8]:
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 1500, label = 0):
    text = [tok for tok in tokenizer(sentence.lower())]
    if len(text) < min_len:
        text += ['<PAD>'] * (min_len - len(text))
    indexed = [vocab[t] for t in text]

    model.zero_grad()
   
    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)
    
    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)
    
    print(f'Here is the input_indices.size {input_indices.size()}')
    print(f'Here is the reference_indices.size {reference_indices.size()}')
    
    print(f'Here is the input_indices {input_indices}')
    print(f'Here is the reference_indices {reference_indices}')
    
    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=500, return_convergence_delta=True
                                          )
    # Replace Label with Text below
    print(itos[pred_ind])
    print('pred: ', itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            itos[pred_ind],
                            itos[label],
                            itos[1],
                            attributions.sum(),
                            text,
                            delta))

In [9]:
file = open('anon.txt',mode='r')# 
text = file.read()
file.close()

interpret_sentence(model, text, label=1)


  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


Here is the input_indices.size torch.Size([1, 1500])
Here is the reference_indices.size torch.Size([1, 1500])
Here is the input_indices tensor([[502,  17, 221,  ...,   3,   3,   3]])
Here is the reference_indices tensor([[3, 3, 3,  ..., 3, 3, 3]])
here is the single layer: Embedding(22338, 300)
line 294 if FALSE
entered hook_wrapper
entered registered forward_hooks
handle: <torch.utils.hooks.RemovableHandle object at 0x000000743308FB20>
about to run_forward
Entered run_forward
in forward_run, len(forward_func_args) != 0
entering forward_func: CNN(
  (embed): Embedding(22338, 300)
  (conv1): Conv2d(1, 500, kernel_size=(3, 300), stride=(1, 1), padding=(2, 0))
  (conv2): Conv2d(1, 500, kernel_size=(4, 300), stride=(1, 1), padding=(3, 0))
  (conv3): Conv2d(1, 500, kernel_size=(5, 300), stride=(1, 1), padding=(4, 0))
  (dropout): Dropout(p=0.85, inplace=False)
  (fc1): Linear(in_features=1500, out_features=1, bias=True)
)
entered forward_hook
if FALSE, here is the saved_layer: defaultdict(<

In [10]:
print('Visualize attributions based on Integrated Gradients')
_ = visualization.visualize_text(vis_data_records_ig)

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,(0.53),,0.59,"reason for referral ms . chan is a 41 year old nulligravid woman who has been referred for management of a large pelvic mass . history of present illness for the past few years , emma has noted more irregular cycles and had seen her family physician several times regarding this . she eventually went back to taiwan to obtain imaging in late 2012 . apparently this showed that she had ovarian cysts . when she returned to canada , she had a repeat pelvic ultrasound performed on april 14 , 2012 , revealing complex right adnexal lesion with a cystic component measuring 8 . 5 cm and a solid irregular marrow component measuring 3 . 3 cm . the uterus and left ovary appeared normal . she went on to see dr . cohen in late april and an endometrial biopsy was negative for malignancy . she had repeat ultrasound on may 22 , 2013 , which revealed that the right adnexal lesion was enlarging , now measuring 10 . 1 x 10 . 0 x 9 . 3 cm with multiple peripheral solid nodules with the largest measuring 4 . 5 cm , increased from 3 . 2 cm . the same day tumour markers revealed a slightly elevated ca125 of 42 , ca19 9 of 64 and a normal ca15 3 and cea . the patient is quite symptomatic from this mass and is noticing increasing lower abdominal pain especially on the right over the past 2 weeks . she finds that it is worse in the morning and has difficulty moving , but has been able to continue working after she discovered that a herbal supplement gives her relief . she has also noticed increasing bloating and bowel changes over the past 2 3 weeks . she now has loose bowel movements 3 4 times a day . she has increased urinary frequency and hesitancy . her appetite is lower , but she has not had any weight loss . gynecologic history she underwent menarche at age 13 . as described previously , she has noticed shorter cycles over the past 2 years every 20 25 days . her periods are light and she denies any dysmenorrhea . she denies any intermenstrual bleeding . she is nulligravid . she is sexually active and has no history of stis . she had a history of cin iii in 2008 which was treated with leep and her pap test has since been normal . she had her last mammogram about 3 4 years ago . past history healthy . past surgical history nil . medications multivitamin , herbal supplement for bloating which she recently started . allergies none known . social history she works in retail . she lives alone , but is currently in a relationship . she moved from taipei in 2001 . she is a nonsmoker and consumes approximately 3 alcoholic beverages a week . family history her maternal grandmother had breast at age 42 . her own mother is well . she also had a paternal grandmother with lymphoma . physical examination height 155 cm , weight 61 kg . on examination she appears younger than her stated . there is no supraclavicular or cervical lymphadenopathy . lungs are clear with no crackles or wheezes . no pleural effusion . cardiac examination reveals blood pressure of 109 65 and pulse 72 beats per minute . normal heart sounds and a very soft systolic murmur is heard at the apex . on abdominal exam , a mass is palpated up to 3 cm below the umbilicus . it is tender . she has hyperpigmented skin around her umbilicus from a previous burn . no inguinal lymphadenopathy . a speculum exam reveals a normal appearing cervix . please see dr . murphy ' s note for the rest of the pelvic examination . impression and plan this is a 41 year old woman who presents with an enlarging right adnexal mass . the recommendation would be for surgical removal however she is reluctant to undergo surgery and is insistent on keeping her ovaries . she has consented to at most an unilateral salpingo oophorectomy after discussion with dr . murphy . we will obtain baseline blood work and chest x ray today . we hope to perform the surgery as soon as possible . thank you for this referral . #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD #PAD"
,,,,


In [11]:
print(tokenizer)

<function _basic_english_normalize at 0x0000007402C0B790>


In [12]:
print(type(vis_data_records_ig))

<class 'list'>


In [34]:
viz_records = vis_data_records_ig[0]

print('hello')

print(viz_records.__dir__())

print(f'Len of raw_input is: {len(viz_records.raw_input)}')
print(f'Len of word_attributions is: {len(viz_records.word_attributions)}')

raw_words = viz_records.raw_input
importances = viz_records.word_attributions

cutoff = 0.05

# for i in range(10):
#    word = viz_records.raw_input[i]
#    importance = viz_records.word_attributions[i]
#    print(f'{word} {importance}')
    
for i in range(len(importances)):
    if importances[i] > 0.05:
        print(raw_words[i])
        
# Next, need to extract things at the sentance level. 

hello
['__module__', '__doc__', '__slots__', '__init__', 'attr_class', 'attr_score', 'convergence_score', 'pred_class', 'pred_prob', 'raw_input', 'true_class', 'word_attributions', '__repr__', '__hash__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', '__new__', '__reduce_ex__', '__reduce__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir__', '__class__']
Len of raw_input is: 1500
Len of word_attributions is: 1500
referral
41
year
old
mass
noted
on
may
multiple
peripheral
also
noticed
increasing
menarche
at
age
mammogram
which
started
her
maternal
grandmother
had
breast
at
age
also
appears
lungs
clear
reveals
will
x
today
referral
