# Training Vision Transformer (ViT) on custom dataset

In [5]:
from datasets import load_from_disk

ds = load_from_disk('../saved_dataset')
train_dataset = ds['train']
val_dataset = ds['validation']
test_dataset = ds['test']
train_dataset.features

{'image': Image(mode=None, decode=True, id=None),
 'caption': Value(dtype='string', id=None),
 'date': Value(dtype='string', id=None),
 'location': Value(dtype='string', id=None),
 'coordinates': Value(dtype='string', id=None)}

## Pre-processing the data

In [6]:
from transformers import AutoImageProcessor, VisionEncoderDecoderModel, AutoTokenizer

# ViT Encoder - Decoder Model
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to("cpu")

processor = AutoImageProcessor.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
tokenizer = AutoTokenizer.from_pretrained('nlpconnect/vit-gpt2-image-captioning')



In [31]:
from PIL import Image

def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 
                      padding="max_length", 
                      max_length=max_target_length).input_ids

    return labels

def feature_extraction_fn(images):
    """
    Run feature extraction on images
    """

    image_processor = AutoImageProcessor.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
    encoder_inputs = image_processor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

    images = examples['image']
    captions = examples['caption']    
    
    model_inputs = {}
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(images)

    return model_inputs

def preprocess_fn(examples, max_target_length):
    """Run tokenization + image feature extraction"""
    images = examples['image']
    captions = examples['caption']    
    
    model_inputs = {}
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(images)

    return model_inputs    

In [32]:
processed_train_dataset = train_dataset.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=train_dataset.column_names
)
processed_test_dataset = test_dataset.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=test_dataset.column_names
)

Map:   0%|          | 0/144 [00:00<?, ? examples/s]



In [9]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import accelerate


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)

In [37]:
import numpy as np
from transformers import TrainingArguments, Trainer
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

import evaluate
metric = evaluate.load("rouge")

import numpy as np

ignore_pad_token_for_loss = True


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

# from sklearn.metrics import accuracy_score

# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.argmax(predictions, axis=1)
#     return dict(accuracy=accuracy_score(predictions, labels))

# 1. Define the training and evaluation steps
# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.argmax(predictions, axis=-1)  # get the predicted token IDs

#     # Assuming your tokenizer is available as `tokenizer`
#     captions_pred = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions.tolist()]
#     captions_true = [tokenizer.decode(true, skip_special_tokens=True) for true in labels.tolist()]
    
#     # Compute metrics for image captioning
#     bleu_scores = [sentence_bleu([true.split()], pred.split()) for true, pred in zip(captions_true, captions_pred)]
#     bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0

#     rouge_scores = [Rouge().get_scores(pred, true, avg=True)['rouge-l']['f'] for true, pred in zip(captions_true, captions_pred)]
#     rouge = sum(rouge_scores) / len(rouge_scores) if rouge_scores else 0

#     return {
#         'bleu': bleu,
#         'rouge': rouge,
#     }

# 2. Initialize the Trainer
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=processed_train_dataset,         # training dataset
    eval_dataset=processed_test_dataset,           # evaluation dataset
    compute_metrics=compute_metrics,     # the callback that computes metrics of interest
)


In [18]:
# 3. Train the model
trainer.train()

  0%|          | 0/27 [00:00<?, ?it/s]

{'train_runtime': 54.7729, 'train_samples_per_second': 7.887, 'train_steps_per_second': 0.493, 'train_loss': 0.6030247299759476, 'epoch': 3.0}


TrainOutput(global_step=27, training_loss=0.6030247299759476, metrics={'train_runtime': 54.7729, 'train_samples_per_second': 7.887, 'train_steps_per_second': 0.493, 'total_flos': 7.79604009638953e+16, 'train_loss': 0.6030247299759476, 'epoch': 3.0})

In [19]:
# 4. Save the model
trainer.save_model()

In [40]:
# 5. Evaluate the model
# trainer.evaluate()
trainer.eval_dataset[0]

{'labels': [1925,
  5580,
  24380,
  319,
  257,
  5509,
  5762,
  607,
  2042,
  35685,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  