In [1]:
import numpy as np
import time
import sys
import os
import tqdm

import torch
from torch import nn
import torch.nn.functional as F


from catr.configuration import Config
from catr.models.utils import NestedTensor, nested_tensor_from_tensor_list, get_rank
from catr.models.backbone import build_backbone
from catr.models.transformer import build_transformer
from catr.models.position_encoding import PositionEmbeddingSine
from catr.models.caption import MLP

import json

from dataset.dataset import ImageFeatureDataset
from torch.utils.data import DataLoader
from transformer_ethan import *
sys.path.append(os.path.join(os.path.dirname("__file__"), "catr"))
from engine import train_one_epoch, evaluate

In [2]:
words = np.load("glove_embed.npy")
with open('word2ind.json') as json_file: 
    word2ind = json.load(json_file) 
with open('ind2word.json') as json_file: 
    ind2word = json.load(json_file) 
config = Config()
config.device = 'cpu' # if running without GPU
config.feature_dim = 1024
config.pad_token_id = word2ind["<S>"]
config.hidden_dim = 300
config.nheads = 10
config.batch_size = 8
config.vocab_size = words.shape[0]
config.dir = '../mimic_features'
config.__dict__["pre_embed"] = torch.from_numpy(words)

In [3]:
model, criterion = main(config)
model = model.float()
device = torch.device(config.device)
model.to(device)

Initializing Device: cpu
Number of params: 33908144


Xray_Captioner(
  (input_proj): Conv2d(1024, 300, kernel_size=(1, 1), stride=(1, 1))
  (position_embedding): PositionEmbeddingSine()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=300, out_features=300, bias=True)
          )
          (linear1): Linear(in_features=300, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=300, bias=True)
          (norm1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=3

In [4]:
param_dicts = [
        {"params": [p for n, p in model.named_parameters(
        ) if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": config.lr_backbone,
        },
    ]

In [5]:
optimizer = torch.optim.AdamW(
        param_dicts, lr=config.lr, weight_decay=config.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)

In [6]:
dataset_train = ImageFeatureDataset(config, mode='train')
dataset_val = ImageFeatureDataset(config, mode='val')

sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, config.batch_size, drop_last=True)

data_loader_train = DataLoader(
        dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers)
data_loader_val = DataLoader(dataset_val, config.batch_size,
                                 sampler=sampler_val, drop_last=False, num_workers=config.num_workers)
print(f"Train: {len(dataset_train)}")
print(f"Val: {len(dataset_val)}")

Train: 128
Val: 32


In [7]:
if os.path.exists(config.checkpoint):
    print("Loading Checkpoint...")
    checkpoint = torch.load(config.checkpoint, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    config.start_epoch = checkpoint['epoch'] + 1

print("Start Training..")

Start Training..


In [8]:
train_loss_hist = []
val_loss_hist = []

for epoch in range(config.start_epoch, config.epochs):
    print(f"Epoch: {epoch}")
    epoch_loss = train_one_epoch(
        model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm)
    train_loss_hist.append(epoch_loss)
    lr_scheduler.step()
    print(f"Training Loss: {epoch_loss}")

    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
    }, config.checkpoint)

    validation_loss = evaluate(model, criterion, data_loader_val, device)
    val_loss_hist.append(validation_loss)
    print(f"Validation Loss: {validation_loss}")

    print()

  0%|          | 0/16 [00:00<?, ?it/s]

Epoch: 0


100%|██████████| 16/16 [01:00<00:00,  3.78s/it]


Training Loss: 9.027783334255219


100%|██████████| 4/4 [00:06<00:00,  1.61s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 7.783172607421875

Epoch: 1


100%|██████████| 16/16 [01:01<00:00,  3.82s/it]


Training Loss: 6.678520172834396


100%|██████████| 4/4 [00:06<00:00,  1.62s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 4.76054060459137

Epoch: 2


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Training Loss: 4.878822475671768


100%|██████████| 4/4 [00:08<00:00,  2.02s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 4.130942761898041

Epoch: 3


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Training Loss: 4.056869998574257


100%|██████████| 4/4 [00:07<00:00,  1.90s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 3.4725332856178284

Epoch: 4


100%|██████████| 16/16 [01:09<00:00,  4.37s/it]


Training Loss: 3.524763137102127


100%|██████████| 4/4 [00:07<00:00,  1.87s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 3.2662988901138306

Epoch: 5


100%|██████████| 16/16 [01:10<00:00,  4.38s/it]


Training Loss: 3.2473490238189697


100%|██████████| 4/4 [00:08<00:00,  2.01s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.9581229090690613

Epoch: 6


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Training Loss: 2.985088735818863


100%|██████████| 4/4 [00:07<00:00,  1.94s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.7227672338485718

Epoch: 7


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Training Loss: 2.8182372748851776


100%|██████████| 4/4 [00:07<00:00,  1.86s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.6525256037712097

Epoch: 8


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Training Loss: 2.768174171447754


100%|██████████| 4/4 [00:07<00:00,  1.83s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.656706690788269

Epoch: 9


100%|██████████| 16/16 [01:09<00:00,  4.35s/it]


Training Loss: 2.753058820962906


100%|██████████| 4/4 [00:07<00:00,  1.94s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.6459078192710876

Epoch: 10


100%|██████████| 16/16 [01:09<00:00,  4.36s/it]


Training Loss: 2.7338963747024536


100%|██████████| 4/4 [00:07<00:00,  1.86s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.629267603158951

Epoch: 11


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Training Loss: 2.7085530534386635


100%|██████████| 4/4 [00:07<00:00,  1.94s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.6252678632736206

Epoch: 12


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Training Loss: 2.667819380760193


100%|██████████| 4/4 [00:07<00:00,  1.97s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.553814172744751

Epoch: 13


100%|██████████| 16/16 [01:08<00:00,  4.28s/it]


Training Loss: 2.593202792108059


100%|██████████| 4/4 [00:07<00:00,  1.83s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.500924229621887

Epoch: 14


100%|██████████| 16/16 [01:09<00:00,  4.34s/it]


Training Loss: 2.539571166038513


100%|██████████| 4/4 [00:07<00:00,  1.88s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.4681789577007294

Epoch: 15


100%|██████████| 16/16 [01:08<00:00,  4.27s/it]


Training Loss: 2.5090907737612724


100%|██████████| 4/4 [00:07<00:00,  1.84s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.454487979412079

Epoch: 16


100%|██████████| 16/16 [01:08<00:00,  4.31s/it]


Training Loss: 2.4758152812719345


100%|██████████| 4/4 [00:07<00:00,  1.88s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.4394366443157196

Epoch: 17


100%|██████████| 16/16 [01:08<00:00,  4.29s/it]


Training Loss: 2.436721995472908


100%|██████████| 4/4 [00:07<00:00,  1.79s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.4256466925144196

Epoch: 18


100%|██████████| 16/16 [01:09<00:00,  4.32s/it]


Training Loss: 2.3883387595415115


100%|██████████| 4/4 [00:07<00:00,  1.83s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.404610514640808

Epoch: 19


100%|██████████| 16/16 [01:07<00:00,  4.23s/it]


Training Loss: 2.349529005587101


100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3863262832164764

Epoch: 20


100%|██████████| 16/16 [01:07<00:00,  4.19s/it]


Training Loss: 2.2985946238040924


100%|██████████| 4/4 [00:07<00:00,  2.00s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.378326654434204

Epoch: 21


100%|██████████| 16/16 [01:05<00:00,  4.12s/it]


Training Loss: 2.294106349349022


100%|██████████| 4/4 [00:07<00:00,  1.80s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3755583465099335

Epoch: 22


100%|██████████| 16/16 [01:07<00:00,  4.25s/it]


Training Loss: 2.283103197813034


100%|██████████| 4/4 [00:07<00:00,  1.85s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.374040812253952

Epoch: 23


100%|██████████| 16/16 [01:07<00:00,  4.21s/it]


Training Loss: 2.2768838480114937


100%|██████████| 4/4 [00:07<00:00,  1.79s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3758350908756256

Epoch: 24


100%|██████████| 16/16 [01:07<00:00,  4.20s/it]


Training Loss: 2.27351263910532


100%|██████████| 4/4 [00:07<00:00,  1.83s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3732486963272095

Epoch: 25


100%|██████████| 16/16 [01:06<00:00,  4.15s/it]


Training Loss: 2.268472835421562


100%|██████████| 4/4 [00:07<00:00,  1.78s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.373862475156784

Epoch: 26


100%|██████████| 16/16 [01:08<00:00,  4.27s/it]


Training Loss: 2.2637667655944824


100%|██████████| 4/4 [00:07<00:00,  1.89s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3708381056785583

Epoch: 27


100%|██████████| 16/16 [01:06<00:00,  4.17s/it]


Training Loss: 2.2595183700323105


100%|██████████| 4/4 [00:07<00:00,  1.79s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.373511165380478

Epoch: 28


100%|██████████| 16/16 [01:07<00:00,  4.24s/it]


Training Loss: 2.256667785346508


100%|██████████| 4/4 [00:07<00:00,  1.86s/it]
  0%|          | 0/16 [00:00<?, ?it/s]

Validation Loss: 2.3728419840335846

Epoch: 29


100%|██████████| 16/16 [01:07<00:00,  4.25s/it]


Training Loss: 2.2482321113348007


100%|██████████| 4/4 [00:07<00:00,  1.89s/it]

Validation Loss: 2.3705135583877563






# Evaluation helper funcs

In [9]:
# Edward: note this makes a new caption as (<S>, 0, ..., 0) shouldn't we want as (<S>, <S>, ..., <S>)?
def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template

In [72]:
def make_report(captions):
    all_reports = []
    for report in captions:
        if (report == word2ind["</s>"]).any():
            end_index = (report == word2ind["</s>"]).nonzero()[0][0]
            report = report[:end_index+1]
        one_report = list(map(lambda x: ind2word[str(x)], report))
        all_reports.append(one_report)
    return all_reports

def reports_to_sentence(reports):
    return [' '.join(r) for r in make_report(reports)]

In [67]:
def evaluate(images):
    all_captions = []
    model.eval()
    for i in range(len(images)):
        image = images[i:i+1]
        caption, cap_mask = create_caption_and_mask(
            config.pad_token_id, config.max_position_embeddings)
        for i in range(config.max_position_embeddings - 1):
            predictions = model(image, caption, cap_mask)
            predictions = predictions[:, i, :]
            predicted_id = torch.argmax(predictions, axis=-1)


            caption[:, i+1] = predicted_id[0]
            cap_mask[:, i+1] = False
            
            if predicted_id[0] == word2ind["</s>"]:
                break

        all_captions.append(caption.numpy())
#     return make_report(all_captions)
    return all_captions

In [12]:
image, image_mask, note, note_mask = next(iter(data_loader_train))

In [68]:
report = evaluate(image)

In [70]:
report2 = evaluate(image)

In [73]:
report_np = np.asarray(report).squeeze(1)
reports_to_sentence(report_np)

['<S> The and lateral views of the chest . <s> The and is is is is . <s> The is the the . <s> The is is is is is is are . <s> The is is is is is is . <s> The is is the is is . <s> </s>',
 '<S> The and lateral views of the chest . <s> The . <s> The is is . <s> The is are is . <s> No is . <s> The is are is . <s> No " " " " " . <s> </s>',
 '<S> The and lateral views of the chest . <s> The . <s> The is is . <s> The is are is is is are is . <s> The . <s> No . <s> No . <s> </s>',
 '<S> The and lateral views of the chest . <s> The . <s> The is is . <s> The is are is is is are is . <s> The is is are . <s> </s>',
 '<S> The and lateral the of the chest . <s> The . <s> The is is the . <s> The the the . <s> There is is the is is the the the . <s> There is is is is is . <s> The is is the is is the the . <s> There is is is is . <s> The . <s> </s>',
 '<S> The and lateral views of the chest . <s> The and the is is is . <s> The is the the . <s> The is is the is is the the the . <s> There is is the is i

In [60]:
reports_to_sentence(np.asarray(note))

(array([63]),)
(array([35]),)
(array([45]),)
(array([47]),)
(array([61]),)
(array([89]),)
(array([29]),)
(array([53]),)


['<S> AS COMPARED TO THE PRIOR EXAMINATION DATED ___ , THERE HAS BEEN NO SIGNIFICANT INTERVAL CHANGE . <S> THERE IS NO EVIDENCE OF FOCAL CONSOLIDATION , PLEURAL EFFUSION , PNEUMOTHORAX , OR FRANK PULMONARY EDEMA . <S> THE CARDIOMEDIASTINAL SILHOUETTE IS WITHIN NORMAL LIMITS . <S> THERE IS PERSISTENT THORACIC KYPHOSIS WITH MILD WEDGING OF A MID THORACIC VERTEBRAL BODY . <S> </S>',
 '<S> PA AND LATERAL CHEST RADIOGRAPHS AGAIN DEMONSTRATE SEVERE HYPERINFLATION AND DIFFUSE BRONCHIECTASIS . <S> THERE IS NO FOCAL CONSOLIDATION , PLEURAL EFFUSION , OR PNEUMOTHORAX . <S> THE CARDIOMEDIASTINAL SILHOUETTE IS STABLE . <S> </S>',
 '<S> RETICULAR OPACITIES AT THE LUNG BASES BILATERALLY LIKELY REPRESENT MILD ATELECTASIS . <S> THERE IS MILD <UNK> <UNK> SCARRING . <S> NO EVIDENCE OF PNEUMONIA , PLEURAL EFFUSION , OR PNEUMOTHORAX . <S> HEART SIZE AND MEDIASTINAL CONTOURS ARE WITHIN NORMAL LIMITS . <S> </S>',
 '<S> THE LUNGS ARE CLEAR OF FOCAL CONSOLIDATION , PLEURAL EFFUSION OR PNEUMOTHORAX . <S> THE H

In [58]:
note.shape

torch.Size([8, 129])

In [23]:
evaluate(image[0].unsqueeze(0).unsqueeze(0))

['<S> THE AND LATERAL VIEWS OF THE CHEST . <S> THE AND IS IS IS IS . <S> THE IS THE THE . <S> THE IS IS IS IS IS IS ARE . <S> THE IS IS IS IS IS IS . <S> THE IS IS THE IS IS . <S> </S> " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " "']

In [64]:
image.unsqueeze(0).shape

torch.Size([1, 8, 1024, 8, 8])