<a href="https://colab.research.google.com/github/hmin27/2023_DL_Clip/blob/main/CLIP(Finetune)_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CLIP Fine tuning
- Food image classification
- Baseline of Fine Tuned CLIP model


In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import os
import clip
import torch
from torch import nn, optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder

%matplotlib inline
BATCH_SIZE = 1
EPOCH = 3
LR = 1e-5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# Prepare the Model and Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model = model.to(torch.float32)

# torch.manual_seed(42)

In [None]:
# Creating image path, text list
import pandas as pd

data_folder = '/content/drive/MyDrive/Study/DL_CLIP/Food_fewshot/4_shot/train'

image_paths = []
text_descriptions = []
class_folders = os.listdir(data_folder)

for class_folder in class_folders:
    class_folder_path = os.path.join(data_folder, class_folder)
    image_files = os.listdir(class_folder_path)

    for image_file in image_files:
        image_path = os.path.join(class_folder_path, image_file)
        image_paths.append(image_path)

        # Create text description using class label
        text_description = f"a photo of {class_folder.replace('_',' ')}"
        text_descriptions.append(text_description)

print(text_descriptions)
len(text_descriptions)  # 96개


['a photo of apple pie', 'a photo of apple pie', 'a photo of apple pie', 'a photo of apple pie', 'a photo of burger', 'a photo of burger', 'a photo of burger', 'a photo of burger', 'a photo of butter naan', 'a photo of butter naan', 'a photo of butter naan', 'a photo of butter naan', 'a photo of chai', 'a photo of chai', 'a photo of chai', 'a photo of chai', 'a photo of chapati', 'a photo of chapati', 'a photo of chapati', 'a photo of chapati', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of cheesecake', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chicken curry', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of chole bhature', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dal makhani', 'a photo of dhokla', 'a photo of dhokla', 'a photo of dhokla', 'a photo of dhokla', 'a photo of fried rice', 'a photo of fr

132

In [None]:
# Few-shot learning

class MyDataset(Dataset):
    def __init__(self, data_folder, preprocess):
        self.data_folder = data_folder
        self.preprocess = preprocess
        self.image_paths = []
        self.text_descriptions = []
        self.labels = []

        # image path list, text list
        class_folders = os.listdir(data_folder)

        for label, class_folder in enumerate(class_folders):
            class_folder_path = os.path.join(data_folder, class_folder)
            image_files = os.listdir(class_folder_path)

            for image_file in image_files:
                image_path = os.path.join(class_folder_path, image_file)
                self.image_paths.append(image_path)

                text_description = clip.tokenize(f"a photo of {class_folder.replace('_',' ')}")
                self.text_descriptions.append(text_description)

                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        image = self.preprocess(image)
        text = self.text_descriptions[idx]
        label = self.labels[idx]
        return image, text, label


data_folder = '/content/drive/MyDrive/Study/DL_CLIP/Food_fewshot/4_shot'
train_dataset = MyDataset(os.path.join(data_folder, 'train'), preprocess)
val_dataset = MyDataset(os.path.join(data_folder, 'validation'), preprocess)
test_dataset = MyDataset(os.path.join(data_folder, 'test'), preprocess)

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# len(train_dataset)  # 96
len(val_dataset)  # 28
# len(test_dataset)  # 12

132

In [None]:
from numpy.lib import shape_base
for batch in trainloader:
    images, texts, labels = batch
    # Print the first batch
    print("Image Path:", images[0].shape)
    print("Text Description:", texts[0].shape)
    print("Label: ", labels[0])
    break

Image Path: torch.Size([3, 224, 224])
Text Description: torch.Size([1, 77])
Label:  tensor(14)


# Training

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader)*EPOCH)

In [None]:
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def compute_accuracy(logits, ground_truth):
    _, predicted = logits.max(1)
    total = ground_truth.size(0)
    correct = (predicted == ground_truth).sum().item()
    return correct / total

for epoch in range(EPOCH):
    print(f"Epoch: {epoch+1}")

    # Training loop
    model.train()
    train_total, train_correct = 0, 0
    pbar = tqdm(trainloader, total=len(trainloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts, _ = batch
        texts = texts.squeeze(1)
        images = images.to(device)
        texts = texts.to(device)

        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        actual_batch_size = images.size(0)
        ground_truth = torch.arange(actual_batch_size).to(device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Compute train accuracy
        train_correct += (logits_per_image.argmax(dim=1) == ground_truth).float().sum().item()
        train_total += images.size(0)

        total_loss.backward()

        if device == "cpu":
            optimizer.step()
        else :
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        train_accuracy = 100 * train_correct / train_total
        pbar.set_description(f"Epoch {epoch+1}/{EPOCH}, Loss: {total_loss.item():.4f}, Train Acc: {train_accuracy:.2f}%")

    # Validation loop
    model.eval()
    val_total, val_correct = 0, 0
    with torch.no_grad():
        for batch in valloader:
            images, texts, _ = batch
            texts = texts.squeeze(1)
            images = images.to(device)
            texts = texts.to(device)

            logits_per_image, _ = model(images, texts)

            actual_batch_size = logits_per_image.size(0)
            ground_truth = torch.arange(actual_batch_size).to(device)

            val_correct += (logits_per_image.argmax(dim=1) == ground_truth).float().sum().item()
            val_total += images.size(0)

    val_accuracy = 100 * val_correct / val_total
    print(f"Validation Accuracy: {val_accuracy:.2f}%")

Epoch: 1


  0%|          | 0/132 [00:00<?, ?it/s]

Validation Accuracy: 100.00%
Epoch: 2


  0%|          | 0/132 [00:00<?, ?it/s]

Validation Accuracy: 100.00%
Epoch: 3


  0%|          | 0/132 [00:00<?, ?it/s]

Validation Accuracy: 100.00%


In [None]:
### Top 5 average in train and validation
#######

from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def top_k_accuracy(logits, ground_truth, k):
    _, top_k_predictions = logits.topk(k, 1, True, True)
    correct = top_k_predictions.eq(ground_truth.view(-1, 1).expand_as(top_k_predictions))
    correct_k = correct.sum().item()
    return correct_k / len(ground_truth)

for epoch in range(EPOCH):
    print(f"Epoch: {epoch+1}")

    # Training loop
    model.train()
    train_total, train_correct = 0, 0
    pbar = tqdm(trainloader, total=len(trainloader))
    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch
        texts = texts.squeeze(1)
        images = images.to(device)
        texts = texts.to(device)

        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        actual_batch_size = images.size(0)
        ground_truth = torch.arange(actual_batch_size).to(device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Compute train top-k accuracy
        train_top_k_acc = top_k_accuracy(logits_per_image, ground_truth, k=4)
        train_total += images.size(0)
        train_correct += train_top_k_acc * images.size(0)

        total_loss.backward()

        if device == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        train_accuracy = 100 * train_correct / train_total
        pbar.set_description(f"Epoch {epoch+1}/{EPOCH}, Loss: {total_loss.item():.4f}, Train Top-4 Acc: {train_accuracy:.2f}%")

    # Validation loop
    model.eval()
    val_total, val_correct = 0, 0
    with torch.no_grad():
        for batch in valloader:
            images, texts = batch
            texts = texts.squeeze(1)
            images = images.to(device)
            texts = texts.to(device)

            logits_per_image, _ = model(images, texts)

            actual_batch_size = logits_per_image.size(0)
            ground_truth = torch.arange(actual_batch_size).to(device)

            # Compute validation top-k accuracy
            val_top_k_acc = top_k_accuracy(logits_per_image, ground_truth, k=4)
            val_total += images.size(0)
            val_correct += val_top_k_acc * images.size(0)

    # Calculate and print the average top-k accuracy for the epoch
    train_avg_accuracy = 100 * train_correct / train_total
    val_avg_accuracy = 100 * val_correct / val_total
    print(f"Average Train Top-4 Accuracy: {train_avg_accuracy:.2f}%")
    print(f"Average Validation Top-4 Accuracy: {val_avg_accuracy:.2f}%")


Epoch: 1


  0%|          | 0/27 [00:00<?, ?it/s]

Average Train Top-4 Accuracy: 48.84%
Average Validation Top-4 Accuracy: 45.81%
Epoch: 2


  0%|          | 0/27 [00:00<?, ?it/s]

Average Train Top-4 Accuracy: 52.56%
Average Validation Top-4 Accuracy: 46.51%
Epoch: 3


  0%|          | 0/27 [00:00<?, ?it/s]

Average Train Top-4 Accuracy: 56.74%
Average Validation Top-4 Accuracy: 47.21%


In [None]:
import shutil

torch.save(model.state_dict(), 'CLIP_4shot_v2.pth')
shutil.copy('CLIP_4shot_v2.pth', '/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v2.pth')

'/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v2.pth'

In [None]:
loaded_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
loaded_model = loaded_model.to(torch.float32)
loaded_model.load_state_dict(torch.load('/content/drive/MyDrive/Study/DL_CLIP/model/CLIP_4shot_v2.pth'))
loaded_model.to(device)

100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 67.0MiB/s]


CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/805.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/805.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.0
