In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import requests
from torchvision.datasets import CIFAR100  # CIFAR-100 데이터셋 사용

# 기본 설정
BATCH_SIZE = 32
NUM_WORKERS = 4
LEARNING_RATE = 1e-3
NUM_EPOCHS = 50

# CLIP 모델 로드 함수
def load_clip_model(model_name="ViT-B/16"):
    import clip
    model, preprocess = clip.load(model_name)
    return model, preprocess

# 데이터셋 클래스
class CustomDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = CIFAR100(root=root, train=True, download=True)
        self.transform = transform
        self.classes = self.dataset.classes
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 프롬프트 생성 함수
def generate_prompts(classnames):
    templates = [
        'a photo of a {}.',
        'an image of a {}.',
        'this is a photo of a {}.',
        'this is an image of a {}.',
    ]
    prompts = []
    for template in templates:
        prompts.extend([template.format(c) for c in classnames])
    return prompts

# 데이터 로더 설정
def get_data_loaders(root='./data'):
    # CLIP 기본 전처리
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                           (0.26862954, 0.26130258, 0.27577711))
    ])
    
    dataset = CustomDataset(root=root, transform=transform)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, 
                            shuffle=True, num_workers=NUM_WORKERS)
    
    return train_loader, dataset.classes

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm

class CoOp(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=16, ctx_init="a photo of a"):
        super().__init__()
        
        # CLIP 모델의 텍스트 인코더
        self.clip_model = clip_model
        self.dtype = clip_model.dtype
        
        # Context length
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        
        # 프롬프트 초기화
        ctx_vectors = self.initialize_context_vectors(ctx_init, n_ctx, ctx_dim)
        self.ctx = nn.Parameter(ctx_vectors)
        
        # 클래스명 토큰화
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(clip.tokenize(name)) for name in classnames]
        prompts = [f"{name}." for name in classnames]
        
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS
        
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
    
    def initialize_context_vectors(self, ctx_init, n_ctx, ctx_dim):
        ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        return ctx_vectors
    
    def forward(self, image):
        # 이미지 인코딩
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        # 프롬프트 구성
        ctx = self.ctx
        prefix = self.token_prefix
        suffix = self.token_suffix
        
        # 프롬프트 임베딩 생성
        prompts = torch.cat(
            [
                prefix,  # (n_cls, 1, dim)
                ctx.unsqueeze(0).expand(self.n_cls, -1, -1),  # (n_cls, n_ctx, dim)
                suffix,  # (n_cls, *, dim)
            ],
            dim=1,
        )
        
        # 텍스트 특징 추출
        text_features = self.clip_model.encode_text(prompts)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        # 유사도 계산
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()
        
        return logits

def train_coop(model, train_loader, num_epochs=50, device='cuda'):
    optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0
    model = model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            logits = model(images)
            loss = F.cross_entropy(logits, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # 통계
            total_loss += loss.item()
            pred = logits.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
            
            pbar.set_postfix({'loss': total_loss / (pbar.n + 1),
                            'acc': 100 * correct / total})
        
        scheduler.step()
        
        # 현재 정확도
        epoch_acc = 100 * correct / total
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            torch.save(model.state_dict(), 'best_coop.pth')
        
        print(f'Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}, '
              f'Accuracy = {epoch_acc:.2f}%, Best = {best_acc:.2f}%')

# 평가 함수
def evaluate_coop(model, test_loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Evaluating'):
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            pred = logits.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

In [3]:
import clip

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/16", device=device)

# 데이터 로더 준비
train_loader, test_loader, classnames = get_data_loaders()

# CoOp 모델 초기화
coop_model = CoOp(clip_model, classnames, n_ctx=16)

# 학습 실행
train_coop(coop_model, train_loader, num_epochs=50, device=device)

# 평가
test_accuracy = evaluate_coop(coop_model, test_loader, device=device)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data


ValueError: not enough values to unpack (expected 3, got 2)