# Fine-tuning an Image Caption model

In [1]:
# Check GPU is detected by CUDA
import torch
print(torch.cuda.is_available())


True


In [2]:
# Check if CPU is supported by IPEX
#import intel_extension_for_pytorch as ipex
#
#print(ipex.cpu.runtime.is_runtime_ext_enabled())


## Download pretrained model

In [3]:
from transformers import BlipProcessor, BlipForConditionalGeneration, default_data_collator, get_linear_schedule_with_warmup
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, prepare_model_for_int8_training

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
checkpoint = "Salesforce/blip-image-captioning-base"
config = LoraConfig(r=32, lora_alpha=64, target_modules=["qkv"], lora_dropout=0.05, bias="none")

processor = BlipProcessor.from_pretrained(checkpoint)
model = BlipForConditionalGeneration.from_pretrained(checkpoint, load_in_8bit=True)#.to(device)
#print(model)

print(model.get_memory_footprint())

model = prepare_model_for_int8_training(model)
model = get_peft_model(model, config)
print(model.print_trainable_parameters())

print(model.get_memory_footprint())


  from .autonotebook import tqdm as notebook_tqdm


296122608
trainable params: 1,179,648 || all params: 248,624,248 || trainable%: 0.4744702133799918
None
398189024




## Load data

In [4]:
import pandas as pd

df_train = pd.read_csv('ids_train.csv')
df_test = pd.read_csv('ids_test.csv')


In [5]:
from datasets import Dataset, Image

def prep_data(df):
  files = [f"media/{media_id}.jpg" for media_id in df['media_id'].to_list()]
  descriptions = df['description'].to_list()

  return Dataset.from_dict({ "image": files, "text": descriptions }).cast_column("image", Image())


In [6]:
ds_train = prep_data(df_train)
ds_test = prep_data(df_test)

ds_train[0]


{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x747 at 0x7FEACC9D3E80>,
 'text': 'Landscape with two horsemen enter the composition from the left foreground, one with a falcon standing on his left hand. Falcons hunt birds in the background. Framing line at the bottom.'}

In [7]:
del df_train
del df_test


In [8]:
#from textwrap import wrap
#import matplotlib.pyplot as plt
#import numpy as np
#
#def plot_images(images, captions):
#    plt.figure(figsize=(20, 20))
#    for i in range(len(images)):
#        ax = plt.subplot(1, len(images), i + 1)
#        caption = captions[i]
#        caption = "\n".join(wrap(caption, 12))
#        plt.title(caption)
#        plt.imshow(images[i])
#        plt.axis("off")
#
#sample_images_to_visualize = [np.array(ds_train[i]["image"]) for i in range(5)]
#sample_captions = [ds_train[i]["text"] for i in range(5)]
#plot_images(sample_images_to_visualize, sample_captions)


In [9]:
# https://huggingface.co/docs/transformers/tasks/image_captioning
def transforms(example_batch):
    images = [x.convert("RGB").resize((100,100)) for x in example_batch["image"]]
    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images, text=captions, padding="max_length")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs


ds_train.set_transform(transforms)
ds_test.set_transform(transforms)


## Train

In [10]:
import evaluate
import torch

#rouge = evaluate.load('rouge')
wer = evaluate.load('wer')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    #score = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
    score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"score": score}


In [16]:
from transformers import TrainingArguments, Trainer

model_name = checkpoint.split("/")[1]

training_args = TrainingArguments(
    # Model saving
    output_dir=f"{model_name}-finetune",
    push_to_hub=False,
    # Hardware support
    fp16=True,
    #use_cpu=True,
    #use_ipex=True,
    # Basics
    num_train_epochs=5,
    learning_rate=5e-5,
    label_names=["labels"],
    # Other
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    save_total_limit=3,
    evaluation_strategy="no",
    #evaluation_strategy="steps",
    #eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=50,
    remove_unused_columns=False,
    #load_best_model_at_end=True,

)


In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    compute_metrics=compute_metrics,
)


In [18]:
import gc

gc.collect()

#torch.cuda.empty_cache()


1709

In [19]:
trainer.train()




Step,Training Loss
50,12.4717
100,9.9101
150,9.2828
200,9.0597
250,8.9455
300,8.8867
350,8.8372
400,8.7938
450,8.762
500,8.7405




TrainOutput(global_step=11330, training_loss=8.552758045651064, metrics={'train_runtime': 14348.5471, 'train_samples_per_second': 6.316, 'train_steps_per_second': 0.79, 'total_flos': 5.406712717285786e+19, 'train_loss': 8.552758045651064, 'epoch': 5.0})