In [1]:
from transformers import VisionEncoderDecoderModel,GPT2TokenizerFast,ViTImageProcessor,Seq2SeqTrainer,Seq2SeqTrainingArguments
from datasets import load_dataset
import numpy as np
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from peft import get_peft_model, LoraConfig,TaskType
from textwrap import wrap
from evaluate import load

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
encoder_model = 'microsoft/swin-base-patch4-window7-224-in22k'
decoder_model ='gpt2'

In [3]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_model,
    decoder_model).to(device)

Some weights of the model checkpoint at microsoft/swin-base-patch4-window7-224-in22k were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_proj.bias', 'h.6.crossattention.c_attn.bias', 'h.9.ln_cross_attn.weight', 'h.9.crossattention.c_proj.bias', 'h.7.crossattention.q_attn.bias', 'h.7.ln_cross_attn.bias', 'h.3.ln_cross_attn.weight', 'h.5.ln_cross_attn.weight', 'h.2.crossattenti

In [4]:
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
image_processor =ViTImageProcessor.from_pretrained(encoder_model)

In [5]:
if "gpt2" in decoder_model:
  # gpt2 does not have decoder_start_token_id and pad_token_id
  # but has bos_token_id and eos_token_id
  tokenizer.pad_token = tokenizer.eos_token # pad_token_id as eos_token_id
  model.config.eos_token_id = tokenizer.eos_token_id
  model.config.pad_token_id = tokenizer.pad_token_id
  # set decoder_start_token_id as bos_token_id
  model.config.decoder_start_token_id = tokenizer.bos_token_id
else:
  # set the decoder start token id to the CLS token id of the tokenizer
  model.config.decoder_start_token_id = tokenizer.cls_token_id
  # set the pad token id to the pad token id of the tokenizer
  model.config.pad_token_id = tokenizer.pad_token_id

In [6]:
# peft_config = LoraConfig(task_type='image_caption',
#                          target_modules = ["query", "value"],
#                          inference_mode=False, 
#                          r=8, 
#                          lora_alpha=32, 
#                          lora_dropout=0.1)

In [7]:
# model = get_peft_model(model,peft_config)

In [8]:
# model.print_trainable_parameters()

In [9]:
root='/mnt/storage-ssd/wangcheng/dataset/rgb/GIT/total/'

ds = load_dataset(root)

Resolving data files: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5897/5897 [00:00<00:00, 6937.12it/s] 
Found cached dataset imagefolder (/home/wangcheng/.cache/huggingface/datasets/imagefolder/total-b9b0b8db427296d7/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1/1 [00:00<00:00, 32.28it/s]


In [10]:
ds = ds['train'].train_test_split(0.1)
test_ds = ds['test']


In [11]:
ds = ds['train'].train_test_split(0.15)
train_ds = ds['train']
valid_ds = ds['test']

In [12]:
len(train_ds),len(valid_ds),len(test_ds)

(4510, 796, 590)

In [13]:
max_length = 16 # max length of the captions in tokens

In [14]:
def preprocess(items):
  # preprocess the image
  pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
  # tokenize the caption with truncation and padding
  targets = tokenizer([ sentence["raw"] for sentence in items["sentences"] ], 
                      max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
  return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

# using with_transform to preprocess the dataset during training
train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset  = test_ds.with_transform(preprocess)

In [15]:
def preprocess(items):
  # preprocess the image
  pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
  # tokenize the caption with truncation and padding
  targets = tokenizer(items['text'], 
                      max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
#   inputs = image_processor(images=pixel_values, text=targets, padding="max_length")
#   inputs.update({"labels": inputs["input_ids"]})
#   return inputs
  return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset  = test_ds.with_transform(preprocess)

In [16]:
# a function we'll use to collate the batches
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

In [17]:
import evaluate
# load the rouge and bleu metrics
rouge = evaluate.load("rouge")


In [18]:
from bleu_script.bleu import Bleu
bleu =Bleu()

In [19]:
def compute_metrics(eval_pred):
  preds = eval_pred.label_ids
  labels = eval_pred.predictions
  # decode the predictions and labels
  pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
  labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
  # compute the rouge score
  rouge_result = rouge.compute(predictions=pred_str, references=labels_str)
  # multiply by 100 to get the same scale as the rouge score
  rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
  # compute the bleu score
  bleu_result = bleu.compute(predictions=pred_str, references=labels_str)
  # get the length of the generated captions
  generation_length = bleu_result["translation_length"]
  return {
        **rouge_result, 
        "bleu": round(bleu_result["bleu"] * 100, 4), 
        "gen_len": bleu_result["translation_length"] / len(preds)
  }

In [20]:
num_epochs = 2 # number of epochs
batch_size = 16 # the size of batches

In [21]:
# define the training arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,             # use generate to calculate the loss
    num_train_epochs=num_epochs,            # number of epochs
    evaluation_strategy="steps",            # evaluate after each eval_steps
    eval_steps=20,                        # evaluate after each 500 steps
    logging_steps=20,                     # log after each 500 steps
    save_steps=20,                       # save after each 500 steps
    per_device_train_batch_size=batch_size, # batch size for training
    per_device_eval_batch_size=batch_size,  # batch size for evaluation
    output_dir="vit-swin-base-224-gpt2-galaxy-captioning", # output directory
    # push_to_hub=True # whether you want to push the model to the hub,
    # check this guide for more details: https://huggingface.co/transformers/model_sharing.html
)

In [22]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,                     # the instantiated ðŸ¤— Transformers model to be trained
    tokenizer=image_processor,       # we use the image processor as the tokenizer
    args=training_args,              # pass the training arguments
    compute_metrics=compute_metrics, 
    train_dataset=train_dataset,     
    eval_dataset=valid_dataset,      
    data_collator=collate_fn,        
)

In [23]:
from torch.utils.data import DataLoader

def get_eval_loader(eval_dataset=None):
  return DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size)

def get_test_loader(eval_dataset=None):
  return DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size)

# override the get_train_dataloader, get_eval_dataloader and
# get_test_dataloader methods of the trainer
# so that we can properly load the data
trainer.get_train_dataloader = lambda: DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size)
trainer.get_eval_dataloader = get_eval_loader
trainer.get_test_dataloader = get_test_loader

In [24]:
# train the model
trainer.train()



Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Bleu,Gen Len
20,2.067,0.296111,64.5604,37.0471,64.5446,64.61,37.8654,4.346734
40,0.2694,0.13933,64.2901,36.4238,64.2635,64.3159,36.8821,4.346734
60,0.1981,0.127372,70.7439,47.3614,70.7099,70.7603,44.4146,4.346734
80,0.1901,0.10492,77.4605,58.746,77.4951,77.4968,53.4623,4.346734
100,0.1745,0.095059,74.6172,52.6086,74.6505,74.6794,51.0397,4.346734
120,0.1462,0.088374,79.3286,63.4694,79.36,79.3078,56.8634,4.346734
140,0.1183,0.082198,82.9745,67.6484,82.9544,82.9543,61.3,4.346734
160,0.1127,0.113399,81.9649,66.7211,81.9473,81.9772,57.8202,4.346734
180,0.1454,0.086789,82.2494,66.7551,82.217,82.2417,59.0973,4.346734
200,0.0978,0.076,84.3141,71.2985,84.2654,84.3485,65.9375,4.346734




TrainOutput(global_step=564, training_loss=0.16676992431600043, metrics={'train_runtime': 2979.5882, 'train_samples_per_second': 3.027, 'train_steps_per_second': 0.189, 'total_flos': 1.637080974186578e+18, 'train_loss': 0.16676992431600043, 'epoch': 2.0})