In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from transformers import T5ForConditionalGeneration
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
import torch

base_model_id = "Salesforce/blip2-flan-t5-xl"

# Processor
processor = Blip2Processor.from_pretrained(base_model_id, use_fast=True)

# Model
model = Blip2ForConditionalGeneration.from_pretrained(
    base_model_id,
    load_in_8bit=True,
    device_map="auto",
)

# 4bit 양자화된 모델에 LoRA 적용을 위한 준비 
model = prepare_model_for_kbit_training(model)

# LoRA 설정 정의
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["query", "key", "value", "dense"],  # BLIP2 Q-Former only 
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM  # flan-t5 기반은 SEQ_2_SEQ
)

# LoRA 모델 생성 (원래 모델 위에 adapter layer 삽입)
model = get_peft_model(model, lora_config)

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")

count_parameters(model)

In [None]:
import os
import torch
import pandas as pd
from PIL import Image
from datasets import Dataset
from transformers import Blip2Processor

# Load train dataset
train_data_dir = "./dataset/generated/"
df = pd.read_csv(os.path.join(train_data_dir, "question_answer.csv"))

# Build prompt
def build_prompt(row):
    return (
        "USER: Based on the image, write a description and create a multiple-choice question with four options (A, B, C, D).\n"
        "Answer the question by selecting the best option from A, B, C, or D.\n"
        "Respond only with a single letter: A, B, C, or D.\n"
        "Follow this exact format:\n\n"
        f"Question: {row['Question']}\n"
        f"A. {row['A']}\n"
        f"B. {row['B']}\n"
        f"C. {row['C']}\n"
        f"D. {row['D']}\n\n"
        "Description:\n"
        "Answer:\n\n"
        "ASSISTANT:"
    )

df["prompt"] = df.apply(build_prompt, axis=1)

# 출력 텍스트 (Description + Answer) 만들기
def build_target(row):
    return (
        f"Description: {row['Description']}\n"
        f"Answer: {row['answer']}"
    )

df["target"] = df.apply(build_target, axis=1)

# Dataset으로 변환
dataset = Dataset.from_pandas(df)

# 전처리 함수 정의
def preprocess(example):
    image = Image.open(example["img_path"]).convert("RGB")
    inputs = processor(
        text=example["prompt"],
        images=image,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=384,   # 128 to 256~384
    )
    labels = processor.tokenizer(
        example["target"],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=128,
    ).input_ids
    return {
        "input_ids": inputs["input_ids"][0].to(torch.long),
        "attention_mask": inputs["attention_mask"][0].to(torch.long),
        "pixel_values": inputs["pixel_values"][0].to(torch.float32),
        "labels": torch.tensor(labels[0], dtype=torch.long),
    }

# 전처리 적용
processed_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)

In [None]:
dataset[0]

In [None]:
print(dataset.column_names)
print(processed_dataset.column_names)

In [None]:
vocab_size = processor.tokenizer.vocab_size

def sanitize_labels(example):
    example["labels"] = [
        token if 0 <= token < vocab_size else -100
        for token in example["labels"]
    ]
    return example

processed_dataset = processed_dataset.map(sanitize_labels)

def zero_to_ignore(example):
    example["labels"] = [-100 if token == 0 else token for token in example["labels"]]
    return example

processed_dataset = processed_dataset.map(zero_to_ignore)

In [None]:
def cast_attention_mask(example):
    if "attention_mask" in example:
        example["attention_mask"] = torch.tensor(example["attention_mask"]).to(torch.float32)
    if "decoder_attention_mask" in example:
        example["decoder_attention_mask"] = torch.tensor(example["decoder_attention_mask"]).to(torch.float32)
    return example

processed_dataset = processed_dataset.map(cast_attention_mask)

In [None]:
from transformers import Trainer

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        inputs.pop("num_items_in_batch", None)

        outputs = model(**inputs)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./model/finetuned-blip2-flan-t5-xl",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    num_train_epochs=5,
    logging_steps=10,
    save_steps=200,
    learning_rate=5e-5,
    save_total_limit=3,
    fp16=True,
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    tokenizer=processor.tokenizer,
)

trainer.train()

In [None]:
# LoRA 만 저장
trainer.save_model()  # 모델 저장
processor.tokenizer.save_pretrained(training_args.output_dir)  # tokenizer 저장