# langauge+vision

## import library

In [None]:
! pip install datasets
! pip install peft bitsandbytes accelerate
! pip install trl
! pip install lightning
! pip install transformers

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [None]:
!git clone https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K

In [None]:
import zipfile

zip_file_path = "/content/LLaVA-CC3M-Pretrain-595K/images.zip"
# 압축을 풀 디렉토리 경로
extract_to ="/content/LLaVA-CC3M-Pretrain-595K/extracted_images"

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import argparse
import sys
import torch
import torch.nn as nn
from PIL import Image
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaForCausalLM,
    SiglipImageProcessor,
    SiglipVisionModel,
    AutoProcessor,
    TrainingArguments,
    LlavaForConditionalGeneration,
)
from transformers import TextStreamer
from peft import get_peft_model, LoraConfig
from torch.utils.data import Dataset

from transformers.data.data_collator import DataCollatorForLanguageModeling
import lightning as L
from datasets import load_dataset
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from io import BytesIO
import requests
import zipfile
import os
from trl import SFTTrainer, SFTConfig

In [None]:
config = {"max_epochs": 1,
          "val_check_interval": 0.05, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-5,
          "batch_size": 2,
          # "seed":24,
          "num_nodes": 1,
          "warmup_steps": 50,
          "result_path": "./result",
          "verbose": True,
          "except_image_max_length": 64,
          "model_name": "unsloth/llama-3-8b-Instruct",
          "vision_model_name": "google/siglip-so400m-patch14-384",
          "model_embedding_size": 4096,
          "vision_model_embedding_size": 1152,
          "image_size" : 384
}

# "MLP-KTLim/llama-3-Korean-Bllossom-8B"
# "unsloth/llama-3-8b-Instruct"
# "microsoft/Phi-3-mini-4k-instruct"

In [None]:
# !git clone https://huggingface.co/qresearch/llama-3-vision-alpha

## model

In [None]:
def initialize_models():

    llm = config.get("model_name")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
    )

    tokenizer = AutoTokenizer.from_pretrained(
        llm, use_fast=True
    )
    tokenizer.padding_side = "right"
    model = LlamaForCausalLM.from_pretrained(
        llm,
        # torch_dtype=torch.float16,
        device_map="auto",
        quantization_config=bnb_config,
        attn_implementation="eager",
        output_hidden_states = True,
        return_dict_in_generate = True,
    )

    for param in model.base_model.parameters():
        param.requires_grad = False

    model_name = config.get("vision_model_name")
    # model_name = ""
    vision_model = SiglipVisionModel.from_pretrained(
        model_name,
        # torch_dtype=torch.float16
    )
    processor = SiglipImageProcessor.from_pretrained(model_name)
    vision_model = vision_model.to("cuda")

    return tokenizer, model, vision_model, processor

In [None]:
class ProjectionModule(nn.Module):
    def __init__(self, mm_hidden_size, hidden_size):
        super(ProjectionModule, self).__init__()

        # Directly set up the sequential model
        self.model = nn.Sequential(
            nn.Linear(mm_hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )
    def forward(self, x):
        return self.model(x)

In [None]:
# def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device="cuda"):
#     projection_module = ProjectionModule(mm_hidden_size, hidden_size)
#     checkpoint = torch.load("./mm_projector.bin")
#     # checkpoint = state_dict
#     checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
#     projection_module.load_state_dict(checkpoint)
#     projection_module = projection_module.to(device).half()
#     return projection_module

In [None]:
tokenizer, model, vision_model, processor = initialize_models()
tokenizer.eos_token = "<|eot_id|>"
projection_module = ProjectionModule(mm_hidden_size=config.get("vision_model_embedding_size"), hidden_size=config.get("model_embedding_size")) #4096, 3072
projection_module = projection_module.to("cuda") #.half()
# state_dict = new_dict()
# projection_module = load_projection_module()

## dataset

In [None]:
def tokenizer_image_token(prompt, tokenizer,max_length=config.get("except_image_max_length"),  image_token_index= 500000):

    prompt_chunks = prompt.split("<image>")
    tokenized_chunks = [tokenizer(chunk, truncation = True, padding = True,max_length=max_length).input_ids for chunk in prompt_chunks]
    input_ids = tokenized_chunks[0]

    for chunk in tokenized_chunks[1:]:
        input_ids.append(image_token_index)
        input_ids.extend(chunk[1:])  # Exclude BOS token on nonzero index

    attention_mask = torch.ones(len(input_ids), dtype=torch.long)

    return torch.tensor(input_ids, dtype=torch.long), attention_mask

In [None]:
LLAVA_CHAT_TEMPLATE = """{% for message in messages %} \
  {% if message['from'] == 'human' %}
    USER: {{ message['value'] }} \
  {% else %}
    ASSISTANT: {{ message['value'] }} \
  {% endif %} \
  {% if message['from'] == 'gpt' %} \
  {% else %} \
      {{ eos_token }} \
  {% endif %} \
{% endfor %}"""

tokenizer.chat_template = LLAVA_CHAT_TEMPLATE

In [None]:
class CustomDataset(Dataset):
    def __init__(self,dataset,tokenizer, processor, train = True , max_length=512, image_size = 256):
        super().__init__()
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_length = max_length
        self.train = train
        self.dataset = dataset
        self.image_size = image_size
        self.dataset_length = len(self.dataset)

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx):

        batch = self.dataset[idx]
        texts = batch['conversations']
        text = self.tokenizer.apply_chat_template(texts,tokenize=False)
        input_id = tokenizer_image_token(text, self.tokenizer, max_length=self.max_length)[0].unsqueeze(0)
        attention_mask = tokenizer_image_token(text, self.tokenizer, max_length=self.max_length)[1].unsqueeze(0)

        img_name = batch['image']
        imgs = os.path.join("/content/LLaVA-CC3M-Pretrain-595K/extracted_images/", img_name)
        imgs = Image.open(imgs)
        image = imgs.convert("RGB")
        image_inputs = self.processor(
            images=image, # [image],
            return_tensors="pt",
            do_resize=True,
            size={"height": self.image_size, "width": self.image_size},
        )
        pixel_values = image_inputs["pixel_values"]

        if self.train:
            result = {'input_ids': input_id, 'attention_mask' : attention_mask ,'pixel_values': pixel_values , 'texts': text}
        else:
            question = text.split("<|eot_id|>")[0]
            question += " ASSISTANT:"
            answer = text.split("<|eot_id|>")[1].replace("ASSISTANT:", "")
            result = {'input_ids': input_id, 'attention_mask' : attention_mask, 'pixel_values': pixel_values , 'questions': question, 'answers': answer }
        return result

In [None]:
from datasets import load_dataset
# raw_datasets = load_dataset("lmms-lab/LLaVA-OneVision-Data", 'TabMWP(MathV360K)')
raw_datasets = load_dataset("liuhaotian/LLaVA-CC3M-Pretrain-595K",data_files = "chat.json")


In [None]:
train = raw_datasets["train"]

train_valid_split = train.train_test_split(test_size=0.2, shuffle=True, seed=42)

# Access the train and validation sets
train_split = train_valid_split['train']
valid_split = train_valid_split['test']

In [None]:
train_dataset = CustomDataset(train_split,tokenizer, processor , train = True,max_length=config.get("except_image_max_length"), image_size = config.get("image_size"))
val_dataset = CustomDataset(valid_split,tokenizer, processor , train = False,max_length=config.get("except_image_max_length"), image_size = config.get("image_size"))

In [None]:
class DataCollatorForCustomVLM(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, mlm=False, train = True):
        super().__init__(tokenizer, mlm)
        self.train = train

    def __call__(self, batch):
        input_ids = [item['input_ids'].squeeze(0) for item in batch]
        attention_mask = [item['attention_mask'].squeeze(0) for item in batch]
        pixel_values = [item['pixel_values'] for item in batch]


        input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        pixel_values = torch.cat(pixel_values, dim = 0)

        labels = input_ids_padded.clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100

        if self.train:
          returns =  {
              'input_ids': input_ids_padded,
              'attention_mask':attention_mask,
              'pixel_values': pixel_values,
              'labels': labels,
          }
        else:
          answers = [item['answers'] for item in batch]
          returns =  {
              'input_ids': input_ids_padded,
              'attention_mask':attention_mask,
              'pixel_values': pixel_values,
              'labels': labels,
              'answers': answers
          }

        return returns


 ## custom vision langauge model

In [None]:
def process_tensors(input_ids, attention_mask, image_features, embedding_layer):

    total_ids = []
    total_attn = []

    for i in range(input_ids.shape[0]):

      input = input_ids[i].unsqueeze(0)
      attention = attention_mask[i].unsqueeze(0)
      image = image_features[i].unsqueeze(0)

      if not isinstance(input, torch.Tensor):
          input = torch.tensor(input)

      split_index = (input == 500000).nonzero(as_tuple=True)[1]

      input_ids_1 = input[:, :split_index]
      input_ids_2 = input[:, split_index + 1 :]

      # Convert input_ids to embeddings
      embeddings_1 = embedding_layer(input_ids_1)
      embeddings_2 = embedding_layer(input_ids_2)

      device = image.device
      token_embeddings_part1 = embeddings_1.to(device)
      token_embeddings_part2 = embeddings_2.to(device)

      # Concatenate the token embeddings and image features
      concatenated_embedding = torch.cat(
          [token_embeddings_part1, image, token_embeddings_part2], dim=1
      )

      attention_mask_1 = attention[:, :split_index]
      attention_mask_2 = attention[:, split_index + 1 :]
      image = torch.ones(
          image.shape[:2], dtype=torch.long
      )
      idevice = image.device
      cat_attention_mask = torch.cat(
          [attention_mask_1.to(idevice), image, attention_mask_2.to(idevice)], dim=1
      )

      # # Create the corrected attention mask
      # attention_mask = torch.ones(
      #     concatenated_embedding.shape[:2], dtype=torch.long
      # )

      total_ids.append(concatenated_embedding)

      total_attn.append(cat_attention_mask)

    concatenated_embeddings = torch.cat(total_ids, dim=0)
    cat_attention_masks = torch.cat(total_attn, dim=0)
    return concatenated_embeddings , cat_attention_masks

def process_labels(input_ids, image_features):

    total_embed = []

    for i in range(input_ids.shape[0]):
      input = input_ids[i].unsqueeze(0)
      image = image_features[i].unsqueeze(0)

      if not isinstance(input, torch.Tensor):
          input = torch.tensor(input)

      split_index = (input == 500000).nonzero(as_tuple=True)[1][0]

      input_ids_1 = input[:, :split_index]
      input_ids_2 = input[:, split_index + 1 :]

      device = image.device
      pbatch = image.shape[0]
      pseq = image.shape[1]
      image_token = torch.full([pbatch,pseq,], -100, dtype=torch.long).to(device)

      # Concatenate the token embeddings and image features
      concatenated_embedding = torch.cat(
          [input_ids_1, image_token, input_ids_2], dim=1
      )
      total_embed.append(concatenated_embedding)

    concatenated_embeddings = torch.cat(total_embed, dim=0)
    return concatenated_embeddings

In [None]:
class custom_vlm(L.LightningModule):
    strict_loading = False

    def __init__(self, config, vision_model, model, projection_module, tokenizer):
        super().__init__()
        self.save_hyperparameters(ignore=["vision_model", "model", "projection_module", "tokenizer"])
        self.config = config
        self.model = model
        self.vision_model = vision_model
        self.tokenizer = tokenizer
        self.projection_module = projection_module
        self.batch_size = config.get("batch_size")
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.automatic_optimization = True # 주석 가능

    def on_fit_start(self):
    # def on_train_start(self):
        self.model.eval()
        self.vision_model.eval()
        self.projection_module.train()

    def training_step(self, batch, batch_idx):
        # self.model.eval()
        # self.vision_model.eval()
        self.projection_module.train()
        # opt = self.optimizers() # automatic_optimization F

        input_ids = batch['input_ids'].to(self.device) # value long
        pixel_values = batch['pixel_values'].to(self.device)
        attention_mask = batch['attention_mask']

        # with autocast():
        image_forward_outs = self.vision_model(
            pixel_values.to(device=self.device,dtype=torch.float16), #.unsqueeze(0),
            output_hidden_states=True,
        ) # value float16

        image_features = image_forward_outs.hidden_states[-2]
        projected_embeddings = self.projection_module(image_features).to(self.device) # module float32 + value float16

        embedding_layer = self.model.get_input_embeddings()

        # 배치 데이터 가능
        new_embeds , attn_mask = process_tensors(
            input_ids, attention_mask, projected_embeddings, embedding_layer
        ) # value float16

        labels = process_labels(
            input_ids, projected_embeddings
        ) # value float16

        attn_mask = attn_mask.to(self.device)
        new_embeds = new_embeds.to(self.device)
        labels = labels.to(self.device)

        outputs = self.model(inputs_embeds=new_embeds, attention_mask=attn_mask) # module float32 + value float16
        logits = outputs.logits

        # logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * sequence_length, vocab_size]
        # labels = labels.view(-1)  # Shape: [batch_size * sequence_length]

        shifted_labels = labels[:, 1:]  # t+1 시점 예측
        shifted_logits = logits[:, :-1, :]  # t 시점 출력

        loss = self.loss_fn(
            shifted_logits.contiguous().view(-1, shifted_logits.size(-1)),
            shifted_labels.contiguous().view(-1)
        )

        # opt.zero_grad() # automatic_optimization F
        # self.manual_backward(loss) # automatic_optimization F
        # opt.step() # automatic_optimization F

        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        # self.model.eval()
        # self.vision_model.eval()
        self.projection_module.eval()


        answers = batch['answers']
        input_ids = batch['input_ids'].to(self.device) # value long
        pixel_values = batch['pixel_values'].to(self.device)
        attention_mask = batch['attention_mask']

        # answers = batch[0]['answers']
        # input_ids = batch[0]['input_ids'] # value long
        # pixel_values = batch[0]['pixel_values']

        # with autocast():
        image_forward_outs = self.vision_model(
            pixel_values.to(device=self.device,dtype=torch.float16), #.unsqueeze(0),
            output_hidden_states=True,
        ) # value float16

        image_features = image_forward_outs.hidden_states[-2]
        projected_embeddings = self.projection_module(image_features).to(self.device) # module float32 + value float16

        embedding_layer = self.model.get_input_embeddings()

        # 배치 데이터 가능
        new_embeds, attn_mask = process_tensors(
            input_ids, attention_mask, projected_embeddings, embedding_layer
        ) # value float16

        attn_mask = attn_mask.to(self.device)
        new_embeds = new_embeds.to(self.device)

        # autoregressively generate token IDs
        generated_ids = self.model.generate(inputs_embeds=new_embeds.to(dtype=torch.float16), attention_mask=attn_mask, max_new_tokens=128)
        # turn them back into text, chopping of the prompt
        # important: we don't skip special tokens here, because we want to see them in the output

        sequences = generated_ids["sequences"]
        predictions = self.tokenizer.batch_decode(sequences[:, input_ids.size(1):], skip_special_tokens=True)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores), on_step=True, on_epoch=False, prog_bar=True, logger=True)

        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.AdamW(self.projection_module.parameters(), lr=self.config.get("lr"))

        return optimizer

    def train_dataloader(self):
        data_collator = DataCollatorForCustomVLM(tokenizer=tokenizer, mlm=False, train = True)
        return DataLoader(train_dataset, collate_fn=data_collator, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        data_collator = DataCollatorForCustomVLM(tokenizer=tokenizer, mlm=False, train = False)
        return DataLoader(val_dataset, collate_fn=data_collator, batch_size=self.batch_size, shuffle=False, num_workers=4)
    # def val_dataloader(self):
    #     return DataLoader(val_dataset, collate_fn=data_collator2, batch_size=1, shuffle=False, num_workers=4)

## training

In [None]:
print(model.get_input_embeddings())
print(vision_model.get_input_embeddings())

In [None]:
for name, module in projection_module.named_modules():
    if isinstance(module,torch.nn.Linear):
        module.weight.requires_grad = True
        module.bias.requires_grad = True

In [None]:
for param in vision_model.parameters():
    param.requires_grad = False

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
# model_module = custom_vlm(config, vision_model, model, projection_module, tokenizer)
# ckpt_path = "/content/drive/MyDrive/model/lightning_logs/version_10/checkpoints/epoch=0-step=11907.ckpt"

model_module = custom_vlm( # .load_from_checkpoint
    # ckpt_path,
    config=config,
    vision_model=vision_model,
    model=model,
    projection_module=projection_module,
    tokenizer=tokenizer,
    # strict=False  # 이거 추가!
)

In [None]:
sum(p.numel() for p in projection_module.parameters() if p.requires_grad == True)

In [None]:
total_params = sum(p.numel() for p in model_module.parameters() if p.requires_grad == True)
print(f"Total trainable parameters: {total_params}")

In [None]:
cd /content/drive/MyDrive/model

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

# wandb_logger = WandbLogger(project=WANDB_PROJECT, name=WANDB_NAME)

# 예: 1000 step마다 저장
checkpoint_callback = ModelCheckpoint(    # 저장 경로
    filename="step-{step}",         # 저장 파일 이름 패턴
    monitor = "val_edit_distance",
    mode = "min",
    save_top_k=1,
)

last_callback = ModelCheckpoint(    # 저장 경로
    filename="latest-step-{step}",         # 저장 파일 이름 패턴
    monitor = "step",
    mode = "max",
    every_n_train_steps= 10000,
    save_top_k=1,
)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=config.get("max_epochs"),
    accumulate_grad_batches=config.get("accumulate_grad_batches"),
    check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
    gradient_clip_val=config.get("gradient_clip_val"),
    precision="bf16-mixed",
    limit_val_batches=10,
    num_sanity_val_steps=0,
    val_check_interval=config.get("val_check_interval"),
    gradient_clip_algorithm="norm",
    # callbacks=[last_callback,checkpoint_callback],
)

In [None]:
trainer.fit(model_module)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type              | Params | Mode 
----------------------------------------------------------------
0 | model             | LlamaForCausalLM  | 4.5 B  | eval 
1 | vision_model      | SiglipVisionModel | 428 M  | eval 
2 | projection_module | ProjectionModule  | 21.5 M | train
3 | loss_fn           | CrossEntropyLoss  | 0      | train
----------------------------------------------------------------
21.5 M    Trainable params
5.0 B     Non-trainable params
5.0 B     Total params
19,961.320Total estimated model params size (MB)
6         Modules in train mode
763       Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name              | Type              | Params | Mode 
----------------------------------------------------------------
0 | model             | LlamaForCausalLM  | 4.5 B  | eval 
1 | vision_mo

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: ? -, -,,???,!, |,.?, - -, -.? -.? -. - - -, -?.? - - -,,, •. -. -,? -, - -, - • - -. - - - -..?.. -? - - •,. -,, -? n
. - - - -.
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7890625
Prediction: “ influential://,.swing.awt-old://ggee.swing influential.log influential.swing respected - — tryingishment able.getElementById://://://://. versa.swing.swing versa celebrated versa versa.ComponentPlacement://.swing.swing.ComponentPlacement ableeenth зрения.ComponentPlacement acclaimed influential importantlyug Arabia — versa importantly versa ableeg.logeg://s influential:// |
://.swing accused versa://://.swing:// educated importantlyishment-old://eg hasn versa invited importantly-old acclaimed invited been versa importantly been-old-old explained
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.9023508137432188
Prediction: .
.
.
:// /.
g.
g://.
.swing-old://g.swing.sw

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  a
 - - - - -? - - -, -??? -? - •? - -? - - - - -? -?? - - - - - - -? - -, - •? -? -? - -? - - - - - - -?? - - -? -? - - - • -? - - -,, - -? - • - - -
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.76
Prediction: .. |, —|.S. — —. - - - - —, - —ates versa:// — - -, - - - - - - |://, —. —.swing hasn:// —.swing:// — influential importantly, - -:// -eg —. —g:// - — -://..log:// -:// —.swing://. acclaimed://.:// — acclaimed.util.swing:// - intelligent versa importantly:// - informed.ComponentPlacement educated
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8619528619528619
Prediction: .,.. —. —.,,,, — -...images.S. - —. —://.eg importantly.logishment versa informed influential importantly.swing versa versa appreciated://, — extensively able - —.eri.S.Segeg —, -er://.. |://,...://. acclaimed:// —the://..log://..Sise.log.swing accused versa.swing://. versa.awt
    A

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  as
,,?,??,, -? a? a -? -,??????????, -????????,, -,????? as?, a? -, the?? -,???? as????? and as??,?,??? and? -.
? or -?
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.8166666666666667
Prediction: ,, — -. | the,, — -, the -,.,... the, —.,,, —inperson.swing зрения.log informed.swing amazing influential:// — importantly://.!..://,...://.. -. the.://,. -.,of://,remnonthe versa://, - -the:// -://,the-old acclaimed versa Angeles:// the -:// -
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8278688524590164
Prediction:  the.. the. —, - the an - -, the.. the. |.., —..idour.swing influential.swing able educated://g.util.S://..
.://,. | -. -the://, the,.., the the the the the,..the -:// |:// j| the,:// a/ the the |://,://. —:// | versa
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 0.8018433179723502
Predictio

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: 
 and part and? a
, as as a as of? people? a as as? a? a a?????? of of??? and,?? and? as and?? a
 in as? and the? and and as?? or as an??? and a??? and for a a
 all as as
 to as as a a the part as a?

    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.725
Prediction:  — — — — -, — the —, — the the the, — — — —the://,, — —end:// -, —um:// —:// —. —:// -. the —://,the:// — the - |:// - |gr? — | the the the - the athe. —. —.util:// the — | the -:// the —:// the the —://, the the the —.swing
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8258928571428571
Prediction:  — — and. - |,. a the the the the the j.,,, —,,., — –li., —., the -Ã://,, -,, of the |., - | the an the a the, - the the the the and the,.,.,.,. the |.
 —://|. —.
 the the | | - —:// and
    Answer:             property image # - beautiful new log cabin nestled in person          
 Normed ED: 0.7688172043

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  a for as
 with as in with the the as ofeg b in all is as an as in b
 a a
 in as a - of as and as  a as as as in as a as and all in a as the a in the the as a or the as b as a as to b an a as all as of as and as b as to the an as and the as the to 2 to

    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7509881422924901
Prediction:  - — — - -, | | -, — the — | — | - — |, | | — —, | —g -sw://://.swing Angelessp.thethe. thesp.log importantlyperson://-old:// — — |.| the | the the ||sm – -people thethe|://the —..S Angeles://.S theof acclaimed.swing.log://person://ishment Angeles://://inperson:// Arabia
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8450184501845018
Prediction:  — | | | the, the | theerihave the - the, | |in — |, –, | – –.
 |.
||g.
 -Ã.S.
 – -eri://eri.log. – – |.
eri | the, the the – the – ||.
 |...
eri the the://. the the.://, versa...
 |!:// the an

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  as of and a part as as the is as  the in the and a as with as as or a, as and as as as an in to and as with of as all as as as as as the as on as the as as as as a as as a the is at with as an b as as or as with as
 a as as with with with as of as as as as or with with as b
 as
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7670250896057348
Prediction:  the - to, – |presthe, the the the the | the the - the - — - -, | -, —periodinsp://the.spum the the - and, conducted..| the — the and the thethethe the the the the the andthe, theth thethe thethe the the theof://.log://the the..://the the. acclaimed://.S the the the.S.
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8327137546468402
Prediction:  and the ( in | the - the the the the - the the.. and - - and, by, the - the --.
 and the |. –. -.the the lastg.
 ( ( the - the | and an ( the the the |the the the – an 

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:   a
 as a little
 in in as
 aseg as, but  with as of a and to in of and  with  a, as  a  or  on
 as with as as as in the in as the as as as as as as as in to a and b a to in fashion b a as  as in  as a  the as as with in as and and as of
 or a
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7366255144032922
Prediction: ,,, the,,| —, the, the the -, the, |, - and — | — -,,person,thein hundredthethein://the the the.theintÃ.non://the - the,thefilm; the the the the thethe the theintthethe theintthe theleg://://://. thein:// Angeles.
the thethethe the.spouslin://://
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8414634146341463
Prediction:  in the | | | the the the the the the the the. the.
, – | a|  as - and – fromg. - Ã/ the the | the. -, the - - | - the, | – and the an, the the the the the the and, last, | the - and the..
 the the.the|...
 the. the..
 the the|.

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  as bent a in a with a as on a of as a a a weekend as at at with foreign but
 as as a more and a to  in as is b a
 with a as
 a
 on as to in as as on a little as a  in at to as to is as the a person in the a in a to a with a an but as a as and as as at to as in as
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7424242424242424
Prediction:  the the the the a the the the the the the the the the the the the the | the, the. the in, and | the and the and, thethe thethe the the the thethe theperson. the — the the thethe the the the the the the the the the the the the the the the thethe the the theperson| the the the theperson the the the thethethe the the the,the
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8641975308641975
Prediction:  and the, the, the the the the the the the the the in the|       —        for -- last — last —s - last in the an the - the,the | 

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction:  the
 as a
 with as in as in as weekend of  and as in b more as with to  to  at with as more in of with as b a breakdown with
 the as  the a with but at at as as  as or an as as with as little as a powers in as b a an the for
 as more in the c as more as as as a as as and with as or or all
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.7655172413793103
Prediction:  — the, the the |insthe the the the the the the the the - the the |sex —— — the forthethe thepres influential thebusiness://thethe thethe.leg://person.swing. -the the the the theremthe the | the the the thepig - thetheins. the. thethis!the!the the theth versa the thethe the.://it://thethe the the the
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8504983388704319
Prediction:  as as in in the - the the the the the - - - - - —. - -    - and -- — — the - --spthe the – –the.
 - | - - the - | | — – as in

Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: 
 F
 an but as or a
 with aepent on but at by as with 191 b as is  ```
  or extra 198 in more as a as as
   more as
 from at b as as as  or a to as a to forum with b a but foreign at as with as as a 3 as as with as all as as on to as fashion as or as as
 in 
    Answer:             a man cooking up dishes at a stall at the night food market in capital          
 Normed ED: 0.751937984496124
Prediction: , the the, the it an, the the the the the theins the the the, but |person - - - | the the thetheremaaaaâthe the thethethethe thethe thethe the the the the the theremthe | the the the the thethe the thethe thethethethe,allyÃ thethe thethe thethethe the the thethe the thethe the thethe the thethe
    Answer:             cute baby elephant walks towards the camera          
 Normed ED: 0.8542372881355932
Prediction: , it, - in, the the the the the the the, - --.
 | | - – –, - the - - fromlinthe. the the the.inÃthe, | the - the - - ( |. | – the an the the the the the the - – | th

In [None]:
torch.save(model_module.projection_module.state_dict(), "/content/drive/MyDrive/model/llama3_multi.bin")

## metric

In [None]:
# model_module.load_state_dict(torch.load("/content/drive/MyDrive/model/llama3_multi.bin"), strict=False)

In [None]:
model_module.to("cuda")

In [None]:
model_module.device

In [None]:
dd = DataLoader(train_dataset, collate_fn=data_collator, batch_size=2, shuffle=True, num_workers=0)
batch = next(iter(dd))

# batch = train_dataset[0]
input_ids = batch['input_ids'].to(device=model_module.device) # value long
pixel_values = batch['pixel_values'].to(device=model_module.device)

In [None]:
attention_mask = batch['attention_mask']

In [None]:
pixel_values.shape

In [None]:

# with autocast():
image_forward_outs = model_module.vision_model(
    pixel_values.to(device=model_module.device), #.unsqueeze(0),
    output_hidden_states=True,
) # value float16


In [None]:
image_forward_outs.last_hidden_state.shape

In [None]:
image_features = image_forward_outs.hidden_states[-2]
projected_embeddings = model_module.projection_module(image_features).to(model_module.device) # module float32 + value float16


In [None]:
projected_embeddings.shape

In [None]:
embedding_layer = model_module.model.get_input_embeddings()

In [None]:
input_ids.shape

In [None]:
total_ids = []
total_attn = []

for i in range(input_ids.shape[0]):
  print(i)
  input = input_ids[i].unsqueeze(0)
  attention = attention_mask[i].unsqueeze(0)
  image = projected_embeddings[i].unsqueeze(0)

  if not isinstance(input, torch.Tensor):
      input = torch.tensor(input)

  split_index = (input == 500000).nonzero(as_tuple=True)[1]

  input_ids_1 = input[:, :split_index]
  input_ids_2 = input[:, split_index + 1 :]

  # Convert input_ids to embeddings
  embeddings_1 = embedding_layer(input_ids_1)
  embeddings_2 = embedding_layer(input_ids_2)

  device = image.device
  token_embeddings_part1 = embeddings_1.to(device)
  token_embeddings_part2 = embeddings_2.to(device)

  # Concatenate the token embeddings and image features
  concatenated_embedding = torch.cat(
      [token_embeddings_part1, image, token_embeddings_part2], dim=1
  )

  attention_mask_1 = attention[:, :split_index]
  attention_mask_2 = attention[:, split_index + 1 :]
  image = torch.ones(
      image.shape[:2], dtype=torch.long
  )
  idevice = image.device
  cat_attention_mask = torch.cat(
      [attention_mask_1.to(idevice), image, attention_mask_2.to(idevice)], dim=1
  )
  print(cat_attention_mask)
  # # Create the corrected attention mask
  # attention_mask = torch.ones(
  #     concatenated_embedding.shape[:2], dtype=torch.long
  # )

  total_ids.append(concatenated_embedding)

  total_attn.append(cat_attention_mask)

concatenated_embeddings = torch.cat(total_ids, dim=0)
cat_attention_masks = torch.cat(total_attn, dim=0)

In [None]:
cat_attention_mask

In [None]:
split_index = (input_ids == 500000).nonzero(as_tuple=True)[1]

In [None]:
(input_ids == 500000).nonzero(as_tuple=True)[1]

In [None]:
input_ids_1 = input_ids[:, :split_index]
input_ids_2 = input_ids[:, split_index + 1 :]

# Convert input_ids to embeddings
embeddings_1 = embedding_layer(input_ids_1)
embeddings_2 = embedding_layer(input_ids_2)


In [None]:
device = image_features.device
token_embeddings_part1 = embeddings_1.to(device)
token_embeddings_part2 = embeddings_2.to(device)

In [None]:
# Concatenate the token embeddings and image features
concatenated_embeddings = torch.cat(
    [token_embeddings_part1, projected_embeddings, token_embeddings_part2], dim=1
)

In [None]:
# Create the corrected attention mask
attention_mask = torch.ones(
    concatenated_embeddings.shape[:2], dtype=torch.long
)

In [None]:


# 배치 데이터 가능
new_embeds,attn_mask = process_tensors(
    input_ids, projected_embeddings, embedding_layer
) # value float16

In [None]:
attn_mask.device

In [None]:
split_index = (input_ids == 500000).nonzero(as_tuple=True)[1][0]

input_ids_1 = input_ids[:, :split_index]
input_ids_2 = input_ids[:, split_index + 1 :]

device = image_features.device
pbatch = image_features.shape[0]
pseq = image_features.shape[1]
image_token = torch.full([pbatch,pseq,], -100, dtype=torch.long).to(model_module.device)



In [None]:
input_ids_2.device

In [None]:
# Concatenate the token embeddings and image features
concatenated_embeddings = torch.cat(
    [input_ids_1, image_token, input_ids_2], dim=1
)

In [None]:
labels = process_labels(
    input_ids, projected_embeddings
) # value float16



In [None]:
attn_mask = attn_mask.to(model_module.device)
new_embeds = new_embeds.to(model_module.device)

In [None]:

labels = labels.to(model_module.device)

outputs = model_module.model(inputs_embeds=new_embeds, attention_mask=attn_mask) # module float32 + value float16
logits = outputs.logits


In [None]:
dd = DataLoader(val_dataset, collate_fn=data_collator2, batch_size=2, shuffle=True, num_workers=0)
batch = next(iter(dd))
device = model_module.device


input_ids = batch['input_ids'].to(device) # value long
pixel_values = batch['pixel_values'].to(device)
attention_mask = batch['attention_mask']
answers = batch['answers']

In [None]:
# with autocast():
image_forward_outs = model_module.vision_model(
    pixel_values.to(device=device), #.unsqueeze(0),
    output_hidden_states=True,
) # value float16

image_features = image_forward_outs.hidden_states[-2]
projected_embeddings = model_module.projection_module(image_features).to(device) # module float32 + value float16

embedding_layer = model_module.model.get_input_embeddings()

# 배치 데이터 가능
new_embeds, attn_mask = process_tensors(
    input_ids, attention_mask, projected_embeddings, embedding_layer
) # value float16

attn_mask = attn_mask.to(device)
new_embeds = new_embeds.to(device)



In [None]:
attn_mask.shape

In [None]:
new_embeds.shape

In [None]:
# autoregressively generate token IDs
generated_ids = model_module.model.generate(inputs_embeds=new_embeds.to(dtype=torch.float16), attention_mask=attn_mask, max_new_tokens=128)
# turn them back into text, chopping of the prompt
# important: we don't skip special tokens here, because we want to see them in the output
predictions = tokenizer.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)

In [None]:
text = "What is transformer attention?"
input= tokenizer([text], return_tensors="pt")
embed = embedding_layer(input['input_ids'].to(model_module.device))

In [None]:
generate_ids  = model_module.model.generate(inputs_embeds = embed.to(dtype=torch.float16) , attention_mask = input['attention_mask'], max_new_tokens=512)

In [None]:
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]