In [45]:
from pathlib import Path
import random
import os

import clip
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split, Subset
import torchmetrics
import torchvision
from tqdm import tqdm
from PIL import Image

from oma_recipeclassifier.src.model_dev.few_shot.balanced_batch_sampler import BalancedBatchSampler

In [46]:
def finetune():
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)
    
    SAVE_INTERVAL = 9
    BATCH_SIZE = 32
    NUM_EPOCHS = 10

    # GPT - Inconsistencies in handling model precision
    def convert_models_to_fp32(model):
        for p in model.parameters():
            p.data = p.data.float()
            if p.requires_grad():
                p.grad.data = p.grad.data.float()

    # ----------
    # Load model 
    # ----------
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device, jit=False) # jit=False to disable TorchScript for fine-tuning
    # GPT - These lines are unnecessary?
    if device == "cpu":
        model.float() # Converts model to fp32
    else:
        clip.model.convert_weights(model) # Converts model to fp16 (unnecessary since CLIP already uses fp16 by default)
    
    # TensorBoard writer
    writer = SummaryWriter()
    weights_path = Path("model_checkpoints") # Dir to save model checkpoints
    weights_path.mkdir(exist_ok=True)

    # ----------
    # Load dataset 
    # ----------
    dataset = torchvision.datasets.ImageFolder("dataset", transform=preprocess)
    
    # Split dataset into train/test
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_indices, test_indices = random_split(range(len(dataset)), [train_size, test_size])
    train_dataset, test_dataset = Subset(dataset, train_indices), Subset(dataset, test_indices)
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Extract train/test labels
    train_labels = torch.tensor([dataset.targets[i] for i in train_indices])
    test_labels = torch.tensor([dataset.targets[i] for i in test_indices])
    
    # Create DataLoaders
    train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler)
    test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_sampler)
    
    # Display train/test dataset distribution
    
    # ----------
    # Training 
    # ----------
    # Loss functions
    loss_img = torch.nn.CrossEntropyLoss()
    loss_txt = torch.nn.CrossEntropyLoss(ignore_index=-1)
    
    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=1e-7, weight_decay=1e-4)
    
    num_batches_train = len(train_dataloader.dataset) / BATCH_SIZE
    num_batches_test = len(test_dataloader.dataset) / BATCH_SIZE
    
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
        epoch_train_loss = 0
        model.train()
        
        for batch in tqdm(train_dataloader, total=num_batches_train):
            optimizer.zero_grad() # Clear gradients from previous iteration
            
            images, label_ids = batch
            
            images = torch.stack([img for img in images], dim=0).to(device) # Stack images into a single tensor (adds an extra dim representing the batch)
            # Generate text prompts: the number of text prompts will be equal to the number of images in the batch (label_ids)
            # Scenario: matching each image with its corresponding text prompt, doesn't allow for comparison against other text prompts
            # texts = [f"A photo of a {train_dataset.dataset.classes[label_id]}" for label_id in label_ids]
            texts = [f"A photo of a {dataset.classes[label_id]}" for label_id in label_ids]
            text = clip.tokenize(texts).to(device) # Tokenize text prompts
            
            print(f"images shape: {images.shape}")
            print(f"label_ids shape: {label_ids.shape}")
            print(f"label_ids: {label_ids}")
            print(f"text shape: {text.shape}")
            
            logits_per_image, logits_per_text = model(images, text) # Forward pass
    
            # Ground truth labels: For each batch, the i-th image corresponds to the i-th text
            # Therefore, the i-th image should have the same label as the i-th text, i.e. [0, 1, 2, ..., BATCH_SIZE - 1]
            # The same happens for text, so we use the same ground truth for both image and text
            ground_truth = torch.arange(logits_per_image.shape[0], dtype=torch.long, device=device)
            
            print(f"logits_per_image shape: {logits_per_image.shape}")
            print(f"logits_per_text shape: {logits_per_text.shape}")
            print(f"ground_truth shape: {ground_truth.shape}")
            print(f"ground_truth: {ground_truth}")
    
            # Compute loss
            total_train_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            total_train_loss.backward() # Backward pass
            epoch_train_loss += total_train_loss
            
            torch.nn.utils.clip_grad_norm_(params, 1.0) # Clip gradients to prevent exploding gradients
            
            if device == "cpu":
                optimizer.step() # Update weights
            else:
                convert_models_to_fp32(model)
                optimizer.step() # Update weights
                clip.model.convert_weights(model)
            
        epoch_train_loss /= num_batches_train # Average loss per epoch
        writer.add_scalar("Loss/train", epoch_train_loss, epoch) # Log loss to TensorBoard
        
        # Save model weights
        if epoch % SAVE_INTERVAL == 0:
            torch.save({"epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, weights_path / f"epoch_{epoch}.pt")
            print(f"Saved weights under {weights_path}/epoch_{epoch}.pt")
           
        # ----------
        # Testing
        # ----------
        num_batches_test = len(test_dataloader.dataset) / BATCH_SIZE 
        epoch_test_loss = 0
        model.eval()
        
        acc_top3_list = []
        acc_top1_list = []
        
        print("------ Testing ------")
        num_classes = len(dataset.classes)
        classes = torch.arange(num_classes, device=device)
        
        print(f"num_classes: {num_classes}")
        print(f"dataset.classes: {dataset.classes}")
        print(f"classes: {classes}")
        
        for i, batch in enumerate(tqdm(test_dataloader, total=num_batches_test)):
            images, label_ids = batch
            images = images.to(device)
            label_ids = label_ids.to(device)
            
            # Generate text prompts: the number of text prompts will be equal to the number of classes in the dataset (classes)
            # Scenario: classifying each image against all possible classes, allows for comparison against all classes
            texts = torch.cat([clip.tokenize(f"A photo of a {c}") for c in dataset.classes]).to(device) # Concatenate text prompts
            
            print(f"images shape: {images.shape}")
            print(f"label_ids shape: {label_ids.shape}")
            print(f"label_ids: {label_ids}")
            print(f"texts shape: {texts.shape}")

            
            with torch.no_grad():
                image_features = model.encode_image(images)
                text_features = model.encode_text(texts)
                
                logits_per_image, logits_per_text = model(images, texts) # Forward pass
                
                ground_truth_img = torch.arange(logits_per_image.shape[0], dtype=torch.long, device=device)
                # Ground truth labels for text
                ground_truth_txt = -1 * torch.ones(len(classes), dtype=torch.long, device=device)
                for idx, class_label in enumerate(classes):
                    if class_label in label_ids:
                        ground_truth_txt[idx] = (label_ids == class_label).nonzero(as_tuple=True)[0].item()

                print(f"logits_per_image shape: {logits_per_image.shape}")
                print(f"logits_per_text shape: {logits_per_text.shape}")
                print(f"ground_truth shape: {ground_truth.shape}")
                print(f"ground_truth_img: {ground_truth_img}")
                print(f"ground_truth_txt: {ground_truth_txt}")
                
                # Compute loss
                img_loss = loss_img(logits_per_image, ground_truth_img)
                txt_loss = loss_txt(logits_per_text, ground_truth_txt)
                total_loss = (img_loss + txt_loss) / 2
                epoch_test_loss += total_loss
            
            # Normalize features
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            assert torch.equal(logits_per_image.T, logits_per_text), "Logits are not equal"
            
            # Compute cosine similarity
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            
            # Compute top accuracy
            acc_top1 = torchmetrics.functional.accuracy(similarity, label_ids, task="multiclass", num_classes=num_classes)
            acc_top3 = torchmetrics.functional.accuracy(similarity, label_ids, task="multiclass", num_classes=num_classes, top_k=3)
            
            acc_top1_list.append(acc_top1)
            acc_top3_list.append(acc_top3)

        epoch_test_loss /= num_batches_test # Average loss per epoch
        writer.add_scalar("Loss/test", epoch_test_loss, epoch) # Log loss to TensorBoard

        print(f"Epoch {epoch} train loss: {epoch_train_loss / num_batches_train}")
        print(f"Epoch {epoch} test loss: {epoch_test_loss / num_batches_test}")
        
        # Compute mean top3 and top1 accuracy
        mean_top3_accuracy = torch.stack(acc_top3_list).mean().cpu().numpy()
        print(f"Mean Top 3 Accuracy: {mean_top3_accuracy*100:.2f}%")
        writer.add_scalar("Test Accuracy/Top3", mean_top3_accuracy, epoch)
        mean_top1_accuracy = torch.stack(acc_top1_list).mean().cpu().numpy()
        print(f"Mean Top 1 Accuracy: {mean_top1_accuracy*100:.2f}%")
        writer.add_scalar("Test Accuracy/Top1", mean_top1_accuracy, epoch)
        
    writer.flush() # Make sure all pending writes are completed
    writer.close()

In [47]:
finetune()

Train dataset size: 74
Test dataset size: 19
Epoch 1/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 7,  6,  8,  9, 10,  1,  4,  5,  3,  0,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:05<00:06,  5.16s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  1,  5,  2,  0,  6,  7,  9,  3, 10,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:11<00:01,  6.01s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 2,  4,  6,  8,  3,  7,  9,  1,  0, 10,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:16,  5.30s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 5,  9,  7,  6,  3, 10,  8,  4,  2,  1,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:21,  5.43s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  0,  6,  4,  5,  3,  8,  1,  2,  7, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:27,  5.66s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8, 10,  7,  5,  0,  3,  6,  2,  4,  1,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:32,  5.42s/it]


Saved weights under model_checkpoints/epoch_0.pt
------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([7, 3, 6, 4, 5, 0, 8])
texts shape: torch.Size([11, 77])


  full_bar = Bar(frac,
168%|██████████| 1/0.59375 [00:03<-1:59:59,  3.01s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  1,  3,  4,  2,  0,  6, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([5, 7, 4, 3, 8, 6, 0])
texts shape: torch.Size([11, 77])


2it [00:05,  2.82s/it]                                


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 6, -1, -1,  3,  2,  0,  5,  1,  4, -1, -1])
Epoch 0 train loss: 1.845449686050415
Epoch 0 test loss: 14.577085494995117
Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 50.00%
Epoch 2/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  8,  7,  4,  6,  0,  5,  9, 10,  3,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:06<00:08,  6.61s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 7, 10,  1,  6,  0,  5,  9,  3,  4,  8,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:12<00:01,  6.26s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 2,  7, 10,  4,  0,  1,  3,  8,  5,  6,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:16,  5.32s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 2,  4,  1,  3, 10,  0,  8,  7,  5,  6,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:20,  4.76s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  2,  4,  0,  8,  5,  7,  9,  3,  6, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:26,  5.12s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 7,  9,  1,  3,  2,  0,  4,  8,  5,  6, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:30,  5.15s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([7, 8, 0, 6, 4, 5, 3])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<00:00,  2.45s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 2, -1, -1,  6,  4,  5,  3,  0,  1, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 8, 4, 0, 3, 7, 5])
texts shape: torch.Size([11, 77])


2it [00:04,  2.46s/it]                             


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 3, -1, -1,  4,  2,  6,  0,  5,  1, -1, -1])
Epoch 1 train loss: 1.6538504362106323
Epoch 1 test loss: 15.908926010131836
Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 42.86%
Epoch 3/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  1, 10,  8,  9,  7,  6,  5,  2,  0,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:04<00:05,  4.54s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  2,  5,  4,  6,  7,  8,  3,  0, 10,  1])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:09<00:01,  4.74s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 0,  9,  4,  8, 10,  1,  5,  6,  2,  3,  7])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:13,  4.48s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 0,  1,  4,  8,  6,  9,  3,  7,  2, 10,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:17,  4.33s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([10,  3,  7,  5,  4,  2,  8,  6,  0,  1,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:22,  4.41s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  3,  8,  7,  5,  9,  6,  2, 10,  4,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:26,  4.40s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([8, 4, 3, 6, 7, 0, 5])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<00:00,  2.45s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  2,  1,  6,  3,  4,  0, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([3, 5, 0, 7, 4, 8, 6])
texts shape: torch.Size([11, 77])


2it [00:05,  3.00s/it]                             


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 2, -1, -1,  0,  4,  1,  6,  3,  5, -1, -1])
Epoch 2 train loss: 1.428858995437622
Epoch 2 test loss: 17.54129981994629
Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 57.14%
Epoch 4/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  4,  0,  7,  6, 10,  8,  5,  2,  9,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:06<00:08,  6.10s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  3,  4,  9,  0,  5, 10,  6,  1,  7,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:10<00:01,  5.35s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  0,  2,  4, 10,  1,  9,  6,  3,  7,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:15,  4.84s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 7,  6,  2, 10,  0,  5,  9,  1,  4,  8,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:19,  4.61s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  3,  9,  6,  7,  5,  2,  8,  0, 10,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:23,  4.52s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 0,  8,  5,  7,  6,  1,  2,  4, 10,  3,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:28,  4.72s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([5, 4, 7, 8, 6, 0, 3])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<-1:59:59,  2.72s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  6,  1,  0,  4,  2,  3, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([5, 8, 6, 4, 0, 3, 7])
texts shape: torch.Size([11, 77])


2it [00:05,  2.64s/it]                                


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 4, -1, -1,  5,  3,  0,  2,  6,  1, -1, -1])
Epoch 3 train loss: 1.4161595106124878
Epoch 3 test loss: 17.38113784790039
Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 71.43%
Epoch 5/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 5,  1,  2,  8,  9,  0, 10,  7,  4,  6,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:05<00:07,  5.36s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  5,  4,  7,  2,  1,  6,  3, 10,  0,  8])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:09<00:01,  4.84s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  5,  2,  7,  8,  4,  1,  6, 10,  0,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:15,  5.28s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 6,  5,  0,  7,  8, 10,  2,  3,  1,  4,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:20,  5.02s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  0,  9,  5,  8,  4,  6, 10,  7,  3,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:26,  5.53s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  6,  4,  7,  1,  5,  3,  2, 10,  8,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:31,  5.20s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([5, 8, 0, 3, 4, 7, 6])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<00:00,  2.34s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 2, -1, -1,  3,  4,  0,  6,  5,  1, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 7, 8, 3, 5, 0, 4])
texts shape: torch.Size([11, 77])


2it [00:04,  2.40s/it]                             


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  3,  6,  4,  0,  1,  2, -1, -1])
Epoch 4 train loss: 1.2371553182601929
Epoch 4 test loss: 11.927498817443848
Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 78.57%
Epoch 6/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  4,  6,  3,  2,  7,  5,  0,  1, 10,  8])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:04<00:06,  4.77s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  0, 10,  2,  9,  7,  4,  6,  8,  3,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:08<00:01,  4.44s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  7,  2,  4,  8,  1,  0,  6, 10,  5,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:13,  4.47s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1,  0,  4,  2,  3,  9,  5,  7,  8,  6, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:18,  4.61s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 5,  4,  0, 10,  3,  6,  9,  2,  1,  8,  7])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:22,  4.45s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  5,  9,  7,  1,  4,  3,  2,  0, 10,  6])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:27,  4.51s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([4, 7, 3, 8, 0, 5, 6])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<00:00,  2.45s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 4, -1, -1,  2,  0,  5,  6,  1,  3, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([8, 6, 3, 5, 4, 0, 7])
texts shape: torch.Size([11, 77])


2it [00:04,  2.45s/it]                             


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  2,  4,  3,  1,  6,  0, -1, -1])
Epoch 5 train loss: 1.2686586380004883
Epoch 5 test loss: 13.416112899780273
Mean Top 3 Accuracy: 92.86%
Mean Top 1 Accuracy: 78.57%
Epoch 7/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  9,  8,  6,  5,  2,  0,  4,  7,  1, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:04<00:06,  4.81s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  5,  7,  9, 10,  8,  1,  6,  2,  0,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:09<00:01,  4.53s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 6,  3,  9,  7,  2,  0,  5,  1,  8, 10,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:13,  4.44s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  5,  1,  7,  0,  6,  4, 10,  8,  2,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:19,  5.15s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 6,  0,  2,  1, 10,  7,  4,  5,  8,  3,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:23,  4.79s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  0,  3,  7,  6,  4, 10,  1,  2,  8,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:28,  4.70s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 7, 8, 3, 5, 4, 0])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<00:00,  2.37s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 6, -1, -1,  3,  5,  4,  0,  1,  2, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 5, 3, 7, 4, 8, 0])
texts shape: torch.Size([11, 77])


2it [00:05,  2.74s/it]                             


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 6, -1, -1,  2,  4,  1,  0,  3,  5, -1, -1])
Epoch 6 train loss: 1.1103264093399048
Epoch 6 test loss: 12.412967681884766
Mean Top 3 Accuracy: 100.00%
Mean Top 1 Accuracy: 85.71%
Epoch 8/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 2,  7,  9,  8,  0,  6, 10,  3,  1,  5,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:05<00:07,  5.48s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  2,  5,  6,  4,  1,  7,  0,  9,  3, 10])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:10<00:01,  5.08s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  7,  5, 10,  2,  3,  8,  0,  4,  1,  6])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:15,  5.13s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  4,  9,  0,  5,  3,  1,  6,  2, 10,  7])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:19,  4.70s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  2,  6,  7,  4, 10,  1,  0,  8,  3,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:23,  4.48s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 1, 10,  2,  8,  4,  3,  7,  9,  6,  0,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:28,  4.71s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([4, 5, 8, 6, 0, 3, 7])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<-1:59:59,  2.49s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 4, -1, -1,  5,  0,  1,  3,  6,  2, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 7, 5, 3, 4, 8, 0])
texts shape: torch.Size([11, 77])


2it [00:04,  2.45s/it]                                


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 6, -1, -1,  3,  4,  2,  0,  1,  5, -1, -1])
Epoch 7 train loss: 0.960955023765564
Epoch 7 test loss: 14.409374237060547
Mean Top 3 Accuracy: 100.00%
Mean Top 1 Accuracy: 85.71%
Epoch 9/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 9,  2,  1,  4,  7,  8,  3,  6,  5, 10,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:05<00:07,  5.95s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 3,  5,  1,  2,  4,  9,  6,  8, 10,  0,  7])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:11<00:01,  5.98s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 5,  8,  6,  3,  9,  0,  1, 10,  7,  2,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:16,  5.17s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([10,  7,  6,  4,  5,  9,  8,  0,  3,  1,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:20,  4.97s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 0,  7,  9,  8,  4, 10,  6,  1,  5,  2,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:25,  4.91s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 0,  4,  2,  9,  7,  8,  1,  5, 10,  6,  3])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:29,  4.99s/it]


------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([7, 5, 4, 3, 8, 0, 6])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:02<-1:59:59,  2.46s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 5, -1, -1,  3,  2,  1,  6,  0,  4, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([6, 8, 0, 3, 4, 5, 7])
texts shape: torch.Size([11, 77])


2it [00:04,  2.44s/it]                                


logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 2, -1, -1,  3,  4,  5,  0,  6,  1, -1, -1])
Epoch 8 train loss: 0.922077476978302
Epoch 8 test loss: 12.589924812316895
Mean Top 3 Accuracy: 100.00%
Mean Top 1 Accuracy: 92.86%
Epoch 10/10


  0%|          | 0/2.3125 [00:00<?, ?it/s]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 6,  4, 10,  5,  2,  1,  3,  8,  7,  0,  9])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 43%|████▎     | 1/2.3125 [00:05<00:07,  5.44s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 6, 10,  4,  8,  0,  2,  9,  7,  1,  3,  5])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


 86%|████████▋ | 2/2.3125 [00:10<00:01,  4.97s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 7,  3,  2,  8,  5, 10,  6,  4,  9,  1,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


3it [00:15,  5.05s/it]                            

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([10,  5,  1,  8,  3,  9,  6,  4,  7,  0,  2])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


4it [00:19,  4.73s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 2,  7,  1, 10,  8,  3,  5,  6,  0,  9,  4])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


5it [00:23,  4.66s/it]

images shape: torch.Size([11, 3, 224, 224])
label_ids shape: torch.Size([11])
label_ids: tensor([ 8,  5,  7, 10,  3,  4,  6,  1,  9,  2,  0])
text shape: torch.Size([11, 77])
logits_per_image shape: torch.Size([11, 11])
logits_per_text shape: torch.Size([11, 11])
ground_truth shape: torch.Size([11])
ground_truth: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


6it [00:28,  4.83s/it]


Saved weights under model_checkpoints/epoch_9.pt
------ Testing ------
num_classes: 11
dataset.classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']
classes: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


  0%|          | 0/0.59375 [00:00<?, ?it/s]

images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([4, 7, 0, 8, 3, 5, 6])
texts shape: torch.Size([11, 77])


168%|██████████| 1/0.59375 [00:03<-1:59:59,  3.74s/it]

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 2, -1, -1,  4,  0,  5,  6,  1,  3, -1, -1])
images shape: torch.Size([7, 3, 224, 224])
label_ids shape: torch.Size([7])
label_ids: tensor([8, 0, 7, 5, 3, 6, 4])
texts shape: torch.Size([11, 77])


2it [00:06,  3.22s/it]                                

logits_per_image shape: torch.Size([7, 11])
logits_per_text shape: torch.Size([11, 7])
ground_truth shape: torch.Size([11])
ground_truth_img: tensor([0, 1, 2, 3, 4, 5, 6])
ground_truth_txt: tensor([ 1, -1, -1,  4,  6,  3,  5,  2,  0, -1, -1])
Epoch 9 train loss: 0.8230394124984741
Epoch 9 test loss: 15.124711036682129
Mean Top 3 Accuracy: 100.00%
Mean Top 1 Accuracy: 85.71%





In [53]:
# Model inference

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# Load the fine-tuned weights
checkpoint_path = Path("model_checkpoints/epoch_9.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Load image
image_path = os.path.join("test_imgs/test_group.jpg")
image = Image.open(image_path)

# Preprocess image
image = preprocess(image).unsqueeze(0).to(device)

# Generate text prompt
keywords = ["chopping-board", "glass-bowl-large", "glass-bowl-medium", "glass-bowl-small", "group_step", "oven-dish", "oven-tray", "pan", "pot-one-handle", "pot-two-handles-medium", "pot-two-handles-small"]
text_prompts = [f"A photo of a {keyword}" for keyword in keywords]
tokenized_text = clip.tokenize(text_prompts).to(device)

# Generate features
with torch.no_grad():
    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True) # Normalize features
    
    text_features = model.encode_text(tokenized_text)
    text_features /= text_features.norm(dim=-1, keepdim=True) # Normalize features

# Compute cosine similarity
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Get predicted keyword
predicted_prob, predicted_keyword_idx = similarity.topk(1, dim=-1)

# Print prediction
predicted_keyword = keywords[predicted_keyword_idx.item()]
print(f"Predicted keyword: {predicted_keyword} with probability {predicted_prob.item() * 100:.2f}%")

Predicted keyword: chopping-board with probability 44.89%
