In [3]:
import clip
import torch
import numpy as np
import pickle
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, BatchSampler
from tqdm import tqdm


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
print(device)

cuda


In [4]:
BATCH_SIZE = 32
class FashionDataset(Dataset):
    def __init__(self, images_list, prompt_list, label_list, preprocess):
        self.images_list = images_list
        self.prompt_list = prompt_list
        self.label_list = label_list
        self.preprocess = preprocess

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        image = self.images_list[idx]
        image_tensor = self.preprocess(image)

        prompt = self.prompt_list[idx]
        prompt_token = clip.tokenize([prompt])[0]

        label = self.label_list[idx]
        return image_tensor, prompt_token, label

In [5]:
class BalancedBatchSampler(BatchSampler):
    """
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

In [6]:
with open('data_loader/train_dataloader.pkl', 'rb') as f:
    loaded_train_dataloader = pickle.load(f)

with open('data_loader/test_dataloader.pkl', 'rb') as f:
    loaded_test_dataloader = pickle.load(f)

In [7]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

In [9]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
num_epochs = 10

for epoch in range(1, num_epochs + 1):
    pbar = tqdm(loaded_train_dataloader, total=len(loaded_train_dataloader))
    train_loss = 0
    for batch in pbar:
        optimizer.zero_grad()

        image_tensor, prompt_token, _ = batch

        image_tensor = image_tensor.to(device)
        prompt_token = prompt_token.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(image_tensor, prompt_token)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        # Compute loss
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass
        if device == "cpu":
            total_loss.backward()
            train_loss += total_loss
            optimizer.step()
        else:
            total_loss.backward()
            train_loss += total_loss
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

    test_loss = 0
    with torch.no_grad():
        for batch in loaded_test_dataloader:
            image_tensor, prompt_token, _ = batch

            image_tensor = image_tensor.to(device)
            prompt_token = prompt_token.to(device)

            # Forward pass
            logits_per_image, logits_per_text = model(image_tensor, prompt_token)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            # Compute loss
            total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            test_loss += total_loss

    print(f"Epoch {epoch}, Train loss {train_loss / len(loaded_train_dataloader)}")
    print(f"Epoch {epoch}, Train loss {test_loss / len(loaded_test_dataloader)}")

Epoch 1/10, Loss: 0.3992: 100%|██████████| 343/343 [02:00<00:00,  2.84it/s]


Epoch 1, Train loss 0.5712890625
Epoch 1, Train loss 0.80078125


Epoch 2/10, Loss: 0.3926: 100%|██████████| 343/343 [01:48<00:00,  3.16it/s]


Epoch 2, Train loss 0.409912109375
Epoch 2, Train loss 0.96923828125


Epoch 3/10, Loss: 0.3577:  57%|█████▋    | 196/343 [01:14<00:55,  2.64it/s]


KeyboardInterrupt: ignored

In [10]:
torch.save(model, "model_storage/pretrained_clip.pt")