In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel, prepare_model_for_kbit_training
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

base_model_id = "Salesforce/blip2-flan-t5-xl"
trained_model_id = "./model/finetuned-bilp2-flan-t5-xl"

# 4bit quantization for decoder only
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load processor 
processor = Blip2Processor.from_pretrained(base_model_id, use_fast=True)

# t5 model만 4bit로 Load
t5 = T5ForConditionalGeneration.from_pretrained(
    "google/flan-t5-xl",  # BLIP2가 내부적으로 사용하는 T5 모델
    device_map="auto",
    # device_map={"": 0},  # 모든 모듈을 GPU 0번으로
    quantization_config=quantization_config
)
 
# BLIP2 모델을 float32로 로딩
model_fp = Blip2ForConditionalGeneration.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    # device_map={"": 0},  # 모든 모듈을 GPU 0번으로
)

# decoder만 4bit T5로 교체
model_fp.language_model = t5

# LoRA fine-tuned weight 적용
model = PeftModel.from_pretrained(model_fp, trained_model_id)


In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")

count_parameters(model)


In [None]:
import os
import re
import torch
import warnings
import pandas as pd
from PIL import Image
from tqdm import tqdm

# Setting
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# helper function for extracting answer letter [A/B/C/D]
def extract_answer_letter(text):
    # match = re.search(r"Answer:\s*([A-Da-d])\b", text)
    match = re.search(r"\b([A-D])\b", text)
    return match.group(1).upper() if match else "?"

In [None]:
test_dataset_dir = './dataset/given/'
test = pd.read_csv(os.path.join(test_dataset_dir, 'test.csv'))
results = []

for i, row in tqdm(test.iterrows(), total=len(test)):
    image = Image.open(os.path.join(test_dataset_dir, row['img_path'])).convert("RGB")

    ### Step 1: Description 생성 ###
    desc_prompt = (
        "USER: Based on the image and question, write a description.\n"
        f"Question: {row['Question']}\n\n"
        "Description:\n"
        "ASSISTANT:"
    )

    inputs = processor(images=image, text=desc_prompt, return_tensors="pt")
    inputs = {k: (v.half().to(device) if v.dtype == torch.float32 else v.to(device)) for k, v in inputs.items()}

    output = model.generate(**inputs, max_new_tokens=128, do_sample=False) #, temperature=0.0)
    generated_description = processor.tokenizer.decode(output[0], skip_special_tokens=True).strip()

    print(f"\n[Step 1] Generated Description: {generated_description}")
    
    ### Step 2: 선택지 포함 프롬프트 구성 후 추론 ###
    final_prompt = (
        "USER: Based on the image, description, and question, choose the best option from A, B, C, or D.\n"
        f"Description: {generated_description}\n"
        f"Question: {row['Question']}\n"
        f"A. {row['A']}\n"
        f"B. {row['B']}\n"
        f"C. {row['C']}\n"
        f"D. {row['D']}\n\n"
        "Answer:"
    )

    inputs = processor(images=image, text=final_prompt, return_tensors="pt")
    inputs = {k: (v.half().to(device) if v.dtype == torch.float32 else v.to(device)) for k, v in inputs.items()}

    output = model.generate(**inputs, max_new_tokens=3, do_sample=False) #, temperature=0.0)
    decoded = processor.tokenizer.decode(output[0], skip_special_tokens=True).strip()

    print(f"[Step 2] Final Answer Prediction: {decoded}")
    print("==========================================================")


    results.append(extract_answer_letter(decoded))


In [None]:
submission_base_file_path = './dataset/given/sample_submission.csv'
submission_save_file_path = './test_inference_final.csv'

submission = pd.read_csv(submission_base_file_path)
submission['answer'] = results
submission.to_csv(submission_save_file_path, index=False)
print("Done.")