In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
    InstructBlipForConditionalGeneration,
    BlipImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    InstructBlipQFormerConfig,
    InstructBlipQFormerModel
)
from datasets import load_dataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast
import os
import random
import numpy as np

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --- 기본 설정 ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"사용 디바이스: {device}")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

In [None]:
# --- 1. Pre-training과 동일한 모델 아키텍처 재현 ---
# SFT를 시작하기 전, state_dict를 로드할 빈 모델 껍데기를 먼저 만듭니다.
# 이 구조는 Pre-training을 저장할 때와 정확히 일치해야 합니다.

print("--- SFT 모델 아키텍처 재현 시작 ---")

# 1-1) 베이스 모델 로드 (Vision Encoder, Q-Former Config, Tokenizer 등 추출용)
print("1) 베이스 모델 로드 (Hugging Face Hub)")
base_blip_model = InstructBlipForConditionalGeneration.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl",
    torch_dtype=dtype,
    use_safetensors=True
)
base_llm_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen-1_8B-Chat",
    torch_dtype=dtype,
    trust_remote_code=True,
    use_safetensors=True
)
llm_tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen-1_8B-Chat",
    trust_remote_code=True
)
image_processor = BlipImageProcessor.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl"
)

# 1-2) Pre-training때와 동일한 컴포넌트 생성
print("2) Pre-training과 동일한 모델 컴포넌트 생성 중...")
# Vision Encoder 추출
vision_model = base_blip_model.vision_model

# Query Tokens 추출
query_tokens = base_blip_model.query_tokens

# Q-Former 설정을 Pre-training때와 같이 10개 레이어로 축소
qformer_config = base_blip_model.qformer.config
qformer_config.num_hidden_layers = 10
shrunken_qformer = InstructBlipQFormerModel(qformer_config)

# LLM 모델은 그대로 사용
llm_model = base_llm_model

# Projector 생성
qformer_hidden_size = shrunken_qformer.config.hidden_size
llm_hidden_size = llm_model.config.hidden_size
image_proj = nn.Linear(qformer_hidden_size, llm_hidden_size)
print("모델 로딩 완료")


# 메모리 정리
del base_blip_model, base_llm_model

In [None]:
# 1-3) Pre-training때와 동일한 최종 모델 클래스 정의
class VisionLLM_QFormer(nn.Module):
    def __init__(self, vision_model, qformer, query_tokens, image_proj, llm_model):
        super().__init__()
        self.vision_model = vision_model
        self.qformer = qformer
        self.image_proj = image_proj
        self.llm_model = llm_model
        self.query_tokens = query_tokens
        # 그래디언트 체크포인팅 활성화
        self.vision_model.gradient_checkpointing_enable()
        self.llm_model.gradient_checkpointing_enable()
        self.qformer.gradient_checkpointing_enable()

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        image_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True).last_hidden_state
        image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long, device=image_embeds.device)
        expanded_query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_shape = expanded_query_tokens.shape
        dummy_query_input_ids = torch.zeros((query_shape[0], query_shape[1]), dtype=torch.long, device=device)
        query_outputs = self.qformer(
            input_ids=dummy_query_input_ids,
            query_embeds=expanded_query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True
        ).last_hidden_state
        projected_feats = self.image_proj(query_outputs)
        input_embeds = self.llm_model.get_input_embeddings()(input_ids)
        full_embeds = torch.cat([projected_feats, input_embeds], dim=1)
        full_attention_mask = torch.cat([
            torch.ones(projected_feats.shape[:2], dtype=attention_mask.dtype, device=attention_mask.device),
            attention_mask
        ], dim=1)
        full_labels = None
        if labels is not None:
            image_labels = torch.full((projected_feats.shape[0], projected_feats.shape[1]), -100, dtype=torch.long, device=labels.device)
            full_labels = torch.cat([image_labels, labels], dim=1)
        return self.llm_model(
            inputs_embeds=full_embeds,
            attention_mask=full_attention_mask,
            labels=full_labels,
            return_dict=True
        )

# --- 3. SFT용 데이터셋 및 Collate 함수 정의 ---
class SFTStreamingMultiTurnDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, image_processor, max_length=1024):
        super().__init__()
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.eos_token_id = self.tokenizer.eos_token_id

    def __iter__(self):
        for sample in self.hf_dataset:
            try:
                if 'image' not in sample or 'conversations' not in sample: continue
                
                image = sample["image"].convert("RGB")
                pixel_values = self.image_processor(image, return_tensors="pt")["pixel_values"]
                
                full_input_ids = []
                full_labels = []

                for conv in sample["conversations"]:
                    role = conv.get("from")
                    value = conv.get("value")

                    if not value or not isinstance(value, str): continue
                    
                    # <image> 토큰을 명시적인 지시어로 변경
                    if "<image>" in value:
                        value = value.replace("<image>", "").strip()
                        # 이미지는 항상 맨 앞에 위치하므로, <image> 토큰에 해당하는 지시어 추가
                        full_input_ids.extend(self.tokenizer("<image>\n").input_ids)
                        full_labels.extend([-100] * len(self.tokenizer("<image>\n").input_ids))

                    if role == 'human':
                        human_tokens = self.tokenizer(value).input_ids
                        full_input_ids.extend(human_tokens)
                        full_labels.extend([-100] * len(human_tokens))
                    
                    elif role == 'gpt':
                        gpt_tokens = self.tokenizer(value).input_ids
                        gpt_tokens.append(self.eos_token_id)
                        full_input_ids.extend(gpt_tokens)
                        full_labels.extend(gpt_tokens)

                if len(full_input_ids) > self.max_length:
                    full_input_ids = full_input_ids[:self.max_length]
                    full_labels = full_labels[:self.max_length]

                yield {
                    "pixel_values": pixel_values.squeeze(),
                    "input_ids": torch.tensor(full_input_ids, dtype=torch.long),
                    "labels": torch.tensor(full_labels, dtype=torch.long)
                }
            except Exception as e:
                print(f"!!! 데이터 처리 중 에러 발생, 샘플 건너뜁니다. Error: {e}")
                continue

def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    input_ids = [item["input_ids"] for item in batch]
    labels = [item["labels"] for item in batch]
    padded_input_ids = llm_tokenizer.pad({"input_ids": input_ids}, padding='longest', return_tensors="pt")
    padded_labels = llm_tokenizer.pad({"input_ids": labels}, padding='longest', return_tensors="pt").input_ids
    padded_labels[padded_labels == llm_tokenizer.pad_token_id] = -100
    return {"pixel_values": pixel_values, "input_ids": padded_input_ids.input_ids, "attention_mask": padded_input_ids.attention_mask, "labels": padded_labels}

# share 전략으로 Pretraining. Vision Encoder Layer 절반 동결, 나머지 전체 학습
def freeze_for_pretraining(vision_model, num_freeze_layers=20):
    print("[학습 전략] Pre-training을 위한 모델 동결을 시작합니다.")
    print("   - Vision Encoder의 Embedding 레이어 동결 중...")
    for param in vision_model.embeddings.parameters(): param.requires_grad = False
    print(f"   - Vision Encoder의 앞 {num_freeze_layers}개 Transformer 레이어 동결 중...")
    for layer in vision_model.encoder.layers[:num_freeze_layers]:
        for param in layer.parameters(): param.requires_grad = False
    print("동결/활성화 설정 완료.")

In [None]:
# 1-4) 최종 모델 '껍데기' 생성
multimodal_model = VisionLLM_QFormer(
    vision_model, shrunken_qformer, query_tokens, image_proj, llm_model
).to(device, dtype=dtype)

print("SFT 모델 아키텍처 재현 완료.")

# --- 2. 사전 학습된 가중치 로드 ---
PRETRAINED_MODEL_PATH = "share_instructblip_qwen_pt.pth"
if os.path.exists(PRETRAINED_MODEL_PATH):
    print(f"\n사전 학습된 모델 가중치를 로드합니다: {PRETRAINED_MODEL_PATH}")
    multimodal_model.load_state_dict(torch.load(PRETRAINED_MODEL_PATH, map_location=device))
else:
    print(f"경고: 사전 학습된 모델 파일({PRETRAINED_MODEL_PATH})을 찾을 수 없습니다. 랜덤 가중치로 SFT를 시작합니다.")


# Qwen 토크나이저의 pad_token 설정
if llm_tokenizer.pad_token is None:
    print("\npad_token이 설정되지 않았습니다. Qwen 모델에 맞게 수동으로 설정합니다.")
    if llm_tokenizer.eos_token is None:
        llm_tokenizer.eos_token = '<|endoftext|>'
    llm_tokenizer.pad_token = llm_tokenizer.eos_token
    
print("--- 토크나이저 설정 확인 ---")
print(f"pad_token: {llm_tokenizer.pad_token}, pad_token_id: {llm_tokenizer.pad_token_id}")
assert llm_tokenizer.pad_token_id == 151643, "pad_token_id가 올바르게 설정되지 않았습니다!"
print("토크나이저 설정이 올바릅니다.")


# SFT에서는 Vision Encoder 앞 30개 Layer 동결
freeze_for_pretraining(
    vision_model=multimodal_model.vision_model,
    num_freeze_layers=30
)

In [None]:
# --- 4. SFT 학습 실행 ---
# 하이퍼파라미터
NUM_EPOCHS = 1
PER_DEVICE_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_SAMPLES = 645000
SAVE_CHECKPOINT_STEPS = 3

print("\nLLaVA-NeXT-Data 멀티턴 데이터셋을 스트리밍합니다...")
streaming_dataset = load_dataset("lmms-lab/LLaVA-NeXT-Data", split="train", streaming=True)
train_dataset = SFTStreamingMultiTurnDataset(streaming_dataset, llm_tokenizer, image_processor, max_length=512)
train_dataloader = DataLoader(train_dataset, batch_size=PER_DEVICE_BATCH_SIZE, collate_fn=collate_fn)

print("\nSFT에서는 Vision Encoder 앞 30개 Layer만 동결, 나머지 전체 학습")
trainable_params = [p for p in multimodal_model.parameters() if p.requires_grad]
total_trainable_params = sum(p.numel() for p in trainable_params)
print(f"총 학습 가능 파라미터 수: {total_trainable_params / 1_000_000:.2f}M")

optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE, eps=1e-6)

loss_history = []
print("\nShare_InstructBlip_Qwen SFT 시작")
effective_batch_size = PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
total_steps = NUM_TRAIN_SAMPLES // effective_batch_size
multimodal_model.train()
progress_bar = tqdm(range(total_steps), desc=f"Epoch {1} (SFT)")
completed_steps = 0
trained_sample_count = 0

for step, batch in enumerate(train_dataloader):
    if completed_steps >= total_steps:
        print("목표 스텝에 도달하여 학습을 조기 종료합니다.")
        break
    
    pixel_values = batch["pixel_values"].to(device, dtype=dtype)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with autocast(dtype=dtype):
        outputs = multimodal_model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS

    if torch.isnan(loss) or torch.isinf(loss):
        print(f"경고: 스텝 {step}에서 Loss가 NaN 또는 Inf입니다. 이 배치의 학습을 건너뜁니다.")
        optimizer.zero_grad()
        continue

    loss.backward()

    if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
        torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        completed_steps += 1
        trained_sample_count += effective_batch_size
        current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS
        loss_history.append(current_loss)
        
        progress_bar.update(1)
        progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})

        if SAVE_CHECKPOINT_STEPS > 0 and (completed_steps % SAVE_CHECKPOINT_STEPS == 0):
            checkpoint_path = f"checkpoint_share_instructblip_qwen_sft_step_{completed_steps}.pth"
            torch.save(multimodal_model.state_dict(), checkpoint_path)
            print(f"SFT Checkpoint 저장 완료: {checkpoint_path}")
            
            
progress_bar.close()
print("\nShare_InstructBlip_Qwen SFT 완료! 최종 모델을 저장합니다...")
torch.save(multimodal_model.state_dict(), "share_instructblip_qwen_sft.pth")
print("최종 SFT 모델 저장 완료! (share_instructblip_qwen_sft.pth)")

# SFT 학습 결과 요약
print("\n" + "="*50)
print("SFT 학습 결과 요약")
print(f"  - 완료된 스텝 수 (Updates): {completed_steps}")
print(f"  - 실제 학습에 사용된 총 샘플 수: {trained_sample_count:,}")
print("="*50)

if loss_history:
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history)
    plt.title("SFT Loss Curve (share_instructblip_qwen_sft)")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.savefig("share_instructblip_qwen_sft_loss_curve.png")
    print("SFT 학습 곡선 이미지를 'share_instructblip_qwen_sft_loss_curve.png'로 저장했습니다.")
    plt.show()