# ChestX-Ray data

In [7]:
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

class WordTokenizer():
    def __init__(self, texts, tokenizer, vocab_size=400, sos_token='<SOS>', eos_token='<EOS>', pad_token='<PAD>', unk_token='<UNK>'):
        self.tokenizer = word_tokenize if tokenizer == 'word' else sent_tokenize
        self.vocab, self.word2idx, self.idx2word = self.__build_vocab__(texts, sos_token, eos_token, pad_token, unk_token)
        self.vocab_size = len(self.vocab)
        self.pad_token = self.word2idx[pad_token]
        self.unk_token = self.word2idx[unk_token]

    def tokenize(self, text, sos_token='<SOS>', eos_token='<EOS>'):
        return [sos_token] + self.tokenizer(text) + [eos_token]
    
    def __build_vocab__(self, texts, sos_token, eos_token, pad_token, unk_token):
        vocab = set()
        vocab.add(sos_token)
        vocab.add(eos_token)
        vocab.add(pad_token)
        vocab.add(unk_token)

        for text in texts:
            for token in self.tokenize(text):
                vocab.add(token)
        word2idx = {word: idx for idx, word in enumerate(vocab)}
        idx2word = {idx: word for idx, word in enumerate(vocab)}
        return vocab, word2idx, idx2word

In [43]:
tokenizer = WordTokenizer(reports, 'sent')
print(tokenizer.vocab_size)

5521


In [19]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import models.training as tr
import models.reportnet as M
import data.preprocessing as pr
from torchvision import transforms

batch_size = 2
# Load data
uids = np.unique(pr.projections.index)

# Image preprocessing 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((2048, 2048), antialias=False)
])

train_data, train_loader, val_data, val_loader, test_data, test_loader = pr.create_dataloaders(uids, pr.IMAGES_PATH, batch_size=batch_size, transform=transform)

# Vocabulary
reports = [pr.reports.loc[uid, 'findings'] if type(pr.reports.loc[uid, 'findings']) == str else 'No findings.' for uid in uids]
tokenizer = WordTokenizer(reports, 'word')
find_divisors(tokenizer.vocab_size)

[nltk_data] Downloading package punkt to /home/mdelucasg/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


KeyError: 'findings'

In [2]:
# Just because sometimes if get some errors when setting the nheads parameter
# TODO: Think a way of clamping the size of the vocabulary to a multiple of the number of heads
def find_divisors(n):
    divisors = [i for i in range(1, n+1) if n % i == 0]
    return divisors

# Test the function
print(f"N-heads might be: {find_divisors(vocab_size)}")

N-heads might be: [1, 2, 4, 113, 226, 452]


# Transformer Decoder only

In [3]:
# Define the parameters
d_model = 128 # how many features do we have for each token --> embeding dimension
nhead = 2
num_layers = 6
dropout = 0.1

# seq length, batch size, d_model == emb_dim 
memory = torch.rand(50, batch_size, d_model) # --> this memory should be the image features from CNN
for batch in train_loader:
    tgt = batch[1]
    break
decoder_generator = M.DecoderGenerator(vocab_size, d_model, nhead, num_layers, dropout)
out = decoder_generator(tgt, 'hola', memory, 'adios')
out.shape

  return self.softmax(out)


torch.Size([60, 2, 452])

# Training

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import models.training as tr
import models.reportnet as M

device = "cuda" if torch.cuda.is_available() else "cpu"

# Model
model = M.Transformer(
    num_tokens=vocab_size,
    dim_model=vocab_size*2,
    num_heads=2,
    num_encoder_layers=3,
    num_decoder_layers=3,
    dropout_p=0.3,
)

# Training
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=train_data.word2idx[pr.PAD_TOKEN])
trainer = tr.fit(model, optimizer, criterion, train_loader, val_loader, 2)

Epoch 1
torch.Size([1, 60, 904])
torch.Size([1, 60, 904])
torch.Size([60, 1, 904])
torch.Size([60, 2, 904])


RuntimeError: The shape of the 2D attn_mask is torch.Size([1, 1]), but should be (60, 60).

# Predict