# ViT

Here, we propose a combination of the *Google ViT* for the image feature extraction (i.e. encoder) and the *Bert* for the text generation (i.e. decoder).

We use a chekpoint from [HuggingFace](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder) since we do not have enough resources for training such huge models from scratch (either have enough data for doing so).

In [None]:
import torch
import numpy as np
import pandas as pd
import data.preprocessing as pr
from torchvision import transforms
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel

# Get the data
uids = np.unique(pr.projections.index)[:300]

# Image preprocessing 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=False)
])

train_data, train_loader, val_data, val_loader, test_data, test_loader = pr.create_dataloaders(uids, pr.IMAGES_PATH, batch_size=6, transform=transform)

## Training

In [None]:
# ViT training
# Instance model and optimizer
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=False, do_normalize=True)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.add_cross_attention = True

# Hyperparameters
n_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

best_loss = np.inf
for epoch in range(n_epochs):
    t_loss = 0
    model.train()
    n_epochs_to_stop = 5
    for batch in train_loader:
        imgs, reports = batch[0], batch[1]
        pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
        labels, _, att = tokenizer(
            reports,
            padding=True,
            truncation=True,
            # vocab_file='./vocab.txt',
            return_tensors="pt",
        ).values()
        loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
        # Some optimizations for training
        del pixel_values, labels, att, imgs, reports
        # torch.cuda.empty_cache()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        t_loss += loss.item() / len(train_loader)

    # Validation
    model.eval()
    test_loss = 0
    for batch in val_loader:
        imgs, reports = batch[0], batch[1]
        pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
        labels, _, att = tokenizer(
            reports,
            padding=True,
            truncation=True,
            # vocab_file='./vocab.txt',
            return_tensors="pt",
        ).values()
        loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
        # Some optimizations for training
        del pixel_values, labels, att, imgs, reports
        test_loss += loss.item() / len(val_loader)

    print(f"Epoch {epoch} Training_Loss: {t_loss} || Validation_Loss: {test_loss}")
    if test_loss < best_loss:
        best_loss = t_loss
        model.save_pretrained('vit-bert-pretrained')

    # Early stopping
    if test_loss > best_loss:
        n_epochs_to_stop -= 1
        if n_epochs_to_stop == 0:
            print(f"Early stopping at epoch {epoch}")
            break

# Testing

In [None]:
# Load model
model = VisionEncoderDecoderModel.from_pretrained('vit-bert-pretrained')
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=True, do_normalize=False)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.add_cross_attention = True

# Testing
model.eval()
test_loss = 0
for batch in val_loader:
    imgs, reports = batch[0], batch[1]
    pixel_values = image_processor(imgs, return_tensors="pt").pixel_values
    labels, _, att = tokenizer(
        reports,
        padding=True,
        truncation=True,
        # vocab_file='./vocab.txt',
        return_tensors="pt",
    ).values()
    loss = model(pixel_values=pixel_values, labels=labels, decoder_attention_mask=att).loss
    # Some optimizations for training
    del pixel_values, labels, att, imgs, reports
    test_loss += loss.item() / len(val_loader)

print(f"Epoch {epoch} Training_Loss: {t_loss} || Test_Loss: {test_loss}")

## Predict 

In [None]:
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel


image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", do_rescale=True, do_normalize=False)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisionEncoderDecoderModel.from_pretrained("vit-bert-pretrained")

img, report = test_data[120] # This has to be changed
pixel_values = image_processor(img, return_tensors="pt").pixel_values
labels = tokenizer(report, return_tensors="pt").input_ids
logits = model(pixel_values=pixel_values, labels=labels).logits
predicted_ids = logits.argmax(-1)
tokenizer.convert_ids_to_tokens(predicted_ids[0])