In [1]:
import numpy as np
import time
import sys
import os

import torch
from torch import nn
import torch.nn.functional as F


from catr.configuration import Config
from catr.models.utils import NestedTensor, nested_tensor_from_tensor_list, get_rank
from catr.models.backbone import build_backbone
from catr.models.transformer import build_transformer
from catr.models.position_encoding import PositionEmbeddingSine
from catr.models.caption import MLP

from get_matrix_from_cnn import *
from get_notes_and_image_paths import *
from get_word_embeddings import *
from image_feature_dataset import *
from load_glove_840B_300d import *
from make_study_dictionary import *

import math
import torch.optim as optim
import copy

### First, define our custom caption class

In [2]:
class Xray_Captioner(nn.Module):
    def __init__(self, transformer, feature_dim, hidden_dim, vocab_size):
        super().__init__()
        self.input_proj = nn.Conv2d(
            feature_dim, hidden_dim, kernel_size=1) # project feature dimension
        self.position_embedding = PositionEmbeddingSine(hidden_dim//2, normalize=True)
        self.transformer = transformer
        self.mlp = MLP(hidden_dim, 512, vocab_size, 3)

    def forward(self, img_features, target, target_mask):
        # The input mask here is all zeros, meaning we look at the whole image
        # The mask here is more of a formality, oringinally implemented to 
        # let the model accept different image sizes. Not needed here.
        b, c, h, w = img_features.shape
        mask = torch.zeros((b, h, w), dtype=torch.bool, device=img_features.device)

        # Get projected image features and positional embedding
        img_embeds = self.input_proj(img_features)
        pos = self.position_embedding(NestedTensor(img_embeds, mask))
        
        # Run through transformer -> linear -> softmax
        hs = self.transformer(img_embeds, mask,
                              pos, target, target_mask)
        out = self.mlp(hs.permute(1, 0, 2))
        return out
    

def build_model(config):
    transformer = build_transformer(config)
    model = Xray_Captioner(transformer, config.feature_dim, config.hidden_dim, config.vocab_size)
    criterion = torch.nn.CrossEntropyLoss()

    return model, criterion

### This method builds the model like we will during training/inference

In [3]:
''' 
This method uses a config file to appropriately create the model.
This includes setting the device and specifying a random seed    
''' 
def main(config):
    # initialize device we're runnign this on
    device = torch.device(config.device)
    print(f'Initializing Device: {device}')

    # specify the random seed for deterministic behavior
    seed = config.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # create the model
    model, criterion = build_model(config)
    model.to(device)
    
    # sanity check
    n_parameters = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
    print(f"Number of params: {n_parameters}")
    
    return model

In [7]:
test_dict = main_get_notes_and_image_paths("test")

In [8]:
test_dict["p10/p10274145/s53356050/4e60f3da-37ed157d-a469a568-0b2ee907-4b01c924.jpg"]

['<S>',
 'Chest',
 'PA',
 'and',
 'lateral',
 'radiograph',
 'demonstrates',
 'unchanged',
 'cardiomediastinal',
 'and',
 'hilar',
 'contours',
 '.',
 '<s>',
 'No',
 'overt',
 'pulmonary',
 'edema',
 'is',
 'evident',
 'though',
 'chronic',
 'mild',
 'interstitial',
 'abnormalities',
 'are',
 'stable',
 '.',
 '<s>',
 'Faint',
 'opacification',
 'projecting',
 'over',
 'the',
 'left',
 'mid',
 'lung',
 'may',
 'represent',
 'developing',
 'infectious',
 'process',
 '.',
 '<s>',
 'There',
 'is',
 'no',
 'definitive',
 'correlate',
 'on',
 'the',
 'lateral',
 'radiograph',
 '.',
 '<s>',
 'No',
 'pleural',
 'effusion',
 'or',
 'pneumothorax',
 'present',
 '.',
 '<s>',
 'Mild',
 'separation',
 'of',
 'superior',
 'aspect',
 'of',
 'sternotomy',
 'line',
 'with',
 'intact',
 'sternotomy',
 'sutures',
 '.',
 '<s>',
 '</s>']

In [9]:
words = main_get_word_embeddings()

Loading Glove Model
Done. 23094  words loaded!


In [27]:
# make word embedding dictionary
weights = np.zeros((len(set(words.keys())), 300))
word2ind = {}
ind2word = {}
current_ind = 0
for word in sorted(words.keys()):
    word2ind[word] = current_ind
    ind2word[current_ind] = word
    weights[current_ind] = words[word]
    current_ind += 1

### Create a model

In [29]:
# Create a sample config file
# feature_dim is not specified by default, so we need to set it
config = Config()
config.device = 'cpu' # if running without GPU
config.feature_dim = 1024
config.pad_token_id = word2ind["<S>"]
config.hidden_dim = 300
config.nheads = 10
config.vocab_size = len(words.keys())
config.__dict__["pre_embed"] = torch.from_numpy(weights)

# Create the model!
xray_model = main(config)
xray_model = xray_model.double()

Initializing Device: cpu
Number of params: 33908144


In [13]:
# Helper function to create initial caption and mask
def create_evaluation_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, :] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

In [14]:
# Create starter caption and caption mask
start_token = word2ind["<S>"]
caption, cap_mask = create_evaluation_caption_and_mask(
    start_token, config.max_position_embeddings)

In [15]:
matrix = extract_image_features("../mimic_cxr", 
                                ["p10/p10274145/s53356050/4e60f3da-37ed157d-a469a568-0b2ee907-4b01c924.jpg"]
                                ,"densenet121")

In [16]:
features = matrix.transpose((0, 3, 1, 2))
features = torch.from_numpy(features)

In [31]:
# getting predicted sentence for an image
current_caption, current_mask = create_evaluation_caption_and_mask(start_token, config.max_position_embeddings)
iteration_number = 1
last_word = word2ind["<S>"]
while iteration_number < config.max_position_embeddings and last_word != word2ind["</s>"]:
    predictions = xray_model(features.double(), current_caption, current_mask)
    # get highest predicted word
    word = torch.argmax(predictions[:,0,:], axis=-1)
    try:
        print(ind2word[word.item()])
    except:
        pass
    current_caption[:, iteration_number] = word
    current_mask[:, iteration_number] = False
    iteration_number += 1

05:29
05:29
FAILURE
Mastectomy
05:29
05:29
Pars
Pars
p.o
Ends
Mastectomy
Narrowing
05:29
Worry
pneumonic
Pars
latter
Pars
Produces
4.8
p.o
Parrot
05:29


KeyboardInterrupt: 