<a href="https://colab.research.google.com/github/hmin27/2023_DL_Clip/blob/main/CLIP(Finetune)_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CLIP Fine tuning
- Food image classification
- Baseline of Fine Tuned CLIP model


In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Collecting ftfy
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting wcwidth<0.3.0,>=0.2.12 (from ftfy)
  Downloading wcwidth-0.2.12-py2.py3-none-any.whl (34 kB)
Installing collected packages: wcwidth, ftfy
  Attempting uninstall: wcwidth
    Found existing installation: wcwidth 0.2.10
    Uninstalling wcwidth-0.2.10:
      Successfully uninstalled wcwidth-0.2.10
Successfully installed ftfy-6.1.3 wcwidth-0.2.12


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-c05wvvu9
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-c05wvvu9
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369500 sha256=12199eece4fc64b78dd5f6ebf0d68071aaf664c0ba4e98b51db4b53093c1260a
  Stored in directory: /tmp/pip-ephem-wheel-cache-t1gn9wgn/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
Installing collected packages: clip
Successfully installed clip-1.0


In [None]:
import os
import clip
import torch
from torch import nn, optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import shutil

%matplotlib inline
BATCH_SIZE = 16
EPOCH = 10
LR = 1e-7

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# Prepare the Model and Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model = model.to(torch.float32)

In [None]:
# Creating image path, text list
import pandas as pd

data_folder = '/content/drive/MyDrive/Study/DL_CLIP/Food_fewshot/4_shot/train'

image_paths = []
text_descriptions = []
class_folders = os.listdir(data_folder)
class_names = []

for class_folder in class_folders:
    class_folder_path = os.path.join(data_folder, class_folder)
    image_files = os.listdir(class_folder_path)

    class_names.append(class_folder)
    for image_file in image_files:
        image_path = os.path.join(class_folder_path, image_file)
        image_paths.append(image_path)

        # Create text description using class label
        text_description = f"a photo of {class_folder.replace('_',' ')}"
        text_descriptions.append(text_description)



print(text_descriptions)
# len(text_descriptions)
print(class_names)


['a photo of apple pie', 'a photo of apple pie', 'a photo of apple pie', 'a photo of apple pie', 'a photo of burger', 'a photo of burger', 'a photo of burger', 'a photo of burger', 'a photo of butter naan', 'a photo of butter naan', 'a photo of butter naan', 'a photo of butter naan', 'a photo of chai', 'a photo of chai', 'a photo of chai', 'a photo of chai', 'a photo of chapati', 'a photo of chapati', 'a photo of chapati', 'a photo of chapati', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dhokla', 'a photo of dhokla', 'a photo of dhokla', 'a photo of dhokla', 'a photo of fried rice', 'a photo of fr

In [None]:
# Few-shot learning

class FewshotDataset(Dataset):
    def __init__(self, data_folder, preprocess):
        self.data_folder = data_folder
        self.preprocess = preprocess
        self.image_paths = []
        self.text_descriptions = []
        self.labels = []

        # image path list, text list
        class_folders = os.listdir(data_folder)

        for label, class_folder in enumerate(class_folders):
            class_folder_path = os.path.join(data_folder, class_folder)
            image_files = os.listdir(class_folder_path)

            for image_file in image_files:
                image_path = os.path.join(class_folder_path, image_file)
                self.image_paths.append(image_path)

                text_description = f"a photo of {class_folder.replace('_',' ')}"
                self.text_descriptions.append(text_description)

                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        image = self.preprocess(image)
        text = self.text_descriptions[idx]
        label = self.labels[idx]
        return image, text, label


data_folder = '/content/drive/MyDrive/Study/DL_CLIP/Food_fewshot/4_shot'
train_dataset = FewshotDataset(os.path.join(data_folder, 'train'), preprocess)
val_dataset = FewshotDataset(os.path.join(data_folder, 'validation'), preprocess)
test_dataset = FewshotDataset(os.path.join(data_folder, 'test'), preprocess)

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# len(train_dataset)  # 96
# len(val_dataset)  # 28
# len(test_dataset)  # 12

In [None]:
from numpy.lib import shape_base
for batch in valloader:
    images, texts, labels = batch
    # Print the first batch
    print("Image Path:", images[0].shape)
    print("Text Description:", texts[0])
    print("Label: ", labels)
    break

Image Path: torch.Size([3, 224, 224])
Text Description: a photo of apple pie
Label:  tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])


# Training

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader)*EPOCH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import numpy as np

for epoch in range(EPOCH):
    print(f"Epoch: {epoch+1}")

    # Training loop
    model.train()
    train_total, train_correct, train_top1_correct, train_top3_correct = 0, 0, 0, 0
    pbar = tqdm(trainloader, total=len(trainloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts, _ = batch
        texts = clip.tokenize(texts).to(device)
        images = images.to(device)

        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        actual_batch_size = images.size(0)
        ground_truth = torch.arange(actual_batch_size).to(device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Compute train accuracy
        train_correct += (logits_per_image.argmax(dim=1) == ground_truth).float().sum().item()

        # Compute top-1 and top-5 accuracy
        train_top1_correct += accuracy_score(ground_truth.cpu().detach().numpy(), logits_per_image.argmax(dim=1).cpu().detach().numpy()) * actual_batch_size
        train_top3_correct += top_k_accuracy_score(ground_truth.cpu().detach().numpy(),logits_per_image.cpu().detach().numpy(), k=3) * actual_batch_size

        train_total += actual_batch_size

        total_loss.backward()

        if device == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        train_accuracy = 100 * train_correct / train_total
        train_top1_accuracy = 100 * train_top1_correct / train_total
        train_top3_accuracy = 100 * train_top3_correct / train_total
        pbar.set_description(f"Epoch {epoch+1}/{EPOCH}, Loss: {total_loss.item():.4f}, Train Top-1 Acc: {train_top1_accuracy:.2f}%, Train Top-3 Acc: {train_top3_accuracy:.2f}%")

    # Validation loop
    model.eval()
    val_total, val_correct_top1, val_correct_top3 = 0, 0, 0

    with torch.no_grad():
        for batch in valloader:
            images, texts, labels = batch
            texts = clip.tokenize(texts).to(device)
            images = images.to(device)
            labels = labels.to(device)

            logits_per_image, _ = model(images, texts)
            print("Logits per Image:", logits_per_image)

            print("Predicted Labels:", logits_per_image.argmax(dim=1))
            print("Actual Labels:", labels)

            # Compute top-1 and top-3 accuracy
            val_correct_top1 += (logits_per_image.argmax(dim=1) == labels).float().sum().item()
            _, top3_predictions = logits_per_image.topk(3, dim=1)
            val_correct_top3 += torch.any(top3_predictions == labels.view(-1, 1), dim=1).float().sum().item()

            val_total += images.size(0)

    # Compute and print validation accuracy
    val_accuracy_top1 = 100 * val_correct_top1 / val_total
    val_accuracy_top3 = 100 * val_correct_top3 / val_total

    print(f"Validation Top-1 Acc: {val_accuracy_top1:.2f}%, Validation Top-3 Acc: {val_accuracy_top3:.2f}%")



Epoch: 1


  0%|          | 0/9 [00:00<?, ?it/s]

Logits per Image: tensor([[32.2812, 32.2812, 32.2812, 32.2812, 22.1250, 22.1250, 22.1250, 22.1250,
         22.3125, 22.3125, 22.3125, 22.3125, 21.9219, 21.9219, 21.9219, 21.9219],
        [31.7188, 31.7188, 31.7188, 31.7188, 21.7188, 21.7188, 21.7188, 21.7188,
         21.1406, 21.1406, 21.1406, 21.1406, 19.4844, 19.4844, 19.4844, 19.4844],
        [32.1562, 32.1562, 32.1562, 32.1562, 20.0000, 20.0000, 20.0000, 20.0000,
         21.8281, 21.8281, 21.8281, 21.8281, 20.6719, 20.6719, 20.6719, 20.6719],
        [29.2812, 29.2812, 29.2812, 29.2812, 17.7344, 17.7344, 17.7344, 17.7344,
         18.2812, 18.2812, 18.2812, 18.2812, 18.0938, 18.0938, 18.0938, 18.0938],
        [20.8750, 20.8750, 20.8750, 20.8750, 28.7500, 28.7500, 28.7500, 28.7500,
         17.5625, 17.5625, 17.5625, 17.5625, 19.0625, 19.0625, 19.0625, 19.0625],
        [18.5625, 18.5625, 18.5625, 18.5625, 26.7344, 26.7344, 26.7344, 26.7344,
         16.9219, 16.9219, 16.9219, 16.9219, 18.5938, 18.5938, 18.5938, 18.5938],
    

  0%|          | 0/9 [00:00<?, ?it/s]

Logits per Image: tensor([[32.2812, 32.2812, 32.2812, 32.2812, 22.1094, 22.1094, 22.1094, 22.1094,
         22.3281, 22.3281, 22.3281, 22.3281, 21.9219, 21.9219, 21.9219, 21.9219],
        [31.7031, 31.7031, 31.7031, 31.7031, 21.7031, 21.7031, 21.7031, 21.7031,
         21.1406, 21.1406, 21.1406, 21.1406, 19.4844, 19.4844, 19.4844, 19.4844],
        [32.1562, 32.1562, 32.1562, 32.1562, 19.9844, 19.9844, 19.9844, 19.9844,
         21.8281, 21.8281, 21.8281, 21.8281, 20.6719, 20.6719, 20.6719, 20.6719],
        [29.2812, 29.2812, 29.2812, 29.2812, 17.7344, 17.7344, 17.7344, 17.7344,
         18.2812, 18.2812, 18.2812, 18.2812, 18.0938, 18.0938, 18.0938, 18.0938],
        [20.8750, 20.8750, 20.8750, 20.8750, 28.7344, 28.7344, 28.7344, 28.7344,
         17.5781, 17.5781, 17.5781, 17.5781, 19.0781, 19.0781, 19.0781, 19.0781],
        [18.5625, 18.5625, 18.5625, 18.5625, 26.7031, 26.7031, 26.7031, 26.7031,
         16.9219, 16.9219, 16.9219, 16.9219, 18.5938, 18.5938, 18.5938, 18.5938],
    

  0%|          | 0/9 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
torch.save(model.state_dict(), 'CLIP_4shot_v4.pth')
shutil.copy('CLIP_4shot_v4.pth', '/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v4.pth')

'/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v4.pth'

In [None]:
loaded_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
loaded_model = loaded_model.to(torch.float32)
loaded_model.load_state_dict(torch.load('/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v3.pth'))
loaded_model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.0


In [None]:
from sklearn.metrics import precision_recall_fscore_support

# 모델을 평가하기 전에 모델을 evaluation 모드로 설정
model.eval()

# 테스트 루프
test_correct = 0
test_total = 0
true_labels = []
predicted_labels = []

with torch.no_grad():
    for batch in testloader:
        images, texts, labels_tuple = batch
        texts = texts.squeeze(1)
        images = images.to(device)
        texts = texts.to(device)
        labels = labels_tuple[1].to(device)  # 레이블을 적절히 추출

        logits_per_image, _ = model(images, texts)

        # 예측값과 실제 레이블 저장
        _, predicted = logits_per_image.max(1)
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())

        # 정확도 계산
        test_correct += (predicted == labels).sum().item()
        test_total += labels.size(0)

# 전체 테스트 세트에 대한 정확도 출력
test_accuracy = 100 * test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Precision, Recall, F1-score 계산
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='weighted')

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")


Test Accuracy: 3.03%
Precision: 0.0061
Recall: 0.0303
F1-score: 0.0101


  _warn_prf(average, modifier, msg_start, len(result))
