# CLIP, CoOp, CoCoOp Implementation and Lightweight Approach

[0] 환경 설정

# use dassl as a codebase to develop any deep learning projects

In [None]:
!git clone https://github.com/mlvlab/ProMetaR.git

In [1]:
# CLIP 설치
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/4p/kqk9nd_51cd1t2bqc3l1fyjr0000gn/T/pip-req-build-ttoy108d
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /private/var/folders/4p/kqk9nd_51cd1t2bqc3l1fyjr0000gn/T/pip-req-build-ttoy108d
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[1] 라이브러리 import

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

[2] CLIP 설정

In [5]:
# CLIP 모델 다운로드 및 토크나이저 초기화
_tokenizer = _Tokenizer()

def load_clip_to_cpu(model_name="ViT-B/16"):
    model, _ = clip.load(model_name, device="cpu")
    return model.eval()

# 기본 설정
device = "cuda" if torch.cuda.is_available() else "cpu"

[3] 기본 TextEncoder

In [6]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

[4] CoOp 구현

In [7]:
class CoOpPromptLearner(nn.Module):
    """CoOp의 프롬프트 학습 모듈"""
    def __init__(self, clip_model, n_ctx=16, n_cls=1000, ctx_init="a photo of a"):
        super().__init__()
        ctx_dim = clip_model.ln_final.weight.shape[0]
        dtype = clip_model.dtype
        
        # 프롬프트 초기화
        if ctx_init:
            # 주어진 단어로 초기화
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # 랜덤 초기화
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")
        
        # 컨텍스트 벡터를 학습 가능한 파라미터로 등록
        self.ctx = nn.Parameter(ctx_vectors)
        
        # 클래스 토큰 처리
        classnames = [name.replace("_", " ") for name in range(n_cls)]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])
        
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts

In [8]:
class CustomCLIP(nn.Module):
    """CoOp 전체 모델"""
    def __init__(self, clip_model, prompt_learner):
        super().__init__()
        self.prompt_learner = prompt_learner
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image, label=None):
        image_features = self.image_encoder(image.type(self.dtype))
        text_features = self.text_encoder(self.prompts, self.tokenized_prompts)

        # 정규화
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        # 로짓 계산
        logits = image_features @ text_features.t() * self.logit_scale.exp()

        if label is not None:
            loss = F.cross_entropy(logits, label)
            return loss
        return logits

[5] CoCoOp 구현

In [9]:
class CoCoOpPromptLearner(nn.Module):
    """CoCoOp의 프롬프트 학습 모듈"""
    def __init__(self, clip_model, n_ctx=4, n_cls=1000, ctx_init="a photo of a"):
        super().__init__()
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        
        # CLIP 모델의 차원 정보
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        clip_imsize = clip_model.visual.input_resolution
        
        # 프롬프트 초기화
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(clip_model.dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=clip_model.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")
        
        # 컨텍스트 벡터
        self.ctx = nn.Parameter(ctx_vectors)
        
        # Meta Network
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))
        
        # 클래스 토큰 처리
        classnames = [name.replace("_", " ") for name in range(n_cls)]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(clip_model.dtype)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])
        
        self.tokenized_prompts = tokenized_prompts

    def forward(self, im_features):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx
        
        # Meta Network로 이미지 특징에서 컨텍스트 바이어스 생성
        bias = self.meta_net(im_features)
        bias = bias.unsqueeze(1)
        ctx = ctx.unsqueeze(0)
        ctx_shifted = ctx + bias
        
        # 프롬프트 생성
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        
        return prompts

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]
        
        prompts = torch.cat([
            prefix,
            ctx,
            suffix,
        ], dim=1)
        
        return prompts

In [10]:
class CoCoOpCustomCLIP(nn.Module):
    """CoCoOp 전체 모델"""
    def __init__(self, clip_model, prompt_learner):
        super().__init__()
        self.prompt_learner = prompt_learner
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image, label=None):
        image_features = self.image_encoder(image.type(self.dtype))
        prompts = self.prompt_learner(image_features)
        
        logits = []
        for pts_i, img_feat_i in zip(prompts, image_features):
            text_features = self.text_encoder(pts_i, self.prompt_learner.tokenized_prompts)
            logit = img_feat_i @ text_features.T
            logits.append(logit)
        logits = torch.stack(logits)
        
        if label is not None:
            loss = F.cross_entropy(logits * self.logit_scale.exp(), label)
            return loss
        return logits * self.logit_scale.exp()

[6] 경량화된 CoCoOp 구현

In [11]:
class LightweightCoCoOpPromptLearner(nn.Module):
    """경량화된 CoCoOp 프롬프트 학습 모듈"""
    def __init__(self, clip_model, n_ctx=4, n_cls=1000, ctx_init="a photo of a", reduction_factor=8):
        super().__init__()
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        
        # CLIP 모델의 차원 정보
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        
        # 축소된 차원 계산
        self.reduced_ctx_dim = ctx_dim // reduction_factor
        self.reduced_vis_dim = vis_dim // reduction_factor
        
        # 차원 축소/복원을 위한 레이어
        self.ctx_reduction = nn.Linear(ctx_dim, self.reduced_ctx_dim)
        self.ctx_expansion = nn.Linear(self.reduced_ctx_dim, ctx_dim)
        
        # 프롬프트 초기화 (축소된 차원으로)
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(clip_model.dtype)
                ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
                ctx_vectors = self.ctx_reduction(ctx_vectors)
        else:
            ctx_vectors = torch.empty(n_ctx, self.reduced_ctx_dim, dtype=clip_model.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
        
        # 경량화된 컨텍스트 벡터
        self.ctx = nn.Parameter(ctx_vectors)
        
        # 경량화된 Meta Network
        self.meta_net = nn.Sequential(OrderedDict([
            ("reduction", nn.Linear(vis_dim, self.reduced_vis_dim)),
            ("linear1", nn.Linear(self.reduced_vis_dim, self.reduced_vis_dim // 4)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(self.reduced_vis_dim // 4, self.reduced_ctx_dim))
        ]))
        
        # 토큰 처리 (기존과 동일)
        classnames = [name.replace("_", " ") for name in range(n_cls)]
        prompts = [ctx_init + " " + name + "." for name in classnames]
        
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(clip_model.dtype)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])
        self.tokenized_prompts = tokenized_prompts

    def forward(self, im_features):
        # 이미지 특징 차원 축소 및 Meta Network 처리
        reduced_features = self.meta_net(im_features)
        bias = reduced_features.unsqueeze(1)
        
        # 컨텍스트 벡터 처리
        ctx = self.ctx.unsqueeze(0)
        ctx_shifted = ctx + bias
        
        # 차원 복원
        ctx_shifted = self.ctx_expansion(ctx_shifted)
        
        # 프롬프트 생성
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, self.token_prefix, self.token_suffix)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        
        return prompts

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]
        
        prompts = torch.cat([prefix, ctx, suffix], dim=1)
        return prompts

[7] 데이터셋 및 데이터 로더

In [12]:
class Config:
    """실험 설정"""
    def __init__(self):
        # 모델 설정
        self.n_ctx = 4
        self.n_cls = 101  # Caltech101
        self.reduction_factor = 8  # 경량화 비율
        
        # 학습 설정
        self.learning_rate = 1e-4
        self.weight_decay = 0.01
        self.max_epoch = 50
        self.batch_size = 32
        self.print_freq = 10
        
        # 데이터 설정
        self.image_size = 224
        self.train_ratio = 0.8

# 데이터셋 정의
class Caltech101Dataset(Dataset):
    """Caltech101 데이터셋"""
    def __init__(self, root_dir='./data/caltech101', transform=None, train=True):
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.classes = self.dataset.classes
        
        # 학습/테스트 분할 (80:20)
        train_size = int(0.8 * len(self.dataset))
        test_size = len(self.dataset) - train_size
        train_dataset, test_dataset = random_split(
            self.dataset, [train_size, test_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        self.data = train_dataset if train else test_dataset
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image, label = self.data[idx]
        return image, label

# 데이터 전처리
def get_transforms():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711]
        )
    ])

[8] 설정 클래스

In [13]:
class Config:
    n_ctx = 4
    n_cls = 101
    reduction_factor = 8
    learning_rate = 1e-4
    weight_decay = 0.01
    max_epoch = 50
    batch_size = 32
    print_freq = 10
    image_size = 224

[9] 학습 코드

In [14]:
def train_model(model_name='lightweight'):
    config = Config()
    # ... (앞서 구현한 학습 코드)

[10] 실행 및 테스트

In [15]:
class Trainer:
    """모델 학습을 위한 트레이너"""
    def __init__(self, model, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        
        # 옵티마이저 설정
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # 메트릭 기록
        self.best_acc = 0
        self.train_losses = []
        self.val_accuracies = []
    
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            loss = self.model(images, labels)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if (batch_idx + 1) % self.config.print_freq == 0:
                print(f"Batch [{batch_idx+1}/{len(train_loader)}] Loss: {loss.item():.4f}")
        
        return total_loss / len(train_loader)

[11] 성능 비교 시각화

In [17]:
def plot_results():
    plt.figure(figsize=(10, 5))
    # ... (결과 시각화 코드)
    
    
class Trainer:
    """모델 학습을 위한 트레이너"""
    def __init__(self, model, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        
        # 옵티마이저 설정
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # 메트릭 기록
        self.best_acc = 0
        self.train_losses = []
        self.val_accuracies = []
    
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            loss = self.model(images, labels)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if (batch_idx + 1) % self.config.print_freq == 0:
                print(f"Batch [{batch_idx+1}/{len(train_loader)}] Loss: {loss.item():.4f}")
        
        return total_loss / len(train_loader)

test

In [None]:
def test_model(model, test_loader, device):
    """개별 모델 테스트"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def run_experiment():
    # ... (이전 코드와 동일)
    
    # 테스트 데이터셋 준비
    test_dataset = Caltech101Dataset(train=False, transform=transform)
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=2
    )
    
    results = {}
    
    # 1. CoOp
    print("\nTraining and Testing CoOp...")
    coop_model = CustomCLIP(clip_model, n_ctx=config.n_ctx, n_cls=config.n_cls)
    coop_trainer = Trainer(coop_model, config)
    coop_train_acc = coop_trainer.train(train_loader, val_loader, config.max_epoch)
    coop_test_acc = test_model(coop_model, test_loader, device)
    results['CoOp'] = {'train': coop_train_acc, 'test': coop_test_acc}
    print(f"CoOp Test Accuracy: {coop_test_acc:.2f}%")
    
    # 2. CoCoOp
    print("\nTraining and Testing CoCoOp...")
    cocoop_model = CoCoOpCustomCLIP(clip_model, n_ctx=config.n_ctx, n_cls=config.n_cls)
    cocoop_trainer = Trainer(cocoop_model, config)
    cocoop_train_acc = cocoop_trainer.train(train_loader, val_loader, config.max_epoch)
    cocoop_test_acc = test_model(cocoop_model, test_loader, device)
    results['CoCoOp'] = {'train': cocoop_train_acc, 'test': cocoop_test_acc}
    print(f"CoCoOp Test Accuracy: {cocoop_test_acc:.2f}%")
    
    # 3. Lightweight CoCoOp
    print("\nTraining and Testing Lightweight CoCoOp...")
    light_model = LightweightCustomCLIP(
        clip_model, 
        n_ctx=config.n_ctx, 
        n_cls=config.n_cls,
        reduction_factor=config.reduction_factor
    )
    light_trainer = Trainer(light_model, config)
    light_train_acc = light_trainer.train(train_loader, val_loader, config.max_epoch)
    light_test_acc = test_model(light_model, test_loader, device)
    results['Lightweight'] = {'train': light_train_acc, 'test': light_test_acc}
    print(f"Lightweight CoCoOp Test Accuracy: {light_test_acc:.2f}%")
    
    # 결과 출력
    print("\nFinal Results:")
    for model_name, accs in results.items():
        print(f"{model_name}:")
        print(f"  Training Accuracy: {accs['train']:.2f}%")
        print(f"  Test Accuracy: {accs['test']:.2f}%")
    
    return results