# ViT

Here, we propose a combination of the *Google ViT* for the image feature extraction (i.e. encoder) and the *Bert* for the text generation (i.e. decoder).

We use a chekpoint from [HuggingFace](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder) since we do not have enough resources for training such huge models from scratch (either have enough data for doing so).

In [1]:
import torch
import numpy as np
import pandas as pd
import data.preprocessing as pr
from torchvision import transforms
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel

# Get the data
uids = np.unique(pr.projections.index)[:300]

# Image preprocessing 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=False)
])

train_data, train_loader, val_data, val_loader, test_data, test_loader = pr.create_dataloaders(uids, pr.IMAGES_PATH, batch_size=3, transform=transform)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/mpizarro/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Training

In [2]:
# ViT training
# Instance model and optimizer
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=False, do_normalize=True)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.add_cross_attention = True

# Hyperparameters
n_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

best_loss = np.inf
for epoch in range(n_epochs):
    t_loss = 0
    model.train()
    n_epochs_to_stop = 5
    for batch in train_data:
        imgs, reports = batch[0], batch[1]
        pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
        labels, _, att = tokenizer(
            reports,
            padding=True,
            truncation=True,
            # vocab_file='./vocab.txt',
            return_tensors="pt",
        ).values()
        loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
        # Some optimizations for training
        del pixel_values, labels, att, imgs, reports
        # torch.cuda.empty_cache()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        t_loss += loss.item() / len(train_loader)

    # Validation
    model.eval()
    test_loss = 0
    for batch in val_data:
        imgs, reports = batch[0], batch[1]
        pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
        labels, _, att = tokenizer(
            reports,
            padding=True,
            truncation=True,
            # vocab_file='./vocab.txt',
            return_tensors="pt",
        ).values()
        loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
        # Some optimizations for training
        del pixel_values, labels, att, imgs, reports
        test_loss += loss.item() / len(val_loader)

    print(f"Epoch {epoch} Training_Loss: {t_loss} || Validation_Loss: {test_loss}")
    if test_loss < best_loss:
        best_loss = t_loss
        model.save_pretrained('vit-bert-pretrained')

    # Early stopping
    if test_loss > best_loss:
        n_epochs_to_stop -= 1
        if n_epochs_to_stop == 0:
            print(f"Early stopping at epoch {epoch}")
            break

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.

Epoch 0 Training_Loss: 16.109341918894685 || Validation_Loss: 15.309488630294807
Epoch 1 Training_Loss: 14.706854710536717 || Validation_Loss: 15.25134525895119
Epoch 2 Training_Loss: 14.557857798264099 || Validation_Loss: 15.252157509326938
Epoch 3 Training_Loss: 14.464490270192634 || Validation_Loss: 15.282632225751877
Epoch 4 Training_Loss: 14.410946164510941 || Validation_Loss: 15.322292834520336
Epoch 5 Training_Loss: 14.36470880550621 || Validation_Loss: 15.344239586591717
Epoch 6 Training_Loss: 14.348493299653052 || Validation_Loss: 15.397416532039639
Epoch 7 Training_Loss: 14.301939221610011 || Validation_Loss: 15.43715462684631
Epoch 8 Training_Loss: 14.278168370238454 || Validation_Loss: 15.414231383800505
Epoch 9 Training_Loss: 14.261726054470097 || Validation_Loss: 15.460633134841917
Epoch 10 Training_Loss: 14.241429316259072 || Validation_Loss: 15.52247618436813
Epoch 11 Training_Loss: 14.631924027890221 || Validation_Loss: 15.528285914659497
Epoch 12 Training_Loss: 14.220

In [5]:
model.save_pretrained('vit-bert-pretrained_last_epoch')

# Testing

In [6]:
# # Load model
# model = VisionEncoderDecoderModel.from_pretrained('vit-bert-pretrained')
# image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=True, do_normalize=False)
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# model.config.pad_token_id = tokenizer.pad_token_id
# model.config.decoder_start_token_id = tokenizer.cls_token_id
# model.config.add_cross_attention = True

# Testing
model.eval()
test_loss = 0
for batch in val_loader:
    imgs, reports = batch[0], batch[1]
    pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
    labels, _, att = tokenizer(
        reports,
        padding=True,
        truncation=True,
        # vocab_file='./vocab.txt',
        return_tensors="pt",
    ).values()
    loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
    # Some optimizations for training
    del pixel_values, labels, att, imgs, reports
    test_loss += loss.item() / len(val_loader)

print(f"Epoch {epoch} Training_Loss: {t_loss} || Test_Loss: {test_loss}")

Epoch 99 Training_Loss: 14.201556039067494 || Test_Loss: 8.624162828922273


## Predict 

In [25]:
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel


image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=True, do_normalize=False)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisionEncoderDecoderModel.from_pretrained("vit-bert-pretrained")

img, report = test_data[0] # This has to be changed
pixel_values = image_processor(img, return_tensors="pt").pixel_values
labels = tokenizer(report, return_tensors="pt").input_ids
logits = model(pixel_values=pixel_values, labels=labels).logits
predicted_ids = logits.argmax(-1)
tokenizer.convert_ids_to_tokens(predicted_ids[0])

['.', '.', '.', '.', '.']

In [36]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("nickmuchi/vit-finetuned-chest-xray-pneumonia")
model = AutoModelForImageClassification.from_pretrained("nickmuchi/vit-finetuned-chest-xray-pneumonia")

max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(images):
  pixel_values = feature_extractor2(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values

  output_ids = model2.generate(pixel_values, **gen_kwargs)

  preds = tokenizer2.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

predict_step([train_data[90][0], train_data[1][0], train_data[2][0]])

Downloading config.json: 100%|██████████| 4.61k/4.61k [00:00<00:00, 1.85MB/s]
Downloading pytorch_model.bin: 100%|██████████| 982M/982M [00:31<00:00, 30.8MB/s] 
Downloading (…)rocessor_config.json: 100%|██████████| 228/228 [00:00<00:00, 863kB/s]
Downloading tokenizer_config.json: 100%|██████████| 241/241 [00:00<00:00, 368kB/s]
Downloading vocab.json: 100%|██████████| 798k/798k [00:00<00:00, 6.80MB/s]
Downloading merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 48.8MB/s]
Downloading tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 12.6MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 120/120 [00:00<00:00, 138kB/s]
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


['a black and white photo of a person standing in front of a light',
 'a black and white photo of a person standing in front of a light',
 'a black and white photo of a person standing in front of a light']

In [34]:
train_data[0][1]

'vague increased opacity which appears to be within the left lower lobe. question of this could be developing or resolving pneumonia. lungs are otherwise clear. no pleural effusions or pneumothoraces. heart and mediastinum are stable normal size heart. atherosclerotic vascular disease. degenerative changes in the thoracic spine.'

In [24]:
for i in logits[0]:
    idx = i.argmax().item()
    print(tokenizer.convert_ids_to_tokens([idx]))

['.']
['.']
['.']
['.']
['.']
