In [1]:
import torch
import requests
from PIL import Image
from transformers import TrainingArguments,Trainer,BlipModel,BlipForConditionalGeneration,AutoProcessor
import accelerate
from peft import get_peft_model, LoraConfig,TaskType
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from textwrap import wrap
from evaluate import load

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_checkpoint = 'Salesforce/blip-image-captioning-base'

In [3]:
processor = AutoProcessor.from_pretrained(model_checkpoint)

In [4]:
model = BlipForConditionalGeneration.from_pretrained(model_checkpoint)

In [5]:
# peft_config = LoraConfig(task_type='image_caption',
#                          target_modules = ["q", "v"],
#                          inference_mode=False, 
#                          r=8, 
#                          lora_alpha=32, 
#                          lora_dropout=0.1)

In [6]:
# model = get_peft_model(model,peft_config)
# model.print_trainable_parameters()

In [7]:
root='/mnt/storage-ssd/wangcheng/dataset/rgb/GIT/total/'

In [8]:
dataset = load_dataset(root,split='train')

Resolving data files: 100%|██████████| 5897/5897 [00:00<00:00, 6859.70it/s] 
Found cached dataset imagefolder (/home/wangcheng/.cache/huggingface/datasets/imagefolder/total-dfe73fa9e4870104/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [9]:
from torch.utils.data import Dataset, DataLoader

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [10]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(50):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Epoch: 0
Loss: 13.222488403320312
Loss: 10.39437198638916
Loss: 10.316842079162598
Loss: 10.265515327453613
Loss: 10.240105628967285
Loss: 10.241506576538086
Loss: 10.2442626953125
Loss: 10.227607727050781
Loss: 10.217031478881836
Loss: 10.213162422180176
Loss: 10.191887855529785
Loss: 10.18581485748291
Loss: 10.190314292907715
Loss: 10.174428939819336
Loss: 10.125255584716797
Loss: 9.865944862365723
Loss: 9.829734802246094
Loss: 9.217122077941895
Loss: 9.081928253173828
Loss: 8.883038520812988
Loss: 8.691396713256836
Loss: 8.508221626281738
Loss: 8.309130668640137
Loss: 8.149846076965332
Loss: 7.961730480194092
Loss: 7.822160720825195
Loss: 7.691000938415527
Loss: 7.520909309387207
Loss: 7.39722204208374
Loss: 7.259430408477783
Loss: 7.128566741943359
Loss: 6.998673439025879
Loss: 6.845564365386963
Loss: 6.73099946975708
Loss: 6.585529804229736
Loss: 6.432960510253906
Loss: 6.3031697273254395
Loss: 6.165825843811035
Loss: 6.036649703979492
Loss: 5.892551422119141
Loss: 5.7352747917175

In [None]:
def plot_images(images,captions):
    plt.figure(figsize=(20,20))
    for i in range(len(images)):
        ax = plt.subplot(1,len(images),i+1)
        caption = captions[i]
        caption = '\n'.join(wrap(caption,12))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis('off')

In [None]:
sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)

In [None]:
max_length = 16 # max length of the captions in tokens

In [None]:
def preprocess(items):
  # preprocess the image
  pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
  # tokenize the caption with truncation and padding
  targets = tokenizer(items['text'], 
                      max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
#   inputs = image_processor(images=pixel_values, text=targets, padding="max_length")
#   inputs.update({"labels": inputs["input_ids"]})
#   return inputs
  return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset  = test_ds.with_transform(preprocess)

In [None]:
# a function we'll use to collate the batches
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

In [None]:
import evaluate
# load the rouge and bleu metrics
rouge = evaluate.load("rouge")


In [None]:
from bleu_script.bleu import Bleu
bleu =Bleu()

In [None]:
def compute_metrics(eval_pred):
  preds = eval_pred.label_ids
  labels = eval_pred.predictions
  # decode the predictions and labels
  pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
  labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
  # compute the rouge score
  rouge_result = rouge.compute(predictions=pred_str, references=labels_str)
  # multiply by 100 to get the same scale as the rouge score
  rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
  # compute the bleu score
  bleu_result = bleu.compute(predictions=pred_str, references=labels_str)
  # get the length of the generated captions
  generation_length = bleu_result["translation_length"]
  return {
        **rouge_result, 
        "bleu": round(bleu_result["bleu"] * 100, 4), 
        "gen_len": bleu_result["translation_length"] / len(preds)
  }

In [None]:
num_epochs = 2 # number of epochs
batch_size = 16 # the size of batches

In [None]:
# define the training arguments
training_args = TrainingArguments(
    predict_with_generate=True,             # use generate to calculate the loss
    num_train_epochs=num_epochs,            # number of epochs
    evaluation_strategy="steps",            # evaluate after each eval_steps
    eval_steps=20,                        # evaluate after each 500 steps
    logging_steps=20,                     # log after each 500 steps
    save_steps=20,                       # save after each 500 steps
    per_device_train_batch_size=batch_size, # batch size for training
    per_device_eval_batch_size=batch_size,  # batch size for evaluation
    output_dir="vit-swin-base-224-gpt2-galaxy-captioning", # output directory
    # push_to_hub=True # whether you want to push the model to the hub,
    # check this guide for more details: https://huggingface.co/transformers/model_sharing.html
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()