# Training Vision Transformer (ViT) on custom dataset

In [7]:
from datasets import load_from_disk

ds = load_from_disk('../saved_dataset')
train_dataset = ds['train']
val_dataset = ds['validation']
test_dataset = ds['test']


## Pre-processing the data

In [3]:
from transformers import AutoImageProcessor, VisionEncoderDecoderModel, AutoTokenizer

# ViT Encoder - Decoder Model
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to("cpu")

processor = AutoImageProcessor.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
tokenizer = AutoTokenizer.from_pretrained('nlpconnect/vit-gpt2-image-captioning')



In [8]:
from PIL import Image

def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 
                      padding="max_length", 
                      max_length=max_target_length).input_ids

    return labels

def feature_extraction_fn(images):
    """
    Run feature extraction on images
    """

    image_processor = AutoImageProcessor.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
    encoder_inputs = image_processor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

    images = examples['image']
    captions = examples['caption']    
    
    model_inputs = {}
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(images)

    return model_inputs

def preprocess_fn(examples, max_target_length):
    """Run tokenization + image feature extraction"""
    images = examples['image']
    captions = examples['caption']    
    
    model_inputs = {}
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(images)

    return model_inputs    

In [22]:
processed_train_dataset = train_dataset.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=train_dataset.column_names
)
processed_test_dataset = test_dataset.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=test_dataset.column_names
)

Map:   0%|          | 0/18 [00:00<?, ? examples/s]



In [11]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import accelerate


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)

In [33]:
from transformers import TrainingArguments, Trainer
from sklearn.metrics import precision_score, recall_score, f1_score
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

# 1. Define the training and evaluation steps
def compute_metrics(eval_pred):
    print(eval_pred)
    predictions, labels = eval_pred
    if not predictions or not labels:
        return {}

    print(predictions)
    print(labels)

    # Assuming predictions and labels are lists of tuples
    # Each tuple contains two elements: detected_objects and generated_caption
    detected_objects_pred, captions_pred = zip(*predictions)
    detected_objects_true, captions_true = zip(*labels)

    # Compute metrics for object detection
    precision = precision_score(detected_objects_true, detected_objects_pred, average='weighted') if detected_objects_true and detected_objects_pred else 0
    recall = recall_score(detected_objects_true, detected_objects_pred, average='weighted') if detected_objects_true and detected_objects_pred else 0
    f1 = f1_score(detected_objects_true, detected_objects_pred, average='weighted') if detected_objects_true and detected_objects_pred else 0

    # Compute metrics for image captioning
    bleu = sentence_bleu(captions_true, captions_pred) if captions_true and captions_pred else 0
    rouge = Rouge().get_scores(captions_pred, captions_true, avg=True) if captions_true and captions_pred else 0

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'bleu': bleu,
        'rouge': rouge,
    }

# 2. Initialize the Trainer
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=processed_train_dataset,         # training dataset
    eval_dataset=processed_test_dataset,           # evaluation dataset
    compute_metrics=compute_metrics,     # the callback that computes metrics of interest
)


In [None]:
# 3. Train the model
trainer.train()

In [27]:
# 4. Save the model
trainer.save_model()

In [34]:
# 5. Evaluate the model
trainer.evaluate()

  0%|          | 0/1 [00:00<?, ?it/s]

<transformers.trainer_utils.EvalPrediction object at 0x32105e690>


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()