# CLIP 모델에 간단한 fine tuning을 적용해 모델을 사용해보자.
https://devs0n.tistory.com/195 해당 글을 참고했습니다.

In [1]:
%config Completer.use_jedi = False

In [2]:
# 모델 불러오기
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.optim import AdamW

device = "cuda" if torch.cuda.is_available() else "cpu"

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
optimizer = AdamW(model.parameters(), lr=5e-5)

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [3]:
import torch, torchvision

class CIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(self, clip_processor, is_train):
        self.clip_processor = clip_processor
        self.is_train = is_train
        self.dataset = torchvision.datasets.CIFAR10('/kaggle/working',download=True,train=is_train)
        self.class_texts = [
            f"this is {class_}."
            for class_ in self.dataset.class_to_idx.keys()
        ]
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        text = self.class_texts[label]
        return {
            "image": image,
            "label": label,
            "text": text,
        }
    
    def preprocess(self, batch):
        images = [data["image"] for data in batch]
        labels = [data["label"] for data in batch]
        texts = [data["text"] for data in batch]

        inputs = self.clip_processor(
            text=texts,
            images=images,
            return_tensors="pt", 
            padding=True
        )
        
        return {
            "text": texts,
            "label": torch.tensor(labels),
            **inputs,
        }

batch_size = 256

train_dataset = CIFAR10Dataset(processor, is_train=True)
test_dataset = CIFAR10Dataset(processor, is_train=False)

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               collate_fn=train_dataset.preprocess, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              collate_fn=train_dataset.preprocess, batch_size=batch_size)

print(len(train_dataset))
print(len(test_dataset))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /kaggle/working/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 15734291.35it/s]


Extracting /kaggle/working/cifar-10-python.tar.gz to /kaggle/working
Files already downloaded and verified
50000
10000


In [4]:
import torch.nn.functional as F

def loss_fn(logits_per_image, logits_per_text):
    assert logits_per_image.shape[0] == logits_per_image.shape[0] # logits' shape should be (nxn)
    assert logits_per_image.shape == logits_per_text.shape
    
    labels = torch.arange(logits_per_image.shape[0], device=device)
    loss_i = F.cross_entropy(logits_per_image, labels)
    loss_t = F.cross_entropy(logits_per_text, labels)
    loss = (loss_i + loss_t) / 2
    
    return loss

In [5]:
from tqdm import tqdm

def train(model, dataloader, optimizer):
    for batch in tqdm(dataloader, position=0, desc="batch", leave=False):
        optimizer.zero_grad()

        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text # logits_per_text == logits_per_image.T
        loss = loss_fn(logits_per_image, logits_per_text)
        loss.backward()
                
        optimizer.step()

    print(loss)

epochs = 5

for epoch in range(epochs):
    train(model, train_dataloader, optimizer)

                                                        

tensor(2.6383, device='cuda:0', grad_fn=<DivBackward0>)


                                                        

tensor(2.4036, device='cuda:0', grad_fn=<DivBackward0>)


                                                        

tensor(2.2878, device='cuda:0', grad_fn=<DivBackward0>)


                                                        

tensor(2.2076, device='cuda:0', grad_fn=<DivBackward0>)


                                                        

tensor(2.2722, device='cuda:0', grad_fn=<DivBackward0>)


In [7]:
import torch.nn.functional as F

all_class_texts = processor.tokenizer(test_dataset.class_texts)
all_class_texts = {k: torch.tensor(v, device=device) for k, v in all_class_texts.items()}

model.eval()
correct_count = 0
ce_loss_sum = 0

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            **all_class_texts,
        )
        
        probs = outputs.logits_per_image.cpu().softmax(dim=1)
        pred = probs.argmax(dim=1)
        label = batch["label"]

        correct_count += (pred == label).sum().item()
        ce_loss_sum += F.cross_entropy(probs, label).item()
    
accuracy = correct_count / len(test_dataloader.dataset)
ce_loss = ce_loss_sum / len(test_dataloader)
print(f"Test cross entropy loss: {ce_loss:.4}, Test accuracy: {accuracy:.4}")

100%|██████████| 40/40 [00:37<00:00,  1.06it/s]

Test cross entropy loss: 1.53, Test accuracy: 0.94



