In [1]:
# "./ckpt/20250313_172710/best_model.pth"

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import datetime
import logging
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from peft import LoraConfig, get_peft_model
from clip_dataset import ImageTextDataset, collate_fn  # 사용자 정의 데이터셋 모듈
from loss import CLIPContrastiveLoss              # 사용자 정의 손실함수

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

class OWLVITCLIPModel:
    """
    OwlViT 모델을 로드하고, LoRA를 적용한 후 head만 학습할 수 있도록 하는 클래스입니다.
    여기서는 bbox 예측 head(box_head)와 클래스 예측 head(class_head)만 학습합니다.
    """
    def __init__(self, model_name="google/owlvit-base-patch32", device='cuda', use_lora=True, lora_config_params=None):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        # 프로세서 및 기본 모델 로드
        self.processor = OwlViTProcessor.from_pretrained(model_name)
        self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device)
        self.model.train()

        # 전체 파라미터 Freeze
        for param in self.model.parameters():
            param.requires_grad = False

        if use_lora:
            # 기본 LoRA 하이퍼파라미터 값 (필요시 조정)
            if lora_config_params is None:
                lora_config_params = {"r": 4, "lora_alpha": 32, "lora_dropout": 0.1}
            lora_config = LoraConfig(
                task_type="OTHER",  # 태스크에 따라 적절한 task_type으로 변경 가능
                r=lora_config_params["r"],
                lora_alpha=lora_config_params["lora_alpha"],
                lora_dropout=lora_config_params["lora_dropout"],
                target_modules=["text_projection", "visual_projection"]
            )
            # PEFT 라이브러리를 이용하여 LoRA 어댑터 추가
            self.model = get_peft_model(self.model, lora_config)
        else:
            # LoRA를 사용하지 않는 경우, 예시로 text_projection, visual_projection만 unfreeze
            trainable_layers = [
                self.model.owlvit.text_projection,
                self.model.owlvit.visual_projection
            ]
            for layer in trainable_layers:
                for param in layer.parameters():
                    param.requires_grad = True
            self.model.owlvit.logit_scale.requires_grad = True

        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        logging.info(f"🚀 초기 trainable 파라미터: {trainable_params / 1e6:.2f}M")

    def load_checkpoint(self, checkpoint_path):
        """
        checkpoint에서 모델 state_dict를 로드합니다.
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        logging.info(f"Checkpoint loaded from {checkpoint_path}")

    def freeze_except_heads(self):
        """
        모델의 모든 파라미터를 freeze하고, 'box_head'와 'class_head'에 해당하는 파라미터만 학습 가능하도록 설정합니다.
        """
        for name, param in self.model.named_parameters():
            if "box_head" in name or "class_head" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        logging.info(f"🚀 Head만 학습 가능하도록 설정됨. Trainable 파라미터: {trainable_params / 1e6:.2f}M")

    def reinitialize_heads(self):
        """
        box_head와 class_head에 해당하는 모듈들의 파라미터를 재초기화합니다.
        """
        def _reinit_module(module, module_name):
            if hasattr(module, "reset_parameters"):
                module.reset_parameters()
                logging.info(f"{module_name} 재초기화됨.")
        for name, module in self.model.named_modules():
            if "box_head" in name or "class_head" in name:
                _reinit_module(module, name)

    def get_optimizer(self, lr=1e-4):
        """학습 가능한 파라미터(여기서는 head만)를 업데이트하는 옵티마이저 반환"""
        return optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr)

    def get_dataloaders(self, train_dir, val_dir, batch_size=16):
        """데이터 로더 생성"""
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        ])

        train_dataset = ImageTextDataset(train_dir, self.processor, transform=transform)
        val_dataset = ImageTextDataset(val_dir, self.processor)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        return train_loader, val_loader

    def train(self, train_dir, val_dir, epochs=10, batch_size=16, lr=1e-4, ckpt_base_dir="ckpt"):
        """
        학습 및 검증 루프.
        학습 전에 freeze_except_heads()를 호출하여 head만 학습하도록 합니다.
        """
        # head만 학습할 수 있도록 설정
        self.freeze_except_heads()

        train_loader, val_loader = self.get_dataloaders(train_dir, val_dir, batch_size)
        optimizer = self.get_optimizer(lr)
        criterion = CLIPContrastiveLoss().to(self.device)

        # 체크포인트 저장 폴더 생성
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        ckpt_dir = os.path.join(ckpt_base_dir, timestamp)
        os.makedirs(ckpt_dir, exist_ok=True)

        best_val_loss = float("inf")

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0
            for batch in train_loader:
                optimizer.zero_grad()
                pixel_values = batch["pixel_values"].to(self.device)
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)

                outputs = self.model(
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

                # bbox 및 class head 학습은 보통 image_embeds, text_embeds가 아닌
                # box_head와 class_head의 출력을 이용합니다.
                # (아래 예시는 단순히 기존 임베딩 대비 손실을 계산하는 예시이며,
                # 실제 bbox 및 class head 학습에는 적절한 손실 함수 및 전처리가 필요합니다.)
                vision_embeds = outputs.image_embeds.mean(dim=(1, 2))
                text_embeds = outputs.text_embeds.squeeze(1)
                vision_embeds = self.model.owlvit.visual_projection(vision_embeds)
                text_embeds = self.model.owlvit.text_projection(text_embeds)

                loss = criterion(vision_embeds, text_embeds)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_train_loss = total_loss / len(train_loader)
            avg_val_loss = self.validate(val_loader, criterion)
            logging.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

            # 체크포인트 저장
            checkpoint = {
                "epoch": epoch + 1,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss
            }
            ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch+1}.pth")
            torch.save(checkpoint, ckpt_path)
            logging.info(f"Checkpoint saved: {ckpt_path}")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_ckpt_path = os.path.join(ckpt_dir, "best_model.pth")
                torch.save(checkpoint, best_ckpt_path)
                logging.info(f"Best model updated: {best_ckpt_path}")

    def validate(self, val_loader, criterion):
        """검증 루프"""
        self.model.eval()
        total_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)

                outputs = self.model(
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                vision_embeds = outputs.image_embeds.mean(dim=(1, 2))
                text_embeds = outputs.text_embeds.squeeze(1)
                vision_embeds = self.model.owlvit.visual_projection(vision_embeds)
                text_embeds = self.model.owlvit.text_projection(text_embeds)

                loss = criterion(vision_embeds, text_embeds)
                total_loss += loss.item()
        return total_loss / len(val_loader)

if __name__ == "__main__":
    # 데이터셋 경로 (프로젝트에 맞게 수정)
    train_dataset_dir = "./total_dataset/train_dataset/"
    val_dataset_dir = "./total_dataset/val/"

    # 모델 인스턴스 생성 (LoRA 적용)
    model_wrapper = OWLVITCLIPModel(use_lora=True)

    # 기존 checkpoint에서 모델 로드 (원한다면 head 재초기화도 수행)
    checkpoint_path = "./ckpt/20250313_172710/best_model.pth"
    model_wrapper.load_checkpoint(checkpoint_path)
    # (원하는 경우) head 재초기화
    # model_wrapper.reinitialize_heads()

    # head만 학습하도록 설정한 후 학습 시작
    model_wrapper.train(
        train_dir=train_dataset_dir,
        val_dir=val_dataset_dir,
        epochs=10,
        batch_size=16,
        lr=1e-4,
        ckpt_base_dir="ckpt"
    )


In [9]:
model = OWLVITCLIPModel()

2025-03-13 18:37:16,478 - INFO - 🚀 Trainable Parameters: 0.01M


In [15]:
model.model

PeftModel(
  (base_model): LoraModel(
    (model): OwlViTForObjectDetection(
      (owlvit): OwlViTModel(
        (text_model): OwlViTTextTransformer(
          (embeddings): OwlViTTextEmbeddings(
            (token_embedding): Embedding(49408, 512)
            (position_embedding): Embedding(16, 512)
          )
          (encoder): OwlViTEncoder(
            (layers): ModuleList(
              (0-11): 12 x OwlViTEncoderLayer(
                (self_attn): OwlViTAttention(
                  (k_proj): Linear(in_features=512, out_features=512, bias=True)
                  (v_proj): Linear(in_features=512, out_features=512, bias=True)
                  (q_proj): Linear(in_features=512, out_features=512, bias=True)
                  (out_proj): Linear(in_features=512, out_features=512, bias=True)
                )
                (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                (mlp): OwlViTMLP(
                  (activation_fn): QuickGELUActivation()
  