# Fine-tuning an Image Caption model

In [None]:
# Check GPU is detected by CUDA
import torch
print(torch.cuda.is_available())


In [None]:
# Check if CPU is supported by IPEX
#import intel_extension_for_pytorch as ipex
#
#print(ipex.cpu.runtime.is_runtime_ext_enabled())


## Download pretrained model

In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration, default_data_collator, get_linear_schedule_with_warmup
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, prepare_model_for_int8_training

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
checkpoint = "Salesforce/blip-image-captioning-base"
model_name = checkpoint.split("/")[1]

config = LoraConfig(r=32, lora_alpha=64, target_modules=["qkv"], lora_dropout=0.05, bias="none")

processor = BlipProcessor.from_pretrained(checkpoint)
model = BlipForConditionalGeneration.from_pretrained(checkpoint, load_in_8bit=True)#.to(device)
#print(model)

print(model.get_memory_footprint())

#model.save_pretrained(f"{model_name}-8bit-pre-peft")
#processor.save_pretrained(f"{model_name}-8bit-pre-peft")


## PEFT Prep

In [None]:
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, config)
print(model.print_trainable_parameters())

print(model.get_memory_footprint())


In [None]:
#model.save_pretrained(f"{model_name}-quantized")
#processor.save_pretrained(f"{model_name}-quantized")


## Load data

In [None]:
import pandas as pd

df_train = pd.read_csv('ids_train.csv')
#df_test = pd.read_csv('ids_test.csv')


In [None]:
import nltk

nltk.download('punkt')
nltk.download('stopwords')


In [None]:
from nltk.corpus import stopwords

stop_words = set(stopwords.words('english'))


In [None]:
from nltk.tokenize import word_tokenize

# Remove Stopwords & Commas
# "Why?" - this shouldn't be necessary but the model is pretty much only
# outputting stopwords or commas; we're removing them to try to force it
# to not do this. This may break semantic understanding slightly but
# hopefully will give a better result.
def remove_junk(text):
  tokens = word_tokenize(text.replace(",",""))
  tokens = [tok for tok in tokens if tok.lower() not in stop_words]
  return " ".join(tokens)


In [None]:
from datasets import Dataset, Image

def prep_data(df):
  files = [f"media/{media_id}.jpg" for media_id in df['media_id'].to_list()]
  descriptions = [remove_junk(text) for text in df['description'].to_list()]

  return Dataset.from_dict({ "image": files, "text": descriptions }).cast_column("image", Image())


In [None]:
ds_train = prep_data(df_train.head(1000))
#ds_test = prep_data(df_test)

ds_train[0]


In [None]:
del df_train
#del df_test


In [None]:
#from textwrap import wrap
#import matplotlib.pyplot as plt
#import numpy as np
#
#def plot_images(images, captions):
#    plt.figure(figsize=(20, 20))
#    for i in range(len(images)):
#        ax = plt.subplot(1, len(images), i + 1)
#        caption = captions[i]
#        caption = "\n".join(wrap(caption, 12))
#        plt.title(caption)
#        plt.imshow(images[i])
#        plt.axis("off")
#
#sample_images_to_visualize = [np.array(ds_train[i]["image"]) for i in range(5)]
#sample_captions = [ds_train[i]["text"] for i in range(5)]
#plot_images(sample_images_to_visualize, sample_captions)


In [None]:
# https://huggingface.co/docs/transformers/tasks/image_captioning
def transforms(example_batch):
    images = [x.convert("RGB").resize((100,100)) for x in example_batch["image"]]
    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images, text=captions, padding="max_length")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs


ds_train.set_transform(transforms)
#ds_test.set_transform(transforms)


## Train

In [None]:
import evaluate
import torch

rouge = evaluate.load('rouge')
#wer = evaluate.load('wer')
#bleu = evaluate.load('bleu')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    score = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
    #score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    #score = bleu.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"score": score}


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    # Model saving
    output_dir=f"{model_name}-wip-2",
    push_to_hub=False,
    # Hardware support
    #fp16=True,
    #use_cpu=True,
    #use_ipex=True,
    # Basics
    num_train_epochs=5,
    learning_rate=5e-5,
    label_names=["labels"],
    # Other
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    save_total_limit=3,
    evaluation_strategy="no",
    #evaluation_strategy="steps",
    #eval_steps=50,
    save_strategy="steps",
    save_steps=250,
    logging_steps=250,
    remove_unused_columns=False,
    #load_best_model_at_end=True,

)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    #eval_dataset=ds_test,
    compute_metrics=compute_metrics,
)


In [None]:
import gc

gc.collect()

#torch.cuda.empty_cache()


In [None]:
trainer.train()


In [None]:
model.save_pretrained(f"{model_name}-quicktest")
processor.save_pretrained(f"{model_name}-quicktest")


## Eval