In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
    InstructBlipForConditionalGeneration,
    BlipImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    InstructBlipQFormerConfig,
    InstructBlipQFormerModel
)
from datasets import load_dataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast
import random
import numpy as np

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --- 기본 설정 ---
# 사용 가능한 경우 GPU 사용, 그렇지 않으면 CPU 사용
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"사용 디바이스: {device}")
# 혼합 정밀도(Mixed Precision) 학습을 위한 데이터 타입 설정
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

In [None]:
# --- 모델 로딩 ---

# 1) InstructBLIP 모델 로드 후 Vision Encoder 및 Q-Former 추출
print("1) InstructBLIP 모델 로드 후 Vision Encoder 및 Q-Former 추출 중... (Salesforce/instructblip-flan-t5-xl)")
full_blip_model = InstructBlipForConditionalGeneration.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl",
    torch_dtype=dtype,
    use_safetensors=True
).to(device)

# Vision Encoder 추출
vision_model = full_blip_model.vision_model

# query_tokens를 모델 삭제 전에 별도로 추출합니다.
print("   - Q-Former를 위한 query_tokens 추출 중...")
query_tokens = full_blip_model.query_tokens.to(device, dtype=dtype)

# Q-Former 설정을 수정하여 레이어 축소 (12 -> 10)
print("   - Q-Former 설정을 10개 레이어로 축소 중...")
qformer_config = full_blip_model.qformer.config
qformer_config.num_hidden_layers = 10

# 축소된 설정으로 새로운 Q-Former 모델 생성
shrunken_qformer = InstructBlipQFormerModel(qformer_config).to(device, dtype=dtype)

# 사전 학습된 가중치에서 처음 10개 레이어의 가중치를 새로운 Q-Former로 복사
print("   - 사전 학습된 Q-Former 가중치를 축소된 모델로 로드 중...")
shrunken_qformer.load_state_dict(full_blip_model.qformer.state_dict(), strict=False)

# Q-Former를 포함한 전체 BLIP 모델은 더 이상 필요 없으므로 메모리에서 해제
del full_blip_model

# Image Processor 로드
image_processor = BlipImageProcessor.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl"
)

# 2) LLM (Qwen) 및 Tokenizer 로드
print("2) Language Model & Tokenizer 로딩 중... (Qwen/Qwen-1_8B-Chat)")
llm_tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen-1_8B-Chat",
    trust_remote_code=True
)
llm_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen-1_8B-Chat",
    torch_dtype=dtype,
    trust_remote_code=True,
    use_safetensors=True
).to(device)
print("모델 로딩 완료")


# Qwen 토크나이저의 pad_token 설정
if llm_tokenizer.pad_token is None:
    print("pad_token이 설정되지 않았습니다. Qwen 모델에 맞게 수동으로 설정합니다.")
    if llm_tokenizer.eos_token is None:
        llm_tokenizer.eos_token = '<|endoftext|>'
    llm_tokenizer.pad_token = llm_tokenizer.eos_token

print("--- 토크나이저 설정 확인 ---")
print(f"pad_token: {llm_tokenizer.pad_token}")
print(f"pad_token_id: {llm_tokenizer.pad_token_id}")
assert llm_tokenizer.pad_token_id == 151643, "pad_token_id가 올바르게 설정되지 않았습니다!"
print("토크나이저 설정이 올바릅니다.")


# --- 아키텍처 정의 (Q-Former Connector & Multimodal Model) ---
print("\n--- 아키텍처 구성 (축소된 Q-Former 사용) ---")
vision_hidden_size = vision_model.config.hidden_size
qformer_hidden_size = shrunken_qformer.config.hidden_size
llm_hidden_size = llm_model.config.hidden_size

print(f"Vision Encoder Hidden Size: {vision_hidden_size}")
print(f"Q-Former Hidden Size: {qformer_hidden_size}")
print(f"LLM Hidden Size: {llm_hidden_size}")

image_proj = nn.Linear(qformer_hidden_size, llm_hidden_size).to(device, dtype=dtype)

print("Vision Encoder -> Q-Former -> MLP Projector -> LLM 구조 생성 완료")


In [None]:
# 파라미터 수 계산
vision_encoder_params = sum(p.numel() for p in vision_model.parameters())
qformer_params = sum(p.numel() for p in shrunken_qformer.parameters())
query_token_params = query_tokens.numel() # [수정] 쿼리 토큰 파라미터 추가
proj_params = sum(p.numel() for p in image_proj.parameters())
llm_params = sum(p.numel() for p in llm_model.parameters())
total_params = vision_encoder_params + qformer_params + query_token_params + proj_params + llm_params

print(f"\n--- 모델 파라미터 수 ---")
print(f"Vision Encoder: {vision_encoder_params / 1_000_000:.2f}M")
print(f"Shrunken Q-Former (8-layers): {qformer_params / 1_000_000:.2f}M")
print(f"Query Tokens  : {query_token_params / 1_000_000:.2f}M") # [수정] 쿼리 토큰 파라미터 수 출력
print(f"MLP Projector : {proj_params / 1_000_000:.2f}M")
print(f"Language Model: {llm_params / 1_000_000:.2f}M")
print(f"-------------------------")
print(f"전체 파라미터   : {total_params / 1_000_000:.2f}M")


In [None]:
class VisionLLM_QFormer(nn.Module):
    def __init__(self, vision_model, qformer, query_tokens, image_proj, llm_model):
        super().__init__()
        self.vision_model = vision_model
        self.qformer = qformer
        self.image_proj = image_proj
        self.llm_model = llm_model
        self.query_tokens = query_tokens

        self.vision_model.gradient_checkpointing_enable()
        self.llm_model.gradient_checkpointing_enable()
        self.qformer.gradient_checkpointing_enable()

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        image_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True).last_hidden_state
        image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long, device=image_embeds.device)
        expanded_query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)

        # [수정] Q-Former가 요구하는 더미 `input_ids`를 생성합니다.
        # 모양만 맞으면 되므로, 0으로 채운 텐서를 사용합니다.
        query_shape = expanded_query_tokens.shape
        dummy_query_input_ids = torch.zeros((query_shape[0], query_shape[1]), dtype=torch.long, device=device)
        
        # [수정] Q-Former 호출 시 더미 `input_ids`를 함께 전달합니다.
        query_outputs = self.qformer(
            input_ids=dummy_query_input_ids,
            query_embeds=expanded_query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True
        ).last_hidden_state

        projected_feats = self.image_proj(query_outputs)
        input_embeds = self.llm_model.get_input_embeddings()(input_ids)
        full_embeds = torch.cat([projected_feats, input_embeds], dim=1)

        full_attention_mask = torch.cat([
            torch.ones(projected_feats.shape[:2], dtype=attention_mask.dtype, device=attention_mask.device),
            attention_mask
        ], dim=1)

        full_labels = None
        if labels is not None:
            image_labels = torch.full((projected_feats.shape[0], projected_feats.shape[1]), -100, dtype=torch.long, device=labels.device)
            full_labels = torch.cat([image_labels, labels], dim=1)

        return self.llm_model(
            inputs_embeds=full_embeds,
            attention_mask=full_attention_mask,
            labels=full_labels,
            return_dict=True
        )


# --- 데이터셋 및 유틸리티 함수 ---
class PreTrainingStreamingDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, image_processor, max_length=512):
        super().__init__()
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length

    def __iter__(self):
        for sample in self.hf_dataset:
            try:
                if 'image' not in sample or 'conversations' not in sample: continue
                image = sample["image"].convert("RGB")
                pixel_values = self.image_processor(image, return_tensors="pt")["pixel_values"]
                conversations = sample.get("conversations")
                if not conversations or not isinstance(conversations, list): continue
                    
                # instruction = next((c.get("value") for c in conversations if c.get("from") == "human"), None)
                # response = next((c.get("value") for c in conversations if c.get("from") == "gpt"), None)
                
                # 'human'의 입력은 항상 '<image>'이므로, 이를 기반으로 새로운 instruction을 정의합니다.
                human_input = next((c.get("value") for c in conversations if c.get("from") == "human"), None)
                response = next((c.get("value") for c in conversations if c.get("from") == "gpt"), None)
                
                # human_input이 '<image>'인 경우에만 새로운 instruction을 생성합니다.
                if human_input and "<image>" in human_input:
                    # <image> 토큰을 맨 앞에 두고, 그 뒤에 명확한 지시어를 추가합니다. 줄바꿈(\n)은 좋은 구분자 역할을 합니다.
                    instruction = "<image>\nDescribe the image in detail." 
                else:
                    # 만약 다른 형태의 human_input이 있다면 원래대로 사용합니다. (안전장치)
                    instruction = human_input
                
                if not isinstance(instruction, str) or not isinstance(response, str) or not instruction or not response: continue
                full_text = instruction + response + self.tokenizer.eos_token
                full_tokens = self.tokenizer(full_text, truncation=True, max_length=self.max_length, return_tensors=None)
                full_input_ids = full_tokens.input_ids
                instruction_length = len(self.tokenizer(instruction).input_ids)
                labels = list(full_input_ids)
                for i in range(min(instruction_length, self.max_length)): labels[i] = -100
                yield {
                    "pixel_values": pixel_values.squeeze(),
                    "input_ids": torch.tensor(full_input_ids, dtype=torch.long),
                    "labels": torch.tensor(labels, dtype=torch.long)
                }
            except Exception as e:
                print(f"!!! 데이터 처리 중 예외 발생, 해당 샘플을 건너뜁니다. 오류: {e}, 샘플 ID: {sample.get('id', 'N/A')}")
                continue

def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    input_ids = [item["input_ids"] for item in batch]
    labels = [item["labels"] for item in batch]
    padded_input_ids = llm_tokenizer.pad({"input_ids": input_ids}, padding='longest', return_tensors="pt")
    padded_labels = llm_tokenizer.pad({"input_ids": labels}, padding='longest', return_tensors="pt").input_ids
    padded_labels[padded_labels == llm_tokenizer.pad_token_id] = -100
    return {"pixel_values": pixel_values, "input_ids": padded_input_ids.input_ids, "attention_mask": padded_input_ids.attention_mask, "labels": padded_labels}


# share 전략으로 Pretraining. Vision Encoder Layer 절반 동결, 나머지 전체 학습
def freeze_for_pretraining(vision_model, num_freeze_layers=20):
    print("[학습 전략] Pre-training을 위한 모델 동결을 시작합니다.")
    print("   - Vision Encoder의 Embedding 레이어 동결 중...")
    for param in vision_model.embeddings.parameters(): param.requires_grad = False
    print(f"   - Vision Encoder의 앞 {num_freeze_layers}개 Transformer 레이어 동결 중...")
    for layer in vision_model.encoder.layers[:num_freeze_layers]:
        for param in layer.parameters(): param.requires_grad = False
    print("동결/활성화 설정 완료.")

In [None]:
# --- 학습 실행 ---
NUM_EPOCHS = 1
PER_DEVICE_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_SAMPLES = 557800
SAVE_CHECKPOINT_STEPS = 1000

print("\nPre-training을 위한 스트리밍 데이터셋을 로드합니다...")
streaming_dataset = load_dataset("lmms-lab/LLaVA-ReCap-558K", split="train", streaming=True)
train_dataset = PreTrainingStreamingDataset(streaming_dataset, llm_tokenizer, image_processor, max_length=512)
train_dataloader = DataLoader(train_dataset, batch_size=PER_DEVICE_BATCH_SIZE, collate_fn=collate_fn)

multimodal_model = VisionLLM_QFormer(
    vision_model, shrunken_qformer, query_tokens, image_proj, llm_model
).to(device)

freeze_for_pretraining(
    vision_model=multimodal_model.vision_model,
    num_freeze_layers=20
)

trainable_params = [p for p in multimodal_model.parameters() if p.requires_grad]
total_trainable_params = sum(p.numel() for p in trainable_params)
print(f"총 학습 가능 파라미터 수: {total_trainable_params / 1_000_000:.2f}M")

total_params = [p for p in multimodal_model.parameters()]
total__params = sum(p.numel() for p in total_params)
print(f"전체 파라미터 수: {total__params / 1_000_000:.2f}M")


optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE, eps=1e-6)

loss_history = []
print("\nShare_InstructBlip_Qwen Pre-training 시작")
effective_batch_size = PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
total_steps = NUM_TRAIN_SAMPLES // effective_batch_size
multimodal_model.train()
progress_bar = tqdm(range(total_steps), desc=f"Epoch {1}")
completed_steps = 0
trained_sample_count = 0

for step, batch in enumerate(train_dataloader):
    if completed_steps >= total_steps:
        print("목표 스텝에 도달하여 학습을 조기 종료합니다.")
        break
    
    pixel_values = batch["pixel_values"].to(device, dtype=dtype)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with autocast(dtype=dtype):
        outputs = multimodal_model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS

    if torch.isnan(loss):
        print(f"경고: 스텝 {step}에서 Loss가 NaN입니다. 이 배치의 학습을 건너뜁니다.")
        optimizer.zero_grad()
        continue

    loss.backward()

    if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
        torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        completed_steps += 1
        trained_sample_count += effective_batch_size
        current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS
        loss_history.append(current_loss)
        
        progress_bar.update(1)
        progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})

        if SAVE_CHECKPOINT_STEPS > 0 and (completed_steps % SAVE_CHECKPOINT_STEPS == 0):
            checkpoint_path = f"checkpoint_share_instructblip_qwen_pt{completed_steps}.pth"
            torch.save(multimodal_model.state_dict(), checkpoint_path)
            print(f"Checkpoint 저장 완료: {checkpoint_path}")
            
progress_bar.close()
print("\nShare_InstructBlip_Qwen Pre-training 완료! 최종 모델을 저장합니다...")
torch.save(multimodal_model.state_dict(), "share_instructblip_qwen_pt.pth")
print("모델 저장 완료!")

print("\n" + "="*50)
print("학습 결과 요약")
print(f"  - 목표 학습 샘플 수: {NUM_TRAIN_SAMPLES}")
print(f"  - 목표 스텝 수 (Updates): {total_steps}")
print(f"  - 완료된 스텝 수 (Updates): {completed_steps}")
print(f"  - 실제 학습에 사용된 총 샘플 수: {trained_sample_count:,}")
if trained_sample_count > 0 and total_steps > 0:
    usage_percentage = (trained_sample_count / (total_steps * effective_batch_size)) * 100
    print(f"  - 데이터 사용률: {usage_percentage:.2f}%")
print("="*50)

if loss_history:
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history)
    plt.title("Training Loss Curve (share_instructblip_qwen)")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.savefig("share_instructblip_qwen_pretraining_loss_curve.png")
    print("학습 곡선 이미지를 'share_instructblip_qwen_pretraining_loss_curve'로 저장했습니다.")
    plt.show()