In [None]:
# Install library
!pip install -U transformers
!pip install accelerate
# Download Dataset
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip /content/Flickr8k_Dataset.zip
!unzip /content/Flickr8k_text.zip

!echo "Downloaded Flickr8k dataset successfully."
!mkdir saved_model

In [None]:
!pip install accelerate

In [None]:
import os
from transformers import AutoProcessor, AutoImageProcessor, ViTModel, CLIPVisionModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from huggingface_hub import login

HUGGINGFACE_TOKEN = "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"
login(HUGGINGFACE_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Architecture

In [None]:
class MyAdaptor(nn.Module) :
  def __init__(self, num_vis_token_summary, vis_token_embedding_size, word_embedding_size, num_vocab) :
    super(MyAdaptor, self).__init__()
    self.num_vis_token_summary = num_vis_token_summary
    self.vis_token_embedding_size = vis_token_embedding_size
    self.word_embedding_size = word_embedding_size
    self.num_vocab = num_vocab

    self.adapter_embed = nn.Parameter(torch.randn(1,self.num_vis_token_summary, self.vis_token_embedding_size)*0.001)
    self.adapter_MLP = nn.Sequential(
        nn.Linear(self.vis_token_embedding_size, self.vis_token_embedding_size*4),
        nn.ReLU(),
        nn.Linear(self.vis_token_embedding_size*4, self.word_embedding_size)
        # nn.Linear(1024, self.num_vocab)
    )

  def forward(self, img_output) :
    temp_adap_embed = self.adapter_embed.to(img_output.device).to(img_output.dtype)
    self.adapter_MLP.to(img_output.device)

    attn_weight = temp_adap_embed @ img_output.permute(0,2,1)
    attn_weight = torch.softmax(attn_weight,-1)
    img_embed = attn_weight @ img_output

    img_embed = self.adapter_MLP(img_embed)
    # img_embed = torch.softmax(img_embed,-1) @ self.model_language.model.decoder.embed_tokens.weight

    return img_embed

In [None]:
class LearableToken(nn.Module) :
  def __init__(self, num_token, word_embedding_size) :
    super(LearableToken, self).__init__()
    self.num_token = num_token
    self.word_embedding_size = word_embedding_size
    self.bias_token = nn.Parameter(torch.randn(self.num_token, self.word_embedding_size)*0.01)
    self.instruct_bias_token = nn.Parameter(torch.randn(self.num_token*8, self.word_embedding_size)*0.01)


In [None]:
class MyModel(nn.Module) :
  def __init__(self) :
    super(MyModel, self).__init__()
    self.model_language = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", torch_dtype=torch.bfloat16)
    # self.model_language = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
    self.tokenizer_language = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", padding_side= 'right')
    self.image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32").image_processor
    self.model_image = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")

    print(self.model_language)
    print(self.model_image)

    self.word_embedding_size = 2048
    self.num_vocab = 256000

    self.trigger_str_img = "<start_image>"
    self.num_vis_token_summary = 32
    self.vis_token_embedding_size = 768
    self.adaptorListCaption = MyAdaptor(self.num_vis_token_summary,self.vis_token_embedding_size,self.word_embedding_size,self.num_vocab )

    self.dummy_img_token = (" ".join(["the"]*self.num_vis_token_summary)).strip()
    self.dummy_bias_token = (" ".join(["the"]*self.num_bias_token)).strip()

  def search_trigger_idx(self, text_token, trigger_str) :
    all_token = text_token
    all_string_now = ""
    all_token_now = []
    dummy_start_token = None
    for token_idx in range(len(all_token)) :
      token_now = int(all_token[token_idx].detach().cpu().numpy())
      all_token_now.append(token_now)
      token_as_string = self.tokenizer_language.batch_decode([all_token_now],skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

      if trigger_str in token_as_string :
        dummy_start_token = token_idx + 1
        break
    return dummy_start_token

  def preproc_image(self, image_local_link) :
    image = Image.open(image_local_link)
    inputs = self.image_processor(image, return_tensors="pt")

    return inputs


  def get_image_embed(self, image_input) :
    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptorListCaption.adaptor_embedding(img_output)

    return img_embed

  def get_bias_token_attn_mask(self,attention_mask_late,bias_token_now) :
    new_attn_mask_now = attention_mask_late.to(float)
    B, L = new_attn_mask_now.shape
    new_attn_mask_now = new_attn_mask_now.reshape(B,L,1) @ new_attn_mask_now.reshape(B,1,L)
    new_attn_mask_now = torch.tril(new_attn_mask_now)

    new_attn_mask_now[:,1:1+len(bias_token_now),:] = 0 # B x L x L
    new_attn_mask_now[:,1:1+len(bias_token_now),1:1+len(bias_token_now)] = 1

    return new_attn_mask_now

  def replace_embedding_hook(self, image_input) :
    image_feature = self.get_image_embed(image_input)
    assert len(image_feature) == 1

    def now_hook(model, input, output) :
      real_input = input[0]
      batch_size, token_len = real_input.shape
      if(token_len > 1) :
        assert batch_size == 1
        dummy_start_token = self.search_trigger_idx(real_input[0], self.trigger_str_img )

        temp = image_feature[0]
        output[:,dummy_start_token:dummy_start_token+self.num_vis_token_summary] = temp
      return output
    return now_hook



  def split_and_replace(self, now_input_tokens, replacement_embed, start_loc) :
    num_token = len(replacement_embed)

    start_embed = now_input_tokens[0:start_loc]
    end_embed = now_input_tokens[start_loc+num_token:]
    replaced_embed = torch.cat((start_embed, replacement_embed.to(now_input_tokens.dtype), end_embed),0)

    return replaced_embed

  def forward_loss(self, image_input_raw, caption_output_raw) :
    instruction_now =  self.dummy_bias_token + "<start_of_turn>user\n"
    instruction_now += f"<start_image> {self.dummy_img_token}\n<end_image>\n"
    instruction_now += f"Create a simple description of the image!\n<end_of_turn>\n<start_of_turn>model\n"

    image_input = self.image_processor(image_input_raw, return_tensors="pt")['pixel_values']
    image_input = image_input.to(device)

    caption_output = self.tokenizer_language(caption_output_raw,padding=True,return_tensors="pt")
    caption_output['input_ids'] = caption_output['input_ids'].to(device)
    caption_output['attention_mask'] = caption_output['attention_mask'].to(device)

    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptorListCaption.adaptor_embedding(img_output)

    all_text_with_prompt = [instruction_now + temp_text for temp_text in self.tokenizer_language.batch_decode(caption_output['input_ids'], skip_special_tokens=True)]
    all_tokens_with_prompt = self.tokenizer_language(all_text_with_prompt, padding=True, return_tensors="pt")
    all_tokens_with_prompt['input_ids'] = all_tokens_with_prompt['input_ids'].to(device).detach()
    all_tokens_with_prompt['attention_mask'] = all_tokens_with_prompt['attention_mask'].to(device).detach()

    all_token_prompt_embed = self.model_language.model.embed_tokens(all_tokens_with_prompt['input_ids'])
    prompt_len = len(self.tokenizer_language([instruction_now])['input_ids'][0])
    caption_label_now = all_tokens_with_prompt['input_ids'][:,prompt_len:]
    caption_label_now = F.one_hot(caption_label_now,self.num_vocab)
    attn_mask_now = all_tokens_with_prompt['attention_mask'][:,prompt_len:]

    all_replaced_feature = []
    for temp_idx in range(len(all_tokens_with_prompt['input_ids'])) :
      tokens_text_now = all_tokens_with_prompt['input_ids'][temp_idx].detach().cpu()
      dummy_location_caption = self.search_trigger_idx(tokens_text_now, self.trigger_str_img )

      replaced_begin_task = self.split_and_replace(all_token_prompt_embed[temp_idx], self.learnableToken.bias_token ,1 )
      image_replaced_prompt = self.split_and_replace(replaced_begin_task, img_embed[temp_idx], dummy_location_caption)

      all_replaced_feature.append(image_replaced_prompt)


    all_replaced_feature = torch.stack(all_replaced_feature)

    new_attn_mask_now = self.get_bias_token_attn_mask(all_tokens_with_prompt['attention_mask'], self.learnableToken.bias_token)
    logits_now = self.model_language(inputs_embeds =all_replaced_feature, attention_mask=new_attn_mask_now)

    logits_now = logits_now['logits']
    caption_prediction_now = logits_now[:,prompt_len-1:-1]
    caption_prediction_now = torch.softmax(caption_prediction_now,-1)

    loss_lm = -torch.sum(caption_label_now*torch.log(caption_prediction_now),-1)
    loss_lm = torch.sum(loss_lm*attn_mask_now,-1)/torch.sum(attn_mask_now,-1)
    loss_lm = torch.mean(loss_lm)

    return loss_lm, caption_prediction_now

  def generate_aswer_image(self, input_with_dummy_prompt, pil_image = None, max_new_tokens = 32, do_sample=True, top_k=50, top_p=0.95, temperature =1 ) :

    dummy_input = self.tokenizer_language(input_with_dummy_prompt,padding=True,return_tensors="pt")
    dummy_input['input_ids'] = dummy_input['input_ids'].to(device)
    dummy_input['attention_mask'] = dummy_input['attention_mask'].to(device)
    assert len(dummy_input['input_ids']) == 1

    handler_image = None

    contains_image = False
    if self.trigger_str_img in input_with_dummy_prompt :
      image_input = self.image_processor(pil_image, return_tensors="pt")['pixel_values'].to(device)

      if len(image_input) == 1 :
        hook_now_image = self.replace_embedding_hook(image_input)
        contains_image = True
        handler_image = self.model_language.model.embed_tokens.register_forward_hook(hook_now_image)
      elif len(image_input) > 1  :
        hook_now_image = self.replace_embedding_hook_multiple(image_input)
        contains_image = True
        handler_image = self.model_language.model.embed_tokens.register_forward_hook(hook_now_image)


    output_now = self.model_language.generate(**dummy_input,
                                              max_new_tokens = max_new_tokens,
                                              do_sample=do_sample,
                                              temperature=temperature,
                                              top_k=top_k,
                                              top_p=top_p,
                                              )
    output_string = self.tokenizer_language.batch_decode(output_now, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    if contains_image :
      handler_image.remove()

    return output_string



# Hyperparameter

In [None]:
CAPTION_PATH = "/content/Flickr8k.token.txt"
IMAGES_FILE_PATH = "/content/Flicker8k_Dataset"
SAVED_PATH = "/content/saved_model/adaptor_caption.pt"

BATCH_SIZE = 4
NUM_ITERATION = 2000
SAVE_EVERY = 500
LEARNING_RATE = 1e-4
TRAIN_DATA_NUM = 7500

device = 'cpu'
if torch.cuda.is_available() :
  device = 'cuda'

# Util Function

In [None]:
def check_model_nan(model) :
  num_nan = 0
  for param in model.parameters() :
    num_nan += torch.sum(torch.isnan(param))
  return num_nan > 0

In [None]:
def getLabelDictionary(file_path) :
  file_now = open(file_path)
  all_string = file_now.read()
  all_string = all_string.split('\n')
  label_dictionary = {}
  for line_now in all_string :
    splitted_line = line_now.split('\t')
    if len(splitted_line) > 1 :
      file_name_now = splitted_line[0].split('#')[0]
      number_now = splitted_line[0].split('#')[1]
      label_now = splitted_line[1]

      if file_name_now in label_dictionary.keys() :
        label_dictionary[file_name_now].append(label_now)
      else :
        label_dictionary.update({
            file_name_now:[label_now]
        })

  return label_dictionary

In [None]:
def count_model_param(model_now) :
  counter = 0
  for param in model_now.parameters() :
    counter = counter + torch.sum(torch.ones_like(param))
  return counter

In [None]:
def sample_data_caption(file_list_now, caption_dict_now, n) :
  # base_path = "/content/Flicker8k_Dataset"
  base_path = IMAGES_FILE_PATH
  rand_idx = np.random.randint(0,len(file_list_now), n)

  all_image = []
  all_text = []
  for idx_now in rand_idx :
    file_now = base_path + "/" + file_list_now[idx_now]
    image_now = Image.open(file_now)
    all_image.append(image_now)

    text_list_now = caption_dict_now[file_list_now[idx_now]]
    selected_text_now_idx = np.random.randint(0,len(text_list_now))
    all_text.append(text_list_now[selected_text_now_idx])

  return all_image, all_text

# Initialization

In [None]:
label_dictionary = getLabelDictionary(CAPTION_PATH)
all_file = os.listdir(IMAGES_FILE_PATH)

In [None]:
model = MyModel()
model = model.to(device)
model = model.to(torch.bfloat16)

In [None]:
# if load model
if os.path.exists(SAVED_PATH) :
  model.adaptorListCaption.load_state_dict(torch.load(SAVED_PATH))

In [None]:
for param in model.parameters() :
  param.requires_grad = True
for param in model.model_language.parameters() :
  param.requires_grad = False
for param in model.model_image.parameters() :
    param.requires_grad = False

In [None]:
optim = torch.optim.Adam(model.parameters(),LEARNING_RATE)

# Training

In [None]:
model.train()
for itr in range(NUM_ITERATION) :
  print("ITERATION:", itr, "/", NUM_ITERATION)

  rand_image, rand_targets = sample_data_caption(all_file, label_dictionary, BATCH_SIZE)
  loss = model.forward_loss(rand_image, rand_targets)
  optim.zero_grad()
  loss.backward()
  optim.step()
