In [1]:
import numpy as np
import time
import sys
import os
import tqdm
import operator

import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue

from dataset.dataset import ImageFeatureDataset
from torch.utils.data import DataLoader
from transformer_ethan import *
from catr.configuration import Config

In [2]:
words = np.load("glove_embed.npy")
with open('word2ind.json') as json_file: 
    word2ind = json.load(json_file) 
with open('ind2word.json') as json_file: 
    ind2word = json.load(json_file) 

config = Config()
config.device = 'cpu'
config.feature_dim = 1024
config.pad_token_id = word2ind["<S>"]
config.hidden_dim = 300
config.nheads = 10
config.batch_size = 64
config.encoder_type = 1
config.vocab_size = words.shape[0]
config.dir = '../mimic_features_double'
config.__dict__["pre_embed"] = torch.from_numpy(words).to(config.device)

In [3]:
model, criterion = main(config) 
model = model.float()
device = torch.device(config.device)
model.to(device)

Initializing Device: cpu
Number of params: 33370520


In [186]:
dataset_val = ImageFeatureDataset(config, mode='val')
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = DataLoader(dataset_val, 1,
                             sampler=sampler_val, drop_last=False, num_workers=config.num_workers)

In [4]:
# Load from checkpoint
CHECKPOINT_NAME = "../../models/checkpoint_imagenet_10_tf.pth"

checkpoint = torch.load(CHECKPOINT_NAME, map_location=device)
model.load_state_dict(checkpoint['model'])

print("Loading Checkpoint...")

Loading Checkpoint...


In [5]:
import nltk

def get_bleu_score(truths, predicteds, weights_to_be_used=[0.25, 0.25, 0.25, 0.25]):
    scores = []
    for index in range(len(truths)):
        truth = truths[index]
        predicted = predicteds[index]
        try:
            score = nltk.translate.bleu_score.sentence_bleu([truth], predicted, weights=weights_to_be_used)
        except:
            score = 0
        scores.append(score)
    return sum(scores)/len(scores)

In [18]:
# Edward: note this makes a new caption as (<S>, 0, ..., 0) shouldn't we want as (<S>, <S>, ..., <S>)?
def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

def make_report(captions):
    all_reports = []
    for report in captions:
        if (report == word2ind["</s>"]).any():
            end_index = (report == word2ind["</s>"]).nonzero()[0][0]
            report = report[:end_index+1]
        one_report = list(map(lambda x: ind2word[str(x)], report))
        all_reports.append(one_report)
    return all_reports

def reports_to_sentence(reports):
    return [' '.join(r) for r in make_report(reports)]

def evaluate(images):
    all_captions = []
    model.eval()
    for i in range(len(images)):
        image = images[i:i+1]
        caption, cap_mask = create_caption_and_mask(
            config.pad_token_id, config.max_position_embeddings)
#         caption.to("cuda")
        for i in range(config.max_position_embeddings - 1):
            with torch.no_grad():
                predictions = model(image, caption, cap_mask).to(config.device)
            predictions = predictions[:, i, :]
            predicted_id = torch.argmax(predictions, axis=-1)


            caption[:, i+1] = predicted_id[0]
            cap_mask[:, i+1] = False
            
            if predicted_id[0] == word2ind["</s>"]:
                break

        all_captions.append(caption.numpy())
#     return make_report(all_captions)
    return all_captions

def bleu(truth_in, generated_in):
    truth = truth_in.replace("<S>", "").replace("<s>", "").replace("</s>", "").replace(".", "").replace(",", "").replace("  ", " ").split(" ")
    generated = generated_in.replace("<S>", "").replace("<s>", "").replace("</s>", "").replace(".", "").replace(",", "").replace("  ", " ").split(" ")
    truth = [y for y in truth if y != ''] 
    generated = [y for y in generated if y != ''] 
    bs4 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[0.25, 0.25, 0.25, 0.25])
    bs3 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[1./3., 1./3., 1./3.])
    bs2 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[0.5, 0.5])
    bs1 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[1.])
    return bs1, bs2, bs3, bs4

In [229]:
class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length

    def eval(self, alpha=1.0):
        return self.logp / float(self.leng - 1 + 1e-6)
    
# def beam_decode(model, config, target_tensor, decoder_hiddens, encoder_outputs=None):
def beam_decode(model, image, beam_width=5, topk=1):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''
    
    model.eval()

#     beam_width = 5
#     topk = 5  # how many sentence do you want to generate
    decoded_batch = []
    
    SOS_token = word2ind["<S>"]
    EOS_token = word2ind["</s>"]
    
    caption, cap_mask = create_caption_and_mask(
#             config.pad_token_id, config.max_position_embeddings)
            SOS_token, config.max_position_embeddings)
    caption[:,1:] = EOS_token
    
    # Number of sentence to generate
    endnodes = []
    number_required = min((topk + 1), topk - len(endnodes))
#     print(number_required, "required.")
    
    # starting node -  hidden vector, previous node, word id, logp, length
    node = BeamSearchNode(cap_mask, None, caption, 0, 1)
    nodes = PriorityQueue()

    # start the queue
    nodes.put((-node.eval(), node))
    qsize = 1
    
    # start beam search
    while True:
        # give up when decoding takes too long
        if qsize > 5000: 

            print("taking too long")
            break

        # fetch the best node
        score, n = nodes.get()
        decoder_input = n.wordid
        decoder_mask = n.h
#         decoder_hidden = n.h

        if n.leng == config.max_position_embeddings or n.wordid[0,n.leng-1].item() == EOS_token:
            endnodes.append((score, n))
            # if we reached maximum # of sentences required
            if len(endnodes) >= number_required:
                break
            else:
                continue

#         if n.wordid.item() == EOS_token and n.prevNode != None:
#             endnodes.append((score, n))
#             # if we reached maximum # of sentences required
#             if len(endnodes) >= number_required:
#                 break
#             else:
#                 continue

        # decode for one step using decoder
#         decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)

        
        with torch.no_grad():
            predictions = model(image, decoder_input, decoder_mask).to(config.device)
        predictions = predictions[0, n.leng-1, :] # shape: [22058]
#         predicted_id = torch.argmax(predictions, axis=-1)

        # PUT HERE REAL BEAM SEARCH OF TOP
        log_prob, indexes = torch.topk(predictions, beam_width)
        nextnodes = []

        for new_k in range(beam_width):
            predicted_id = indexes[new_k]
            log_p = log_prob[new_k].item()
            
            new_caption = decoder_input.detach().clone()
            new_mask = decoder_mask.detach().clone()
            
            new_caption[:, n.leng] = predicted_id
            new_mask[:, n.leng] = False

            node = BeamSearchNode(new_mask, n, new_caption, n.logp + log_p, n.leng + 1)
            score = -node.eval()
            nextnodes.append((score, node))

        # put them into queue
        for i in range(len(nextnodes)):
            score, nn = nextnodes[i]
            nodes.put((score, nn))
            # increase qsize
        qsize += len(nextnodes) - 1

    # choose nbest paths, back trace them
    if len(endnodes) < topk:
        for _ in range(topk-len(endnodes)):
            endnodes.append(nodes.get())

    print("Got ", len(endnodes), " end nodes")
    utterances = []
    for score, n in sorted(endnodes, key=operator.itemgetter(0)):
#         utterance = []
#         utterance.append(n.wordid)
#         # back trace
#         while n.prevNode != None:
#             n = n.prevNode
#             utterance.append(n.wordid)

#         utterance = utterance[::-1]
        utterances.append(n.wordid)

    return utterances
        
        
    # decoding goes sentence by sentence
    for idx in range(target_tensor.size(0)):
        if isinstance(decoder_hiddens, tuple):  # LSTM case
            decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
        else:
            decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
#         encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)

        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]], device=device)

        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))

        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = PriorityQueue()

        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1

        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 2000:
                print("taking too long")
                break

            # fetch the best node
            score, n = nodes.get()
            decoder_input = n.wordid
            decoder_hidden = n.h

            if n.wordid.item() == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    print("Got all required")
                    break
                else:
                    continue

            # decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)

            # PUT HERE REAL BEAM SEARCH OF TOP
            log_prob, indexes = torch.topk(decoder_output, beam_width)
            nextnodes = []

            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k].view(1, -1)
                log_p = log_prob[0][new_k].item()

                node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))

            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1

        # choose nbest paths, back trace them
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]

        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)

            utterance = utterance[::-1]
            utterances.append(utterance)

        decoded_batch.append(utterances)

    return decoded_batch

In [230]:
ETHAN_NOTE = "<S> The heart size is normal . <s> The hilar and mediastinal contours are normal . <s> No focal consolidations concerning for pneumonia are identified . <s> There is no pleural effusion or pneumothorax . <s> The visualized osseous structures are unremarkable . <s> </s>"

def do_beam_iter(image, note, note_mask):
    
    beam_1 = []
    beam_5 = []
    beam_10 = []
    
    report = evaluate(image)
    report_np = np.asarray(report).squeeze(1)
    
    truth = reports_to_sentence(np.asarray(note[:,:]))[0]
    print('[GT]', truth)
    print()
    generated = reports_to_sentence(report_np)[0]
    print('[Beam 1 Pred]', generated)
    bs1, bs2, bs3, bs4 = bleu(truth, generated)
    print("[Beam 1] Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))
    
    beam_1 = [bs1, bs2, bs3, bs4]
    
    print()
    decoded_sentences = beam_decode(model, image, beam_width=5, topk=1)
    gen = reports_to_sentence(decoded_sentences[0].numpy())[0]
    print("[Beam 5 Pred] ",gen)
    bs1, bs2, bs3, bs4 = bleu(truth, gen)
    beam_5 = [bs1, bs2, bs3, bs4]
    print("[Beam 5] Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))
        
    print()
    decoded_sentences = beam_decode(model, image, beam_width=10, topk=1)
    gen = reports_to_sentence(decoded_sentences[0].numpy())[0]
    print("[Beam 10 Pred] ",gen)
    bs1, bs2, bs3, bs4 = bleu(truth, gen)
    beam_10 = [bs1, bs2, bs3, bs4]
    print("[Beam 10] Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))
    print()
    
    bs1, bs2, bs3, bs4 = bleu(truth, ETHAN_NOTE)
    beam_ethan = [bs1, bs2, bs3, bs4]
    print("[Beam ethan] Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))
    
        
    print('---------------------------------------------------------------')
    print()
    
    return beam_1, beam_5, beam_10, beam_ethan

In [231]:
truth

'<S> No evidence of consolidation to suggest pneumonia is seen . <s> There is some retrocardiac atelectasis . <s> A small left pleural effusion may be present . <s> No pneumothorax is seen . <s> No pulmonary edema . <s> A right granuloma is unchanged . <s> The heart is mildly enlarged , unchanged . <s> There is tortuosity of the aorta . <s> </s>'

In [232]:
# do_beam_iter()

In [233]:

beam_1_list = []
beam_5_list = []
beam_10_list = []
beam_ethan_list = []

iters = 0
for image, note, note_mask in data_loader_val:
    
    
    beam_1, beam_5, beam_10, beam_ethan = do_beam_iter(image, note, note_mask)
    beam_1_list.append(beam_1)
    beam_5_list.append(beam_5)
    beam_10_list.append(beam_10)
    beam_ethan_list.append(beam_ethan)

    iters += 1
    if iters >= 128:
        break
    

[GT] <S> No evidence of consolidation to suggest pneumonia is seen . <s> There is some retrocardiac atelectasis . <s> A small left pleural effusion may be present . <s> No pneumothorax is seen . <s> No pulmonary edema . <s> A right granuloma is unchanged . <s> The heart is mildly enlarged , unchanged . <s> There is tortuosity of the aorta . <s> </s>

[Beam 1 Pred] <S> AP upright and lateral views of the chest provided . <s> Lung volumes are low . <s> There is mild left basal atelectasis . <s> No convincing evidence for pneumonia or edema . <s> No large effusion or pneumothorax . <s> The cardiomediastinal silhouette is stable . <s> Imaged osseous structures are intact . <s> No free air below the right hemidiaphragm is seen . <s> </s>
[Beam 1] Bleu score: 0.3800 0.1525 0.0000 0.0000



KeyboardInterrupt: 

In [None]:
'''
Beam 1 Bleu score: 0.3220 0.2219 0.1090 0.0000
Beam 5 Bleu score: 0.0909 0.0615 0.0317 0.0000

Beam 1 Bleu score: 0.1315 0.0771 0.0517 0.0000
Beam 5 Bleu score: 0.0323 0.0162 0.0000 0.0000

Beam 1 Bleu score: 0.2559 0.1497 0.0876 0.0000
Beam 5 Bleu score: 0.0812 0.0501 0.0273 0.0000

Beam 1 Bleu score: 0.1667 0.0595 0.0000 0.0000
Beam 5 Bleu score: 0.1600 0.0808 0.0514 0.0000

Beam 1 Bleu score: 0.2652 0.1967 0.1603 0.1376
Beam 5 Bleu score: 0.1225 0.1083 0.0992 0.0908

Beam 1 Bleu score: 0.0869 0.0553 0.0363 0.0000
Beam 5 Bleu score: 0.1048 0.0595 0.0291 0.0000
'''


In [187]:
image, note, note_mask = next(iter(data_loader_val))

In [190]:
report = evaluate(image)
report_np = np.asarray(report).squeeze(1)

In [194]:
truth = reports_to_sentence(np.asarray(note[:,:]))[0]
truth

'<S> No evidence of consolidation to suggest pneumonia is seen . <s> There is some retrocardiac atelectasis . <s> A small left pleural effusion may be present . <s> No pneumothorax is seen . <s> No pulmonary edema . <s> A right granuloma is unchanged . <s> The heart is mildly enlarged , unchanged . <s> There is tortuosity of the aorta . <s> </s>'

In [195]:
generated = reports_to_sentence(report_np)[0]
generated

'<S> AP upright and lateral views of the chest provided . <s> Lung volumes are low . <s> There is mild left basal atelectasis . <s> No convincing evidence for pneumonia or edema . <s> No large effusion or pneumothorax . <s> The cardiomediastinal silhouette is stable . <s> Imaged osseous structures are intact . <s> No free air below the right hemidiaphragm is seen . <s> </s>'

In [196]:
bs1, bs2, bs3, bs4 = bleu(truth, generated)
print("Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))

Bleu score: 0.3800 0.1525 0.0000 0.0000


In [198]:
decoded_sentences = beam_decode(model, image)

5 required.
Got  5  end nodes


In [203]:
for i in range(len(decoded_sentences)):
    gen = reports_to_sentence(decoded_sentences[i].numpy())[0]
    print("[iter ",i,"] ",gen)
    bs1, bs2, bs3, bs4 = bleu(truth, gen)
    print("[BLEU] Bleu score: {0:.4f} {1:.4f} {2:.4f} {3:.4f}".format(bs1, bs2, bs3, bs4))

[iter  0 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> </s>
[BLEU] Bleu score: 0.3391 0.2105 0.0998 0.0000
[iter  1 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> No displaced fracture is identified . <s> </s>
[BLEU] Bleu score: 0.3680 0.2213 0.1036 0.0000
[iter  2 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> No displaced f

[iter  0 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> </s>
[BLEU] Bleu score: 0.3391 0.2105 0.0998 0.0000
[iter  1 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> No displaced fracture is identified . <s> </s>
[BLEU] Bleu score: 0.3680 0.2213 0.1036 0.0000
[iter  2 ]  <S> AP upright and lateral views of the chest were obtained . <s> Lung volumes are low . <s> No focal consolidation is seen . <s> No pleural effusion or pneumothorax . <s> The cardiac and mediastinal silhouettes are stable . <s> No overt pulmonary edema is seen . <s> No displaced f

In [172]:
truth

'<S> No evidence of consolidation to suggest pneumonia is seen . <s> There is some retrocardiac atelectasis . <s> A small left pleural effusion may be present . <s> No pneumothorax is seen . <s> No pulmonary edema . <s> A right granuloma is unchanged . <s> The heart is mildly enlarged , unchanged . <s> There is tortuosity of the aorta . <s> </s>'