In [1]:
import numpy as np
import time
import sys
import os

import torch
from torch import nn
import torch.nn.functional as F


from catr.configuration import Config
from catr.models.utils import NestedTensor, nested_tensor_from_tensor_list, get_rank
from catr.models.backbone import build_backbone
from catr.models.transformer import build_transformer
from catr.models.position_encoding import PositionEmbeddingSine
from catr.models.caption import MLP


### First, define our custom caption class

In [2]:
class Xray_Captioner(nn.Module):
    def __init__(self, transformer, feature_dim, hidden_dim, vocab_size):
        super().__init__()
        self.input_proj = nn.Conv2d(
            feature_dim, hidden_dim, kernel_size=1) # project feature dimension
        self.position_embedding = PositionEmbeddingSine(hidden_dim//2, normalize=True)
        self.transformer = transformer
        self.mlp = MLP(hidden_dim, 512, vocab_size, 3)

    def forward(self, img_features, target, target_mask):
        # The input mask here is all zeros, meaning we look at the whole image
        # The mask here is more of a formality, oringinally implemented to 
        # let the model accept different image sizes. Not needed here.
        b, c, h, w = img_features.shape
        mask = torch.zeros((b, h, w), dtype=torch.bool, device=img_features.device)

        # Get projected image features and positional embedding
        img_embeds = self.input_proj(img_features)
        pos = self.position_embedding(NestedTensor(img_embeds, mask))
        
        # Run through transformer -> linear -> softmax
        hs = self.transformer(img_embeds, mask,
                              pos, target, target_mask)
        out = self.mlp(hs.permute(1, 0, 2))
        return out
    

def build_model(config):
    transformer = build_transformer(config)
    model = Xray_Captioner(transformer, config.feature_dim, config.hidden_dim, config.vocab_size)
    criterion = torch.nn.CrossEntropyLoss()

    return model, criterion

### This method builds the model like we will during training/inference

In [3]:
''' 
This method uses a config file to appropriately create the model.
This includes setting the device and specifying a random seed    
''' 
def main(config):
    # initialize device we're runnign this on
    device = torch.device(config.device)
    print(f'Initializing Device: {device}')

    # specify the random seed for deterministic behavior
    seed = config.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # create the model
    model, criterion = build_model(config)
    model.to(device)
    
    # sanity check
    n_parameters = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
    print(f"Number of params: {n_parameters}")
    
    return model

### Create a model

In [4]:
# Create a sample config file
# feature_dim is not specified by default, so we need to set it
config = Config()
config.device = 'cpu' # if running without GPU
config.feature_dim = 1024

# Create the model!
xray_model = main(config)

Initializing Device: cpu
Number of params: 41525306


### Run some random stuff through

In [5]:
# Helper function to create initial caption and mask
def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template


# Create starter caption and caption mask
start_token = 102 # why not lol
caption, cap_mask = create_caption_and_mask(
    start_token, config.max_position_embeddings)

# Initialize features to randomness
fake_features = torch.randn(1, 1024, 8,8)

In [6]:
# Get some predictions
predictions = xray_model(fake_features, caption, cap_mask)

# Get the first word
word_preds = predictions[:,0,:]
word = torch.argmax(word_preds, axis=-1)
print("Vocab size:", config.vocab_size)
print("My word:", word.item())

Vocab size: 30522
My word: 22644
