In [1]:
import numpy as np
import time
import sys
import os

import torch
from torch import nn
import torch.nn.functional as F


from catr.configuration import Config
from catr.models.utils import NestedTensor, nested_tensor_from_tensor_list, get_rank
from catr.models.backbone import build_backbone
from catr.models.transformer import build_transformer
from catr.models.position_encoding import PositionEmbeddingSine
from catr.models.caption import MLP


### First, define our custom caption class

In [2]:
class Xray_Captioner(nn.Module):
    def __init__(self, transformer, feature_dim, hidden_dim, vocab_size):
        super().__init__()
        self.input_proj = nn.Conv2d(
            feature_dim, hidden_dim, kernel_size=1) # project feature dimension
        self.position_embedding = PositionEmbeddingSine(hidden_dim//2, normalize=True)
        self.transformer = transformer
        self.mlp = MLP(hidden_dim, 512, vocab_size, 3)

    def forward(self, img_features, target, target_mask):
        # The input mask here is all zeros, meaning we look at the whole image
        # The mask here is more of a formality, oringinally implemented to 
        # let the model accept different image sizes. Not needed here.
        b, c, h, w = img_features.shape
        mask = torch.zeros((b, h, w), dtype=torch.bool, device=img_features.device)

        # Get projected image features and positional embedding
        img_embeds = self.input_proj(img_features)
        pos = self.position_embedding(NestedTensor(img_embeds, mask))
        
        # Run through transformer -> linear -> softmax
        hs = self.transformer(img_embeds, mask,
                              pos, target, target_mask)
        out = self.mlp(hs.permute(1, 0, 2))
        return out
    

def build_model(config):
    transformer = build_transformer(config)
    model = Xray_Captioner(transformer, config.feature_dim, config.hidden_dim, config.vocab_size)
    criterion = torch.nn.CrossEntropyLoss()

    return model, criterion

### This method builds the model like we will during training/inference

In [3]:
''' 
This method uses a config file to appropriately create the model.
This includes setting the device and specifying a random seed    
''' 
def main(config):
    # initialize device we're runnign this on
    device = torch.device(config.device)
    print(f'Initializing Device: {device}')

    # specify the random seed for deterministic behavior
    seed = config.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # create the model
    model, criterion = build_model(config)
    model.to(device)
    
    # sanity check
    n_parameters = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
    print(f"Number of params: {n_parameters}")
    
    return model

### Create a model

In [4]:
# Create a sample config file
# feature_dim is not specified by default, so we need to set it
config = Config()
config.device = 'cpu' # if running without GPU
config.feature_dim = 1024

# Create the model!
xray_model = main(config)

Initializing Device: cpu
Number of params: 41525306


### Run some random stuff through

In [5]:
# Helper function to create initial caption and mask
def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, :] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

In [6]:
# from get_matrix_from_cnn import *
# from get_notes_and_image_paths import *
from get_word_embeddings import *
# from image_feature_dataset import *
from load_glove_840B_300d import *
# from make_study_dictionary import *
from dataset.dataset import ImageFeatureDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader
from torch.autograd import Variable
import math
import torch.optim as optim
import numpy as np
import copy

In [7]:
# test_dict = main_get_notes_and_image_paths("test")

In [8]:
# test_dict["p10/p10274145/s53356050/4e60f3da-37ed157d-a469a568-0b2ee907-4b01c924.jpg"]

In [9]:
dataset = ImageFeatureDataset('../mimic_features/paths.csv',
                             '../mimic_features/')
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [10]:
features, note = next(iter(dataloader))
print(features.shape)
print(note)

torch.Size([1, 1024, 8, 8])
[('<S>',), ('No',), ('evidence',), ('of',), ('consolidation',), ('to',), ('suggest',), ('pneumonia',), ('is',), ('seen',), ('.',), ('<s>',), ('There',), ('is',), ('some',), ('retrocardiac',), ('atelectasis',), ('.',), ('<s>',), ('A',), ('small',), ('left',), ('pleural',), ('effusion',), ('may',), ('be',), ('present',), ('.',), ('<s>',), ('No',), ('pneumothorax',), ('is',), ('seen',), ('.',), ('<s>',), ('No',), ('pulmonary',), ('edema',), ('.',), ('<s>',), ('A',), ('right',), ('granuloma',), ('is',), ('unchanged',), ('.',), ('<s>',), ('The',), ('heart',), ('is',), ('mildly',), ('enlarged',), (',',), ('unchanged',), ('.',), ('<s>',), ('There',), ('is',), ('tortuosity',), ('of',), ('the',), ('aorta',), ('.',), ('<s>',), ('</s>',)]


In [11]:
# test = np.load('../mimic_features/eb2fabb7-4bbc8aab-d7371282-08e5bcb5-de2e430a.npy', allow_pickle=True)

In [12]:
words = main_get_word_embeddings()

Loading Glove Model
Done. 23094  words loaded!


In [13]:
# make word embedding dictionary
weights = np.zeros((len(set(words.keys())), 300))
word2ind = {}
ind2word = {}
current_ind = 0
for word in sorted(words.keys()):
    word2ind[word] = current_ind
    ind2word[current_ind] = word
    weights[current_ind] = words[word]
    current_ind += 1

In [14]:
# Create starter caption and caption mask
start_token = word2ind["<S>"]
caption, cap_mask = create_caption_and_mask(
    start_token, config.max_position_embeddings)

In [81]:
# matrix = extract_image_features("/Users/ethanschonfeld/Desktop/mimic_cxr", 
#                                 ["p10/p10274145/s53356050/4e60f3da-37ed157d-a469a568-0b2ee907-4b01c924.jpg"], 
#                                 "densenet121")

In [17]:
# features_torch = torch.from_numpy(features.transpose((0, 3, 1, 2)))

In [18]:
# getting predicted sentence for an image
current_caption, current_mask = create_caption_and_mask(start_token, config.max_position_embeddings)
iteration_number = 1
last_word = word2ind["<S>"]
while iteration_number <= config.max_position_embeddings and last_word != word2ind["</s>"]:
    predictions = xray_model(features, current_caption, current_mask)
    # get highest predicted word
    word = torch.argmax(predictions[:,0,:], axis=-1)
    current_caption[:, iteration_number] = word
    current_mask[:, iteration_number] = False
    iteration_number += 1

IndexError: index 128 is out of bounds for dimension 1 with size 128

In [20]:
for index in range(current_caption.shape[1]):
    print(ind2word[current_caption[0, index].item()])

<S>
uniformly
Outward
Flexure
Flexure
these
Outward
uniformly
these
emanating
Flexure
Flexure
Month
Flexure
Flexure
these
Flexure
13:38
emanating
Yestarday
these
Flexure
these
Flexure
Flexure
Flexure
Flexure
Flexure
Yestarday
Flexure
13:38
Flexure
emanating
affects
emanating
these
Flexure
Flexure
these
Flexure
Flexure
emanating
Flexure
Flexure
Flexure
Flexure
Make
emanating
these
emanating
Flexure
these
these
emanating
these
these
Flexure
uniformly
these
Flexure
Flexure
Flexure
Flexure
Flexure
these
Flexure
these
Flexure
these
these
Flexure
Flexure


KeyError: 28393