In [1]:
import numpy as np
import time
import sys
import os
import tqdm

import torch
from torch import nn
import torch.nn.functional as F


from catr.configuration import Config
from catr.models.utils import NestedTensor, nested_tensor_from_tensor_list, get_rank
from catr.models.backbone import build_backbone
from catr.models.transformer_double_encoder import build_transformer_double_encoder
from catr.models.position_encoding import PositionEmbeddingSine
from catr.models.caption import MLP

import json

from dataset.dataset import ImageDoubleFeatureDataset
from torch.utils.data import DataLoader
from transformer_ethan_double_encoder import *
sys.path.append(os.path.join(os.path.dirname("__file__"), "catr"))
from engine_double_encoder import train_one_epoch_double_encoder, evaluate_double_encoder

In [2]:
# Load Word Embeddings
words = np.load("glove_embed.npy")
with open('word2ind.json') as json_file: 
    word2ind = json.load(json_file) 
with open('ind2word.json') as json_file: 
    ind2word = json.load(json_file) 
    
# Set up config file
config = Config()
config.feature_dim = 1024
config.pad_token_id = word2ind["<S>"]
config.hidden_dim = 300
config.nheads = 10
config.batch_size = 16
config.enc_layers = 6
config.vocab_size = words.shape[0]
config.checkpoint = './checkpoint_double_TEST.pth'
config.dir = '../mimic_features_double'
config.__dict__["pre_embed"] = torch.from_numpy(words)
config.__dict__["encoder_type"] = 2

In [3]:
model, criterion = main(config)
model = model.float()
device = torch.device(config.device)
model.to(device)

param_dicts = [
    {"params": [p for n, p in model.named_parameters(
    ) if "backbone" not in n and p.requires_grad]},
    {
        "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
        "lr": config.lr_backbone,
    },
]

Initializing Device: cuda
Number of params: 45820808


In [4]:
optimizer = torch.optim.AdamW(
        param_dicts, lr=config.lr, weight_decay=config.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)

In [5]:
dataset_train = ImageDoubleFeatureDataset(config, mode='train')
dataset_val = ImageDoubleFeatureDataset(config, mode='val')

sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, config.batch_size, drop_last=True)

data_loader_train = DataLoader(
        dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers)
data_loader_val = DataLoader(dataset_val, config.batch_size,
                                 sampler=sampler_val, drop_last=False, num_workers=config.num_workers)
print(f"Train: {len(dataset_train)}")
print(f"Val: {len(dataset_val)}")

Train: 253893
Val: 1196


In [6]:
if os.path.exists("./checkpoint_double_10_tf.pth"):
    print("Loading Checkpoint...")
    checkpoint = torch.load("./checkpoint_double_10_tf.pth", map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    config.start_epoch = checkpoint['epoch'] + 1

print("Start Training..")

Loading Checkpoint...
Start Training..


In [None]:
train_loss_hist = []
val_loss_hist = []
train_bleu_hist = []
val_bleu_hist = []

for epoch in range(config.start_epoch, 20):
    print(f"Epoch: {epoch}")
    epoch_loss, train_bleu_score = train_one_epoch_double_encoder(
        model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm, word2ind)
    train_loss_hist.append(epoch_loss)
    train_bleu_hist.append(train_bleu_score)
    lr_scheduler.step()
    print(f"Training Bleu Score: {train_bleu_score}")
    print(f"Training Loss: {epoch_loss}")
    
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
    }, config.checkpoint)

    validation_loss, val_bleu_score = evaluate_double_encoder(model, criterion, data_loader_val, device, word2ind)
    val_loss_hist.append(validation_loss)
    val_bleu_hist.append(val_bleu_score)
    print(f"Validation Bleu Score: {val_bleu_score}")
    print(f"Validation Loss: {validation_loss}")
    
    print()

In [None]:
from matplotlib import pyplot as plt
plt.plot(train_loss_hist)
plt.show()

# Evaluation helper funcs

In [7]:
# Edward: note this makes a new caption as (<S>, 0, ..., 0) shouldn't we want as (<S>, <S>, ..., <S>)?
def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

In [8]:
def make_report(captions):
    all_reports = []
    for report in captions:
        if (report == word2ind["</s>"]).any():
            end_index = (report == word2ind["</s>"]).nonzero()[0][0]
            report = report[:end_index+1]
        one_report = list(map(lambda x: ind2word[str(x)], report))
        all_reports.append(one_report)
    return all_reports

def reports_to_sentence(reports):
    return [' '.join(r) for r in make_report(reports)]

In [9]:
def evaluate(images):
    all_captions = []
    model.eval()
    for i in range(len(images[0])):
        image1 = images[0][i:i+1].to(device)
        image2 = images[1][i:i+1].to(device)
        caption, cap_mask = create_caption_and_mask(
            config.pad_token_id, config.max_position_embeddings)
        for i in range(config.max_position_embeddings - 1):
            predictions = model(image1, image2, caption, cap_mask)
            predictions = predictions[:, i, :]
            predicted_id = torch.argmax(predictions, axis=-1)


            caption[:, i+1] = predicted_id[0]
            cap_mask[:, i+1] = False
            
            if predicted_id[0] == word2ind["</s>"]:
                break

        all_captions.append(caption.numpy())
#     return make_report(all_captions)
    return all_captions

In [116]:
import nltk
sample_bleu4 = []
sample_bleu3 = []
sample_bleu2 = []
sample_bleu1 = []

In [115]:
iterations = iter(data_loader_train)

In [117]:
image, note, note_mask = next(iterations)

In [118]:
report = evaluate(image)

In [119]:
report_np = np.asarray(report).squeeze(1)

In [120]:
for index in range(config.batch_size):
    truth = reports_to_sentence(np.asarray(note[:,:]))[index]
    generated = reports_to_sentence(report_np)[index]
    truth = truth.replace("<S>", "").replace("<s>", "").replace("</s>", "").replace(".", "").replace(",", "").replace("  ", " ").split(" ")
    generated = generated.replace("<S>", "").replace("<s>", "").replace("</s>", "").replace(".", "").replace(",", "").replace("  ", " ").split(" ")
    truth = [y for y in truth if y != ''] 
    generated = [y for y in generated if y != ''] 
    bs4 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[0.25, 0.25, 0.25, 0.25])
    bs3 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[1./3., 1./3., 1./3.])
    bs2 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[0.5, 0.5])
    bs1 = nltk.translate.bleu_score.sentence_bleu([truth], generated, weights=[1.])
    sample_bleu4.append(bs4)
    sample_bleu3.append(bs3)
    sample_bleu2.append(bs2)
    sample_bleu1.append(bs1)
    print("Bleu score: ", bs1, bs2, bs3, bs4)

Bleu score:  0.2574819918958893 0.11329171097764855 6.010168564233855e-104 4.377547941707098e-155
Bleu score:  0.26021690648467544 0.10862066616770275 0.0654378833045959 1.3279403427149402e-78
Bleu score:  0.32692307692307687 0.21182963643408087 0.12152840862513396 0.07779637090949697
Bleu score:  0.16318055318565106 0.12639911298258805 0.09372411919266695 1.7305017688546433e-78
Bleu score:  0.2557647735180994 0.09884262861234175 5.674154754883453e-104 4.299120532299612e-155
Bleu score:  0.31281602148329835 0.21197381067415416 0.1667161816753367 0.12684067851857955
Bleu score:  0.4089086287241437 0.2932435749928753 0.2103531207276943 0.1269215692088095
Bleu score:  0.2653707856378107 0.11990727389027557 0.06424663373279794 1.503880657590319e-78
Bleu score:  0.11927262468906635 0.07040973772670185 3.6946708052182934e-104 2.6763768211153003e-155
Bleu score:  0.03693438612077796 0.017097298521688287 8.965692786400731e-105 6.492501115498954e-156
Bleu score:  0.3663840366994972 0.2358231967

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [124]:
sum(sample_bleu4)/len(sample_bleu1)

0.03977312173436471

In [106]:
len(sample_bleu1)

128

In [None]:
report_np = np.asarray(report).squeeze(1)
reports_to_sentence(report_np)

In [40]:
image, note, note_mask = next(iterations)

In [41]:
reports_to_sentence(np.asarray(note))

['<S> Subtle patchy opacity along the left heart border on the frontal view , not substantiated on the lateral view , may be due to <unk> scarring or epicardial fat pad , less likely consolidation . <s> No focal consolidation seen elsewhere . <s> There is no pleural effusion or pneumothorax . <s> Cardiac and mediastinal silhouettes are stable . <s> Hilar contours are stable . <s> No overt pulmonary edema is seen . <s> Chronic changes at the right acromioclavicular joint are not well assessed . <s> </s>',
 '<S> AP upright and lateral views of the chest provided . <s> There is no focal consolidation , effusion , or pneumothorax . <s> The cardiomediastinal silhouette is normal . <s> Imaged osseous structures are intact . <s> No free air below the right hemidiaphragm is seen . <s> </s>',
 '<S> No focal consolidation is seen . <s> There is elevation of the mid to posterior left hemidiaphragm with minimal blunting of the left costophrenic angle without a definite pleural effusion seen on the