In [1]:
# Restart Kernel
!pip install -qq -U bitsandbytes

In [2]:
import pandas as pd
import numpy as np

import shutil
from pathlib import Path
import os
import random
import re
import json

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import cv2
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn, optim

from peft import PeftModel
from transformers import (AutoModelForCausalLM, Blip2Processor, Blip2Model, BlipImageProcessor, AutoTokenizer, BitsAndBytesConfig,
    Trainer, TrainingArguments, default_data_collator, TrainerCallback)

from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

from datasets import load_dataset



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

save_path = '/kaggle/working/'

2025-07-05 03:41:26.783991: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751686886.978708      88 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751686887.034199      88 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


In [3]:
class BLIP2ForPhi(nn.Module):
    def __init__(self, vision_model, q_former, phi_model, query_tokens):
        super().__init__()
        self.vision_model = vision_model
        self.q_former = q_former
        self.projection = nn.Linear(q_former.config.hidden_size, phi_model.config.hidden_size)
        self.phi_model = phi_model
        self.query_tokens = query_tokens
        

        print("Freezing vision_model and phi_model parameters...")
        for param in self.vision_model.parameters():
            param.requires_grad = False
        for param in self.phi_model.parameters():
            param.requires_grad = False
        
        print("Training q_former and projection layer...")
        for param in self.q_former.parameters():
            param.requires_grad = True
        for param in self.projection.parameters():
            param.requires_grad = True

    
    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        image_embeds = self.vision_model(pixel_values).last_hidden_state

        batch_size = image_embeds.shape[0]
        qformer_query_embeds = self.query_tokens.expand(batch_size, -1, -1)

        
        image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
        query_outputs = self.q_former(
            query_embeds=qformer_query_embeds,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask
        )[0]

        projected_query = self.projection(query_outputs)

        text_embeds = self.phi_model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([projected_query, text_embeds], dim=1)
        
        # 입력 ID의 attention_mask와 쿼리의 attention_mask를 결합
        query_attention_mask = torch.ones(projected_query.size()[:-1], dtype=torch.long, device=projected_query.device)
        combined_attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)

        
        outputs = self.phi_model(
            inputs_embeds=inputs_embeds,
            attention_mask=combined_attention_mask,
            labels=labels, 
        )
        return outputs

In [4]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)


image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

print("Mean:", image_processor.image_mean)
print("Std:", image_processor.image_std)

phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
if phi_tokenizer.pad_token is None:
    phi_tokenizer.pad_token = phi_tokenizer.eos_token

blip2_model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")

vision_model = blip2_model.vision_model
q_former = blip2_model.qformer
query_tokens = blip2_model.query_tokens

phi_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-1_5",
    quantization_config=quantization_config,
    trust_remote_code=True,
)

preprocessor_config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

Mean: [0.48145466, 0.4578275, 0.40821073]
Std: [0.26862954, 0.26130258, 0.27577711]


tokenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/736 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.84G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

In [5]:
model = BLIP2ForPhi(vision_model, q_former, phi_model, query_tokens)

Freezing vision_model and phi_model parameters...
Training q_former and projection layer...


In [6]:
class Stage1Dataset(Dataset):
    def __init__(self, dataframe, image_processor, tokenizer, num_query_tokens=32, max_length=128, is_train=True):
        self.dataset = dataframe
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_query_tokens = num_query_tokens
        self.image_dir = '/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/'
        self.is_train = is_train

        if self.is_train:
            self.transforms = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
                transforms.RandomRotation(degrees=15), # -15도에서 +15도 사이로 랜덤하게 회전
                transforms.RandomHorizontalFlip(p=0.4),
                transforms.RandomVerticalFlip(p=0.4),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(), 
                transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) # 정규화
            ])
        else: 
            self.transforms = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
            ])
        

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx]
        
        image_name = row['image_name']
        captions = row[' comment'] # 캡션 리스트
        caption = random.choice(captions)
        
        image_path = os.path.join(self.image_dir, image_name)

        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
            print(f"Warning: Image file not found at {image_path}. Skipping.")
            return self.__getitem__(random.randint(0, len(self) - 1))

        
        # pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
        pixel_values = self.transforms(image)

        inputs = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        text_labels = inputs.input_ids.clone()
        text_labels[text_labels == self.tokenizer.pad_token_id] = -100
        
        query_labels = torch.full((1, self.num_query_tokens), -100)

        combined_labels = torch.cat([query_labels, text_labels], dim=1)

        return {
            "pixel_values": pixel_values.squeeze(),
            "input_ids": inputs.input_ids.squeeze(),   
            "attention_mask": inputs.attention_mask.squeeze(),
            "labels": combined_labels.squeeze()
        }

train_captions = pd.read_csv('/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv', delimiter='|')
train_captions = train_captions.dropna(subset=[' comment', 'image_name'])
train_captions = train_captions.groupby('image_name')[' comment'].agg(list).reset_index()

train_df, eval_df = train_test_split(
    train_captions,      
    test_size=0.2,   
    random_state=42  
)

train_dataset = Stage1Dataset(train_df, image_processor, phi_tokenizer)
valid_dataset = Stage1Dataset(eval_df, image_processor, phi_tokenizer)
train_debug = Subset(train_dataset, indices=range(50))
valid_debug = Subset(valid_dataset, indices=range(50))


In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup
import os
from pathlib import Path
from collections import deque

class CustomTrainer:
    """
    수동으로 학습 및 평가 루프를 제어하기 위한 커스텀 트레이너 클래스입니다.
    (단계별 학습 로직이 제거된 간소화 버전)
    """
    def __init__(self, model: nn.Module, optimizer, tokenizer, train_dataset, val_dataset=None, batch_size=8, save_dir="./checkpoints"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        
        self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, pin_memory=True) if val_dataset else None
        
        self.scaler = torch.cuda.amp.GradScaler() # 혼합 정밀도 학습용 스케일러
        self.scheduler = None # train 메서드 내에서 설정

        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        print(f"Using device: {self.device}")

    def _forward_step(self, batch: dict, return_preds: bool = False):
        """Train과 Eval에서 중복되는 모델 실행 및 손실 계산 로직을 통합합니다."""
        inputs = {k: v.to(self.device) for k, v in batch.items()}
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = self.model(**inputs)
            loss = outputs.loss

        if return_preds:
            pred_ids = torch.argmax(outputs.logits, dim=-1)
            decoded_preds = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            
            labels = inputs['labels'].clone()
            labels[labels == -100] = self.tokenizer.pad_token_id
            decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
            return loss, decoded_preds, decoded_labels

        return loss

    def train(self, num_epochs: int, resume_from_checkpoint: str = None):
        """지정된 에포크 수만큼 모델을 학습합니다."""
        total_steps = len(self.train_dataloader) * num_epochs
        warmup_steps = int(0.1 * total_steps)
        
        self.scheduler = get_cosine_schedule_with_warmup(
            optimizer=self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )

        if resume_from_checkpoint:
            start_epoch = self.load_checkpoint(resume_from_checkpoint)

        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            epoch_loss = 0

            progress_bar = tqdm(self.train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
            for batch in progress_bar:
                self.optimizer.zero_grad()
                
                loss = self._forward_step(batch)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step()
                
                epoch_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item(), lr=self.scheduler.get_last_lr()[0])
            
            avg_train_loss = epoch_loss / len(self.train_dataloader)
            print(f"Epoch {epoch+1} | Average Train Loss: {avg_train_loss:.4f}")

            if self.val_dataloader:
                avg_val_loss = self.evaluate(epoch)
            if epoch == num_epochs - 1:
                self.save_checkpoint(epoch, avg_train_loss, avg_val_loss)

    def evaluate(self, epoch: int):
        """검증 데이터셋으로 모델 성능을 평가합니다."""
        self.model.eval()
        total_loss = 0
        
        last_n_samples = 5
        last_preds = deque(maxlen=last_n_samples)
        last_labels = deque(maxlen=last_n_samples)

        with torch.no_grad():
            progress_bar = tqdm(self.val_dataloader, desc=f"Epoch {epoch+1} - Evaluating")
            for batch in progress_bar:
                loss, decoded_preds, decoded_labels = self._forward_step(batch, return_preds=True)
                total_loss += loss.item()
                
                last_preds.extend(decoded_preds)
                last_labels.extend(decoded_labels)

        avg_val_loss = total_loss / len(self.val_dataloader)
        print(f"\n--- Validation Results for Epoch {epoch+1} ---")
        print(f"Average Validation Loss: {avg_val_loss:.4f}")
        
        print("\n--- Last 5 Sample Predictions ---")
        for pred, label in zip(last_preds, last_labels):
            print(f"🔵 Pred:  {pred.strip()}")
            print(f"🟢 Label: {label.strip()}")
        print("---------------------------------------\n")

        return avg_val_loss

    def save_checkpoint(self, epoch: int, train_loss: float, val_loss: float):
        """모델의 체크포인트를 저장합니다."""
        save_path = self.save_dir / f"epoch_{epoch+1}"
        save_path.mkdir(parents=True, exist_ok=True)
        
        trainable_state_dict = {k: v for k, v in self.model.state_dict().items() if v.requires_grad}
        
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": trainable_state_dict,
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scaler_state_dict": self.scaler.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "train_loss": train_loss,
            "val_loss": val_loss,
        }
        torch.save(checkpoint, save_path / "checkpoint.pt")
        print(f"✅ Checkpoint saved to {save_path}")

    def load_checkpoint(self, checkpoint_path: str):
        """저장된 체크포인트 파일을 불러와 학습 상태를 복원합니다."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # strict=False: 체크포인트에 저장된 파라미터만 불러오고, 없는 파라미터는 무시
        self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_epoch = checkpoint['epoch']
        print(f"✅ Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch

In [8]:
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=1e-5, weight_decay=0.01)

# 3. CustomTrainer 인스턴스화
# 생성자에 optimizer를 전달합니다.
trainer = CustomTrainer(
    model=model,
    optimizer=optimizer,
    tokenizer=phi_tokenizer,
    train_dataset=train_dataset,
    val_dataset=valid_dataset,
    batch_size=4
)

checkpoint_path = "/kaggle/input/blip2_epoch14/pytorch/default/1/epoch_14/checkpoint.pt"

# 4. 학습 시작
trainer.train(num_epochs=18, resume_from_checkpoint=checkpoint_path) 

  self.scaler = torch.cuda.amp.GradScaler() # 혼합 정밀도 학습용 스케일러


Using device: cuda
✅ Checkpoint loaded. Resuming from epoch 12


Epoch 13/14 - Training:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 13 | Average Train Loss: 4.7302


Epoch 13 - Evaluating:   0%|          | 0/13 [00:00<?, ?it/s]


--- Validation Results for Epoch 13 ---
Average Validation Loss: 3.5312

--- Last 5 Sample Predictions ---
🔵 Pred:  .,, a a
 a with
 with with with girl
AG
 with with with
 with a
 with.


 AA. A sun throws holding a baseball and with and
🟢 Label: The pitcher is wearing a red uniform shirt.
🔵 Pred:  ,,,,, with with with
 with with with
 with

 with with with with with with with with with a a a a a A A in walking or sitting in for line line car
🟢 Label: People either standing or sitting waiting in a subway.
🔵 Pred:  ,,, a

 a
 D D with

D D
 with with with
ANDANDAND standing WITH a..

. A man and through the street. catch for the..
🟢 Label: Young men run on the beach to train for football.
🔵 Pred:  ,,, a,


D with with

. D A with with with
 with.. with. a
 aMAN
 A A man in standingboarding in the lake.
🟢 Label: a man is wakeboarding in a lake
🔵 Pred:  .,,,

 with with
 with with


A's with with with

A�.. a

ED
 A A man man player is throwing a bat bat.
🟢 Label: A male baseball player

Epoch 14/14 - Training:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 14 | Average Train Loss: 3.4292


Epoch 14 - Evaluating:   0%|          | 0/13 [00:00<?, ?it/s]


--- Validation Results for Epoch 14 ---
Average Validation Loss: 4.2926

--- Last 5 Sample Predictions ---
🔵 Pred:  .............. man.. man. man ) ) ) ) ) )
 man man ) ) A man player in the uniform uniform. a ball to a court...
🟢 Label: A baseball player in a red jersey throwing a ball at the pitchers mound.
🔵 Pred:  ....... man man man. ).. man. man man ) ) ) ) ) ) ) ) ) ) ) ) ) A in walking on the street.. for the train to
🟢 Label: People are standing on a train platform waiting for the train.
🔵 Pred:  ....... man ).. ).. man.. ). ) ) ) ) ) ) ) ) ) ) ) ) A man of people children in playing towards. the track.
🟢 Label: A group of young men are running together on a beach.
🔵 Pred:  ...... ) man...... man man. ). ) ) ) ). ) ) ) ) man ) ) A man inhes flying in a sense of
🟢 Label: A man kite surfing creates a wave.
🔵 Pred:  .............. man...... ) ) ) ) ). ). ) ) A man in a uniform background glove standing a man background. standing a ball. the park field.
🟢 Label: A man in a gray b

In [None]:
shutil.make_archive("/kaggle/working/checkpoints", 'zip', "/kaggle/working/checkpoints")