# langauge+vision

## import library

In [None]:
! pip install datasets
! pip install peft bitsandbytes accelerate
! pip install trl
! pip install lightning
! pip install transformers



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import argparse
import sys
import torch
import torch.nn as nn
from PIL import Image
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaForCausalLM,
    SiglipImageProcessor,
    SiglipVisionModel,
    AutoProcessor,
    TrainingArguments,
    LlavaForConditionalGeneration,
)
from transformers import TextStreamer
from peft import get_peft_model, LoraConfig
from torch.utils.data import Dataset

from transformers.data.data_collator import DataCollatorForLanguageModeling
import lightning as L
from datasets import load_dataset
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from io import BytesIO
import requests
import zipfile
import os
from trl import SFTTrainer, SFTConfig

In [None]:
config = {"max_epochs": 2,
          "val_check_interval": 0.5, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-5,
          "batch_size": 2,
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 50,
          "result_path": "./result",
          "verbose": True,
          "max_length": 1024,
          "except_image_max_length": 512,
          "model_name": "unsloth/llama-3-8b-Instruct",
          "vision_model_name": "google/siglip-so400m-patch14-384",
          "model_embedding_size": 4096,
          "vision_model_embedding_size": 1152,
}

# "MLP-KTLim/llama-3-Korean-Bllossom-8B"
# "unsloth/llama-3-8b-Instruct"
# "microsoft/Phi-3-mini-4k-instruct"

In [None]:
# !git clone https://huggingface.co/qresearch/llama-3-vision-alpha

## model

In [None]:
def initialize_models():

    llm = config.get("model_name")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        # bnb_4bit_compute_type=torch.bfloat16
    )

    tokenizer = AutoTokenizer.from_pretrained(
        llm, use_fast=True
    )
    tokenizer.padding_side = "right"
    model = LlamaForCausalLM.from_pretrained(
        llm,
        # torch_dtype=torch.float16,
        device_map="auto",
        quantization_config=bnb_config,
        attn_implementation="eager",
        output_hidden_states = True,
    )

    for param in model.base_model.parameters():
        param.requires_grad = False

    model_name = config.get("vision_model_name")
    # model_name = ""
    vision_model = SiglipVisionModel.from_pretrained(
        model_name,
        # torch_dtype=torch.float16
    )
    processor = SiglipImageProcessor.from_pretrained(model_name)

    vision_model = vision_model.to("cuda")

    return tokenizer, model, vision_model, processor

In [None]:
class ProjectionModule(nn.Module):
    def __init__(self, mm_hidden_size, hidden_size):
        super(ProjectionModule, self).__init__()

        # Directly set up the sequential model
        self.model = nn.Sequential(
            nn.Linear(mm_hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device="cuda"):
#     projection_module = ProjectionModule(mm_hidden_size, hidden_size)
#     checkpoint = torch.load("./mm_projector.bin")
#     # checkpoint = state_dict
#     checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
#     projection_module.load_state_dict(checkpoint)
#     projection_module = projection_module.to(device).half()
#     return projection_module

In [None]:
tokenizer, model, vision_model, processor = initialize_models()
tokenizer.eos_token = "<|eot_id|>"

projection_module = ProjectionModule(mm_hidden_size=config.get("vision_model_embedding_size"), hidden_size=config.get("model_embedding_size")) #4096, 3072
projection_module = projection_module.to("cuda") #.half()
# state_dict = new_dict()
# projection_module = load_projection_module()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
lora_config = LoraConfig(
    r=4,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj"],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


In [None]:
lora_config = LoraConfig(
    r=2,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj"],
    # task_type="CAUSAL_LM",
)
vision_model = get_peft_model(vision_model, lora_config)
vision_model.print_trainable_parameters()

trainable params: 373,248 || all params: 428,598,848 || trainable%: 0.0871


## dataset

In [None]:
# initial_prompt = {
#         "role": "system",
#         "content": "Present a chat with an assistant and a user with a lot of questions. The assistant understands the intent and purpose of the question and answers it accurately. <image>",
#     }

In [None]:
# def tokenizer_image_token(prompt, tokenizer,max_length=config.get("except_image_max_length"),  image_token_index= 500000):

#     prompt_chunks = prompt.split("<image>")
#     tokenized_chunks = [tokenizer(chunk, truncation = True, padding = True,max_length=max_length).input_ids for chunk in prompt_chunks]
#     input_ids = tokenized_chunks[0]

#     for chunk in tokenized_chunks[1:]:
#         input_ids.append(image_token_index)
#         input_ids.extend(chunk[1:])  # Exclude BOS token on nonzero index

#     return torch.tensor(input_ids, dtype=torch.long)

In [None]:
# # for cmarkea/table-vqa
# class CustomDataset(Dataset):
#     def __init__(self,dataset, tokenizer, processor,split = "train" ,max_length=2048):
#         super().__init__()
#         self.tokenizer = tokenizer
#         self.processor = processor
#         self.max_length = max_length
#         self.split = split
#         self.dataset = dataset
#         self.dataset_length = len(self.dataset)

#     def __len__(self) -> int:
#         return self.dataset_length

#     def __getitem__(self, idx):

#         batch = self.dataset[idx]
#         texts = [x for x in batch['qa']['en']]
#         imgs = batch['image']

#         conversation = []
#         conversation.append(initial_prompt)
#         for _, qa in enumerate(texts):
#             conversation.append({"role": "user", "content": qa["question"]})
#             conversation.append({"role": "assistant", "content": qa["answer"]})
#         text = self.tokenizer.apply_chat_template(conversation,tokenize=False)

#         input_id = tokenizer_image_token(text, self.tokenizer, max_length=self.max_length).unsqueeze(0)

#         image = imgs.convert("RGB")
#         image_inputs = self.processor(
#             images=image, # [image],
#             return_tensors="pt",
#             do_resize=True,
#             size={"height": 384, "width": 384},
#         )
#         pixel_values = image_inputs["pixel_values"]

#         result = {'input_ids': input_id, 'pixel_values': pixel_values , 'texts': text}
#         return result

In [None]:
# # for cmarkea/table-vqa
# class CustomDataset2(Dataset):
#     def __init__(self,dataset, tokenizer, processor,split = "train" ,max_length=2048):
#         super().__init__()
#         self.tokenizer = tokenizer
#         self.processor = processor
#         self.max_length = max_length
#         self.split = split
#         self.dataset = dataset
#         self.dataset_length = len(self.dataset)


#     def __len__(self) -> int:
#         return self.dataset_length

#     def __getitem__(self, idx):

#         batch = self.dataset[idx]
#         texts = [x for x in batch['qa']['en']]
#         imgs = batch['image']

#         questions = []
#         answers = []

#         for _, qa in enumerate(texts):
#             question = []
#             question.append(initial_prompt)
#             question.append({"role": "user", "content": qa["question"]})
#             questions.append(question)
#             answers.append(qa["answer"])
#         text = self.tokenizer.apply_chat_template(questions,tokenize=False)


#         new_text = []
#         for t in text:
#           prompt_chunks = t.split("<image>")
#           tokenized_chunks = [tokenizer(chunk, truncation = True, padding = True,max_length=self.max_length).input_ids for chunk in prompt_chunks]
#           input_ids = tokenized_chunks[0]

#           for chunk in tokenized_chunks[1:]:
#               input_ids.append(-200)
#               input_ids.extend(chunk[1:])  # Exclude BOS token on nonzero index

#           input_id =  torch.tensor(input_ids, dtype=torch.long)

#           new_text.append(input_id)

#         input_id = torch.nn.utils.rnn.pad_sequence(new_text, batch_first=True, padding_value=self.tokenizer.pad_token_id)


#         image = imgs.convert("RGB")
#         image_inputs = self.processor(
#             images=image, # [image],
#             return_tensors="pt",
#             do_resize=True,
#             size={"height": 384, "width": 384},
#         )
#         pixel_values = image_inputs["pixel_values"]

#         new_pix = []
#         for _ in range(len(text)):
#           new_pix.append(pixel_values)

#         pixel_values = torch.cat(new_pix, dim = 0)

#         result = {'input_ids': input_id, 'pixel_values': pixel_values , 'answers': answers,  'texts': text}
#         return result

In [None]:
# # for cmarkea/table-vqa
# class DataCollatorForCustomVLM(DataCollatorForLanguageModeling):
#     def __init__(self, tokenizer, mlm=False):
#         super().__init__(tokenizer, mlm)

#     def __call__(self, batch):
#         input_ids = [item['input_ids'].squeeze(0) for item in batch]
#         pixel_values = [item['pixel_values'] for item in batch]

#         input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

#         pixel_values = torch.cat(pixel_values, dim = 0)

#         labels = input_ids_padded.clone()
#         if self.tokenizer.pad_token_id is not None:
#             labels[labels == self.tokenizer.pad_token_id] = -100

#         return {
#             'input_ids': input_ids_padded,
#             'pixel_values': pixel_values,
#             'labels': labels
#         }

# data_collator = DataCollatorForCustomVLM(tokenizer=tokenizer, mlm=False)

In [None]:
# # for cmarkea/table-vqa
# class DataCollatorForCustomVLM2:
#     def __init__(self):
#       pass

#     def __call__(self, batch):

#         return batch

# data_collator2 = DataCollatorForCustomVLM2()

In [None]:
# # raw_datasets = load_dataset("liuhaotian/LLaVA-Instruct-150K")
# raw_datasets = load_dataset("cmarkea/table-vqa")
# train = raw_datasets["train"]
# valid = raw_datasets["test"]

In [None]:
# train_dataset = CustomDataset(train, tokenizer, processor,split = "train" ,max_length=config.get("max_length"))
# val_dataset = CustomDataset2(valid, tokenizer, processor,split = "test" ,max_length=config.get("max_length"))

In [None]:
# train[0]

In [None]:
# print(train_dataset[0]['texts'])

## dataset2

In [None]:
def tokenizer_image_token(prompt, tokenizer,max_length=config.get("except_image_max_length"),  image_token_index= 500000):

    prompt_chunks = prompt.split("<image>")
    tokenized_chunks = [tokenizer(chunk, truncation = True, padding = True,max_length=max_length).input_ids for chunk in prompt_chunks]
    input_ids = tokenized_chunks[0]

    for chunk in tokenized_chunks[1:]:
        input_ids.append(image_token_index)
        input_ids.extend(chunk[1:])  # Exclude BOS token on nonzero index

    attention_mask = torch.ones(len(input_ids), dtype=torch.long)

    return torch.tensor(input_ids, dtype=torch.long), attention_mask

In [None]:
LLAVA_CHAT_TEMPLATE = """{% for message in messages %} \
  {% if message['from'] == 'human' %}
    USER: {{ message['value'] }} \
  {% else %}
    ASSISTANT: {{ message['value'] }} \
  {% endif %} \
  {% if message['from'] == 'gpt' %} \
  {% else %} \
      {{ eos_token }} \
  {% endif %} \
{% endfor %}"""

tokenizer.chat_template = LLAVA_CHAT_TEMPLATE

In [None]:
class CustomDataset(Dataset):
    def __init__(self,dataset,tokenizer, processor,split = "train" ,max_length=2048):
        super().__init__()
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_length = max_length
        self.split = split
        self.dataset = dataset
        self.dataset_length = len(self.dataset)

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx):

        batch = self.dataset[idx]
        texts = batch['conversations']
        text = self.tokenizer.apply_chat_template(texts,tokenize=False)
        input_id = tokenizer_image_token(text, self.tokenizer, max_length=self.max_length)[0].unsqueeze(0)
        attention_mask = tokenizer_image_token(text, self.tokenizer, max_length=self.max_length)[1].unsqueeze(0)

        img_name = batch['image']
        imgs = os.path.join("/content/extracted_images/", img_name)
        imgs = Image.open(imgs)
        image = imgs.convert("RGB")
        image_inputs = self.processor(
            images=image, # [image],
            return_tensors="pt",
            do_resize=True,
            size={"height": 384, "width": 384},
        )
        pixel_values = image_inputs["pixel_values"]


        result = {'input_ids': input_id, 'attention_mask' : attention_mask ,'pixel_values': pixel_values , 'texts': text}
        return result

In [None]:
class CustomDataset2(Dataset):
    def __init__(self,dataset,tokenizer, processor,split = "train" ,max_length=2048):
        super().__init__()
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_length = max_length
        self.split = split
        self.dataset = dataset
        self.dataset_length = len(self.dataset)

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx):

        batch = self.dataset[idx]
        texts = batch['conversations']
        text = self.tokenizer.apply_chat_template(texts,tokenize=False)
        question = text.split("<|eot_id|>")[0]
        question += " ASSISTANT:"
        answer = text.split("<|eot_id|>")[1].replace("ASSISTANT:", "")

        input_id = tokenizer_image_token(question, self.tokenizer, max_length=self.max_length)[0].unsqueeze(0)
        attention_mask = tokenizer_image_token(question, self.tokenizer, max_length=self.max_length)[1].unsqueeze(0)
        img_name = batch['image']

        imgs = os.path.join("/content/extracted_images/", img_name)
        imgs = Image.open(imgs)
        image = imgs.convert("RGB")
        image_inputs = self.processor(
            images=image, # [image],
            return_tensors="pt",
            do_resize=True,
            size={"height": 384, "width": 384},
        )
        pixel_values = image_inputs["pixel_values"]

        result = {'input_ids': input_id, 'attention_mask' : attention_mask, 'pixel_values': pixel_values , 'questions': question, 'answers': answer }
        return result

In [None]:
# raw_datasets = load_dataset("lmms-lab/LLaVA-OneVision-Data", 'TabMWP(MathV360K)')
raw_datasets = load_dataset("liuhaotian/LLaVA-CC3M-Pretrain-595K",data_files = "chat.json")


In [None]:
!git clone https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K

fatal: destination path 'LLaVA-CC3M-Pretrain-595K' already exists and is not an empty directory.


In [None]:
zip_file_path = "/content/LLaVA-CC3M-Pretrain-595K/images.zip"
# 압축을 풀 디렉토리 경로
extract_to = "extracted_images"

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [None]:
train = raw_datasets["train"]

train_valid_split = train.train_test_split(test_size=0.2, shuffle=True, seed=42)

# Access the train and validation sets
train_split = train_valid_split['train']
valid_split = train_valid_split['test']

In [None]:
train_split[0]['image']

'GCC_train_002140770.jpg'

In [None]:
train_dataset = CustomDataset(train_split,tokenizer, processor,split = "train" ,max_length=config.get("max_length"))
val_dataset = CustomDataset2(valid_split,tokenizer, processor,split = "train" ,max_length=config.get("max_length"))

In [None]:
print(train_dataset[1]['attention_mask'].shape)

torch.Size([1, 33])


In [None]:
class DataCollatorForCustomVLM(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, mlm=False):
        super().__init__(tokenizer, mlm)

    def __call__(self, batch):
        input_ids = [item['input_ids'].squeeze(0) for item in batch]
        attention_mask = [item['attention_mask'].squeeze(0) for item in batch]
        pixel_values = [item['pixel_values'] for item in batch]

        input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        pixel_values = torch.cat(pixel_values, dim = 0)

        labels = input_ids_padded.clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_ids_padded,
            'attention_mask':attention_mask,
            'pixel_values': pixel_values,
            'labels': labels
        }

class DataCollatorForCustomVLM2(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, mlm=False):
        super().__init__(tokenizer, mlm)

    def __call__(self, batch):
        input_ids = [item['input_ids'].squeeze(0) for item in batch]
        attention_mask = [item['attention_mask'].squeeze(0) for item in batch]
        pixel_values = [item['pixel_values'] for item in batch]
        answers = [item['answers'] for item in batch]

        input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        pixel_values = torch.cat(pixel_values, dim = 0)

        labels = input_ids_padded.clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_ids_padded,
            'attention_mask':attention_mask,
            'pixel_values': pixel_values,
            'labels': labels,
            'answers': answers
        }
data_collator = DataCollatorForCustomVLM(tokenizer=tokenizer, mlm=False)
data_collator2 = DataCollatorForCustomVLM2(tokenizer=tokenizer, mlm=False)

 ## custom vision langauge model

In [None]:
def process_tensors(input_ids, attention_mask, image_features, embedding_layer):

    total_ids = []
    total_attn = []

    for i in range(input_ids.shape[0]):

      input = input_ids[i].unsqueeze(0)
      attention = attention_mask[i].unsqueeze(0)
      image = image_features[i].unsqueeze(0)

      if not isinstance(input, torch.Tensor):
          input = torch.tensor(input)

      split_index = (input == 500000).nonzero(as_tuple=True)[1]

      input_ids_1 = input[:, :split_index]
      input_ids_2 = input[:, split_index + 1 :]

      # Convert input_ids to embeddings
      embeddings_1 = embedding_layer(input_ids_1)
      embeddings_2 = embedding_layer(input_ids_2)

      device = image.device
      token_embeddings_part1 = embeddings_1.to(device)
      token_embeddings_part2 = embeddings_2.to(device)

      # Concatenate the token embeddings and image features
      concatenated_embedding = torch.cat(
          [token_embeddings_part1, image, token_embeddings_part2], dim=1
      )

      attention_mask_1 = attention[:, :split_index]
      attention_mask_2 = attention[:, split_index + 1 :]
      image = torch.ones(
          image.shape[:2], dtype=torch.long
      )
      idevice = image.device
      cat_attention_mask = torch.cat(
          [attention_mask_1.to(idevice), image, attention_mask_2.to(idevice)], dim=1
      )

      # # Create the corrected attention mask
      # attention_mask = torch.ones(
      #     concatenated_embedding.shape[:2], dtype=torch.long
      # )

      total_ids.append(concatenated_embedding)

      total_attn.append(cat_attention_mask)

    concatenated_embeddings = torch.cat(total_ids, dim=0)
    cat_attention_masks = torch.cat(total_attn, dim=0)
    return concatenated_embeddings , cat_attention_masks

def process_labels(input_ids, image_features):

    total_embed = []

    for i in range(input_ids.shape[0]):
      input = input_ids[i].unsqueeze(0)
      image = image_features[i].unsqueeze(0)

      if not isinstance(input, torch.Tensor):
          input = torch.tensor(input)

      split_index = (input == 500000).nonzero(as_tuple=True)[1][0]

      input_ids_1 = input[:, :split_index]
      input_ids_2 = input[:, split_index + 1 :]

      device = image.device
      pbatch = image.shape[0]
      pseq = image.shape[1]
      image_token = torch.full([pbatch,pseq,], -100, dtype=torch.long).to(device)

      # Concatenate the token embeddings and image features
      concatenated_embedding = torch.cat(
          [input_ids_1, image_token, input_ids_2], dim=1
      )
      total_embed.append(concatenated_embedding)

    concatenated_embeddings = torch.cat(total_embed, dim=0)
    return concatenated_embeddings

In [None]:
class custom_vlm(L.LightningModule):
    def __init__(self, config, vision_model, model, projection_module, tokenizer):
        super().__init__()
        self.config = config
        self.model = model
        self.vision_model = vision_model
        self.tokenizer = tokenizer
        self.projection_module = projection_module
        self.batch_size = config.get("batch_size")
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        # self.automatic_optimization = False # 주석 가능


    def on_train_start(self):
        self.model.train()
        self.vision_model.train()
        self.projection_module.train()


    def training_step(self, batch, batch_idx):
        self.model.train()
        self.vision_model.train()
        self.projection_module.train()
        # opt = self.optimizers() # automatic_optimization F


        input_ids = batch['input_ids'].to(self.device) # value long
        pixel_values = batch['pixel_values'].to(self.device)
        attention_mask = batch['attention_mask']

        # with autocast():
        image_forward_outs = self.vision_model(
            pixel_values.to(device=self.device,dtype=torch.float16), #.unsqueeze(0),
            output_hidden_states=True,
        ) # value float16

        image_features = image_forward_outs.hidden_states[-2]
        projected_embeddings = self.projection_module(image_features).to(self.device) # module float32 + value float16

        embedding_layer = self.model.get_input_embeddings()

        # 배치 데이터 가능
        new_embeds , attn_mask = process_tensors(
            input_ids, attention_mask, projected_embeddings, embedding_layer
        ) # value float16

        labels = process_labels(
            input_ids, projected_embeddings
        ) # value float16

        attn_mask = attn_mask.to(self.device)
        new_embeds = new_embeds.to(self.device)
        labels = labels.to(self.device)

        outputs = self.model(inputs_embeds=new_embeds, attention_mask=attn_mask) # module float32 + value float16
        logits = outputs.logits


        logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * sequence_length, vocab_size]
        labels = labels.view(-1)  # Shape: [batch_size * sequence_length]


        loss = self.loss_fn(logits, labels)
        # opt.zero_grad() # automatic_optimization F
        # self.manual_backward(loss) # automatic_optimization F
        # opt.step() # automatic_optimization F

        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        self.model.eval()
        self.vision_model.eval()
        self.projection_module.eval()


        answers = batch['answers']
        input_ids = batch['input_ids'].to(self.device) # value long
        pixel_values = batch['pixel_values'].to(self.device)
        attention_mask = batch['attention_mask']

        # answers = batch[0]['answers']
        # input_ids = batch[0]['input_ids'] # value long
        # pixel_values = batch[0]['pixel_values']

        # with autocast():
        image_forward_outs = self.vision_model(
            pixel_values.to(device=self.device,dtype=torch.float16), #.unsqueeze(0),
            output_hidden_states=True,
        ) # value float16

        image_features = image_forward_outs.hidden_states[-2]
        projected_embeddings = self.projection_module(image_features).to(self.device) # module float32 + value float16

        embedding_layer = self.model.get_input_embeddings()

        # 배치 데이터 가능
        new_embeds, attn_mask = process_tensors(
            input_ids, attention_mask, projected_embeddings, embedding_layer
        ) # value float16

        attn_mask = attn_mask.to(self.device)
        new_embeds = new_embeds.to(self.device)

        # autoregressively generate token IDs
        generated_ids = self.model.generate(inputs_embeds=new_embeds.to(dtype=torch.float16), attention_mask=attn_mask, max_new_tokens=128)
        # turn them back into text, chopping of the prompt
        # important: we don't skip special tokens here, because we want to see them in the output
        predictions = self.tokenizer.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores), on_step=True, on_epoch=False, prog_bar=True, logger=True)

        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))

        return optimizer

    def train_dataloader(self):
        return DataLoader(train_dataset, collate_fn=data_collator, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(val_dataset, collate_fn=data_collator2, batch_size=self.batch_size, shuffle=True, num_workers=4)
    # def val_dataloader(self):
    #     return DataLoader(val_dataset, collate_fn=data_collator2, batch_size=1, shuffle=False, num_workers=4)

In [None]:
model_module = custom_vlm(config, vision_model, model, projection_module, tokenizer)

## training

In [None]:
print(model.get_input_embeddings())
print(vision_model.get_input_embeddings())

Embedding(128256, 4096, padding_idx=128255)
Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)


In [None]:
sum(p.numel() for p in projection_module.parameters() if p.requires_grad == True)

21504000

In [None]:
total_params = sum(p.numel() for p in model_module.parameters() if p.requires_grad == True)
print(f"Total trainable parameters: {total_params}")

Total trainable parameters: 25285120


In [None]:
for name, module in projection_module.named_modules():
    if isinstance(module,torch.nn.Linear):
        module.weight.requires_grad = False
        module.bias.requires_grad = False

In [None]:
from lightning.pytorch.loggers import WandbLogger

# wandb_logger = WandbLogger(project=WANDB_PROJECT, name=WANDB_NAME)

trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        max_epochs=config.get("max_epochs"),
        accumulate_grad_batches=config.get("accumulate_grad_batches"), # automatic opt 제외시
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"), # automatic opt 제외시
        precision="16-mixed", # 16
        limit_val_batches=10,
        num_sanity_val_steps=0,
        val_check_interval=config.get("val_check_interval"),  # % of an epoch
        gradient_clip_algorithm="norm",        # 그래디언트 클리핑 알고리즘 설정

)

trainer.fit(model_module)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cor

Training: |          | 0/? [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: ::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             battle

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                                                         
    A

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::ıldığındaЎыџNıldığında
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:  

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 1.0
Prediction:                                                                                                         
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.7307692307692307
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 1.0
Prediction: :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    Answer:             a student gets sleepy during state school .          
 Normed ED: 1.0
Prediction:                                                                           

INFO: 
Detected KeyboardInterrupt, attempting graceful shutdown ...
INFO:lightning.pytorch.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
torch.save(model_module.projection_module.state_dict(), "/content/drive/MyDrive/model/llama3_multi.bin")

In [None]:
model_module.save_pretrained("/content/drive/MyDrive/model/")