# Load the model

In [2]:
import torch
from torchvision import transforms, datasets
import torch.nn.functional as F
from tqdm import tqdm
from models.model_clip import CLIP
from dataset.caption_dataset import imagenet_templates
from transformers import AutoTokenizer, RobertaTokenizer
import os
import pandas as pd
from torch.utils.data import Dataset

  from .autonotebook import tqdm as notebook_tqdm


# Common

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
@torch.no_grad()
def zeroshot_transfer(model, data_loader, dataset_name, tokenizer, device):
    model.eval()

    if dataset_name == "imagenet":
        classes = [cls[0] for cls in data_loader.dataset.classes]
        templates = imagenet_templates
    else:
        print(f"===> Loading zeroshot transfer config for {dataset_name}")
        config = eval(open(f"zeroshot_transfer/{dataset_name}_classes.py", "r").read())
        classes, templates = config["classes"], config["templates"]

    text_embeddings = []
    for c in classes:  # Classes here are the string names of the classes
        texts = [template.format(c) for template in templates]
        text_inputs = tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=30,
            return_tensors="pt",
        ).to(device)
        text_outputs = model.text_encoder(
            text_inputs.input_ids,
            attention_mask=text_inputs.attention_mask,
            output_hidden_states=False,
        )
        text_embeds = F.normalize(
            model.text_proj(text_outputs.last_hidden_state[:, 0, :]), dim=-1
        )
        text_embed = text_embeds.mean(dim=0)
        text_embed /= text_embed.norm()
        text_embeddings.append(text_embed)

    text_embeddings = torch.stack(text_embeddings, dim=1).to(device)

    topk = [1, 3, 5, 10]
    correct = {k: 0 for k in topk}

    for image, label in tqdm(data_loader, desc="Evaluating zeroshot transfer"):
        image, label = image.to(device), label.to(
            device
        )  # label is the index (number) of the class
        image_feat = model.visual_encoder(image)
        image_embed = model.vision_proj(image_feat)
        image_embedding = F.normalize(image_embed, dim=-1)

        logits = image_embedding @ text_embeddings
        ranks = logits.topk(max(topk), 1)[1].T
        predictions = ranks == label

        for k in topk:
            correct[k] += torch.sum(torch.any(predictions[:k], dim=0)).item()

    results = {f"zeroshot_top{k}": correct[k] / data_loader.num_samples for k in topk}

    return results

In [5]:
def create_zeroshot_dataloader(dataset_name, data_folder, image_size, train=False):
    if dataset_name == "cifar10":
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset_name == "cifar100":
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif dataset_name == "mnist":
        mean = (0.1307, 0.1307, 0.1307)
        std = (0.3081, 0.3081, 0.3081)
    else:
        # For the following datasets:
        # imagenet, food101, flowers102, sun397, fgvc-aircraft
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    
    normalize = transforms.Normalize(mean=mean, std=std)

    val_transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.Grayscale(num_output_channels=3) if dataset_name == "mnist" else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            normalize,
        ]
    )

    if dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root=data_folder, 
            download=False, 
            train=train, 
            transform=val_transform
        )
        print(f"CIFAR10 classes: {dataset.class_to_idx}")
    elif dataset_name == "cifar100":
        dataset = datasets.CIFAR100(
            root=data_folder, 
            download=False, 
            train=train, 
            transform=val_transform
        )
        print(f"CIFAR100 classes: {dataset.class_to_idx}")
    elif dataset_name == "imagenet":
        dataset = datasets.ImageNet(
            root="/BS/dduka/work/data/imagenet1k/",
            split="val",
            transform=val_transform,
        )
        print(f"ImageNet classes: {dataset.class_to_idx}")
    elif dataset_name == "mnist":
        dataset = datasets.MNIST(
            root=data_folder, 
            download=False, 
            train=False, 
            transform=val_transform
        )
        print(f"MNIST classes: {dataset.class_to_idx}")
    elif dataset_name == "sun397":
        dataset = datasets.SUN397(
            root="/BS/databases03",
            transform=val_transform
        )
        print(f"SUN397 classes: {dataset.class_to_idx}")
    elif dataset_name == "fgvc-aircraft":
        dataset = datasets.FGVCAircraft(
            root="/scratch/inf0/user/dduka",
            split="test",
            download=True,
            transform=val_transform,
        )
        print(f"FGVC-Aircraft classes: {dataset.class_to_idx}")
    elif dataset_name == "oxford-pets":
        dataset = datasets.OxfordIIITPet(
            root="/scratch/inf0/user/dduka", 
            download=True, 
            transform=val_transform,
            split="test"
        )
        print(f"Oxford-IIIT Pet classes: {dataset.class_to_idx}")
    elif dataset_name == "flowers102":
        dataset = datasets.Flowers102(
            root="/scratch/inf0/user/dduka", 
            download=True, 
            split="test", 
            transform=val_transform
        )
    elif dataset_name == "eurosat":
        dataset = datasets.EuroSAT(
            root="/scratch/inf0/user/dduka", 
            download=True, 
            transform=val_transform,
        )
        print(f"EuroSAT classes: {dataset.class_to_idx}")
    else:
        dataset = datasets.ImageFolder(root=data_folder, transform=val_transform)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=256, shuffle=False, num_workers=8, pin_memory=True
    )

    data_loader.num_samples = len(dataset)

    return data_loader

# RN50 Model

In [59]:
rn50_args = {
    "image_encoder": "resnet50",
    "text_encoder": "distilbert-base-uncased",
    "embed_dim": 256,
    "init_model": True,
    "world_size": 1,
    "ita_type": "clip",
    "sogclr_gamma": 0.8,
    "rho": 8.0,
    "tau_init": 0.01,
    "temp": 0.01,
    "learnable_temp": False,
    "personalized_tau": False,
    "vicreg_sim_coeff": 25.0,
    "vicreg_std_coeff": 25.0,
    "N": -1, # don't care about this
    "proto_num": 256,
    "proto_std": 10.0,
    "upper_rho_plus": 0.0,
    "proto_weight": 1.0,
    "sinkhorn_eps": 0.05,
    "swav_temp": 0.1,
    "swav_weight": 1.0,
    "total_steps": -1, # don't care about this
    "sim_based_loss_alpha": 0.1,
    "sim_blend_ratio": 0.0,
    "clip_scheduled_loss_type": "none",
    "use_per_sample_temp": False,
    "include_unimodal_loss": False,
    "disable_temo_modulation": False,
    "disable_crossmodal_minfonce": False,
    "disable_i2i_temo_loss": False,
    "disable_t2t_temo_loss": False,
    "reversed_scheduler": False,
    "enable_non_modulated_unimodal_losses": False,
}

In [60]:
tokenizer = AutoTokenizer.from_pretrained(
    "distilbert-base-uncased", 
    local_files_only=False
)

In [61]:
vit_args = rn50_args.copy()
vit_args["image_encoder"] = "vit_base_patch16_224"

In [62]:
rn50_model = CLIP(
    **rn50_args
).to("cuda")
print(rn50_model)

CLIP(
  (visual_encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act2): ReLU(inplace=True)
        (aa): Identity()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, 

In [63]:
rn50_checkpoint_paths = [
    "/BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth",
    "/BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth",
    "/BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth"
]

In [64]:
cifar10_val_loader = create_zeroshot_dataloader(
    "cifar10", "/BS/dduka/work/projects/old_projects/TempNet/Bimodal_CL/cifar10", image_size=224, train=False
)

cifar100_val_loader = create_zeroshot_dataloader(
    "cifar100", "/BS/dduka/work/projects/old_projects/TempNet/Bimodal_CL/cifar100", image_size=224, train=False
)

mnist_val_loader = create_zeroshot_dataloader(
    "mnist", "/BS/dduka/work/projects/old_projects/TempNet/Bimodal_CL/mnist", image_size=224, train=False
)

sun397_val_loader = create_zeroshot_dataloader(
    "sun397", "/BS/databases03", image_size=224, train=False
)

fgvc_aircraft_val_loader = create_zeroshot_dataloader(
    "fgvc-aircraft", "/BS/databases07/fgvc-aircraft", image_size=224, train=False
)

oxford_pets_val_loader = create_zeroshot_dataloader(
    "oxford-pets", "/BS/databases05/", image_size=224, train=False
)

flowers102_val_loader = create_zeroshot_dataloader(
    "flowers102", 
    "",
    image_size=224,
    train=False
)   

eurosat_val_loader = create_zeroshot_dataloader(
    "eurosat",
    "",
    image_size=224,
    train=False
)

CIFAR10 classes: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
CIFAR100 classes: {'apple': 0, 'aquarium_fish': 1, 'baby': 2, 'bear': 3, 'beaver': 4, 'bed': 5, 'bee': 6, 'beetle': 7, 'bicycle': 8, 'bottle': 9, 'bowl': 10, 'boy': 11, 'bridge': 12, 'bus': 13, 'butterfly': 14, 'camel': 15, 'can': 16, 'castle': 17, 'caterpillar': 18, 'cattle': 19, 'chair': 20, 'chimpanzee': 21, 'clock': 22, 'cloud': 23, 'cockroach': 24, 'couch': 25, 'crab': 26, 'crocodile': 27, 'cup': 28, 'dinosaur': 29, 'dolphin': 30, 'elephant': 31, 'flatfish': 32, 'forest': 33, 'fox': 34, 'girl': 35, 'hamster': 36, 'house': 37, 'kangaroo': 38, 'keyboard': 39, 'lamp': 40, 'lawn_mower': 41, 'leopard': 42, 'lion': 43, 'lizard': 44, 'lobster': 45, 'man': 46, 'maple_tree': 47, 'motorcycle': 48, 'mountain': 49, 'mouse': 50, 'mushroom': 51, 'oak_tree': 52, 'orange': 53, 'orchid': 54, 'otter': 55, 'palm_tree': 56, 'pear': 57, 'pickup_truck': 58, 'pine_tre

In [65]:
loaders = [("cifar10", cifar10_val_loader),
           ("cifar100", cifar100_val_loader),
           ("mnist", mnist_val_loader),
           ("sun397", sun397_val_loader),
           ("fgvc-aircraft", fgvc_aircraft_val_loader),
           ("oxford-pets", oxford_pets_val_loader),
           ("flowers102", flowers102_val_loader),
           ("eurosat", eurosat_val_loader)]

In [66]:
df = pd.DataFrame(columns=["model", "dataset", "zeroshot_top1", "zeroshot_top3", "zeroshot_top5", "zeroshot_top10"])

for dataset_name, val_loader in loaders:
    print(f"Validation dataset: {dataset_name}, number of samples: {val_loader.num_samples}")
    
    for checkpoint_path in rn50_checkpoint_paths:
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        state_dict = checkpoint["model"]
        rn50_model.load_state_dict(state_dict, strict=False)
        print("Load checkpoint from %s" % checkpoint_path)
        print(f"Keys in checkpoint model: {checkpoint.keys()}")
        
        results = zeroshot_transfer(
            rn50_model,
            val_loader,
            dataset_name,
            tokenizer,
            device="cuda",
        )

        row = {
            "model": checkpoint_path.split("/")[-2],
            "dataset": dataset_name,
            **results
        }
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)

        print(f"Results for model {checkpoint_path.split('/')[-2]} on dataset {dataset_name}: {results}")

print(df)

Validation dataset: cifar10, number of samples: 10000


  checkpoint = torch.load(checkpoint_path, map_location="cpu")


Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar10


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.40it/s]
  df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)


Results for model r_clip_tau_0.01_lr_8e-4 on dataset cifar10: {'zeroshot_top1': 0.5473, 'zeroshot_top3': 0.8195, 'zeroshot_top5': 0.9137, 'zeroshot_top10': 1.0}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar10


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.45it/s]


Results for model r_clip_cos_0.01_0.05_lr_8e-4 on dataset cifar10: {'zeroshot_top1': 0.524, 'zeroshot_top3': 0.8147, 'zeroshot_top5': 0.9194, 'zeroshot_top10': 1.0}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar10


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.41it/s]


Results for model r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented on dataset cifar10: {'zeroshot_top1': 0.6657, 'zeroshot_top3': 0.915, 'zeroshot_top5': 0.9674, 'zeroshot_top10': 1.0}
Validation dataset: cifar100, number of samples: 10000
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar100


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.52it/s]


Results for model r_clip_tau_0.01_lr_8e-4 on dataset cifar100: {'zeroshot_top1': 0.2658, 'zeroshot_top3': 0.4348, 'zeroshot_top5': 0.5277, 'zeroshot_top10': 0.6412}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar100


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.29it/s]


Results for model r_clip_cos_0.01_0.05_lr_8e-4 on dataset cifar100: {'zeroshot_top1': 0.3032, 'zeroshot_top3': 0.4805, 'zeroshot_top5': 0.5586, 'zeroshot_top10': 0.6681}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for cifar100


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.53it/s]


Results for model r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented on dataset cifar100: {'zeroshot_top1': 0.3937, 'zeroshot_top3': 0.5999, 'zeroshot_top5': 0.6801, 'zeroshot_top10': 0.786}
Validation dataset: mnist, number of samples: 10000
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for mnist


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.28it/s]


Results for model r_clip_tau_0.01_lr_8e-4 on dataset mnist: {'zeroshot_top1': 0.1393, 'zeroshot_top3': 0.3501, 'zeroshot_top5': 0.5798, 'zeroshot_top10': 1.0}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for mnist


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.43it/s]


Results for model r_clip_cos_0.01_0.05_lr_8e-4 on dataset mnist: {'zeroshot_top1': 0.0984, 'zeroshot_top3': 0.3192, 'zeroshot_top5': 0.5563, 'zeroshot_top10': 1.0}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for mnist


Evaluating zeroshot transfer: 100%|██████████| 40/40 [00:05<00:00,  7.31it/s]


Results for model r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented on dataset mnist: {'zeroshot_top1': 0.1068, 'zeroshot_top3': 0.3474, 'zeroshot_top5': 0.5306, 'zeroshot_top10': 1.0}
Validation dataset: sun397, number of samples: 108754
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for sun397


Evaluating zeroshot transfer: 100%|██████████| 425/425 [06:39<00:00,  1.06it/s]


Results for model r_clip_tau_0.01_lr_8e-4 on dataset sun397: {'zeroshot_top1': 0.4182558802434853, 'zeroshot_top3': 0.6504220534417124, 'zeroshot_top5': 0.7376924067160748, 'zeroshot_top10': 0.8293947808816228}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for sun397


Evaluating zeroshot transfer: 100%|██████████| 425/425 [06:33<00:00,  1.08it/s]


Results for model r_clip_cos_0.01_0.05_lr_8e-4 on dataset sun397: {'zeroshot_top1': 0.4470180407157438, 'zeroshot_top3': 0.6702374165547934, 'zeroshot_top5': 0.7530573588097909, 'zeroshot_top10': 0.8377622892031558}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for sun397


Evaluating zeroshot transfer: 100%|██████████| 425/425 [06:32<00:00,  1.08it/s]


Results for model r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented on dataset sun397: {'zeroshot_top1': 0.445942218217261, 'zeroshot_top3': 0.6813174687827575, 'zeroshot_top5': 0.7662522757783622, 'zeroshot_top10': 0.8537249204626957}
Validation dataset: fgvc-aircraft, number of samples: 3333
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for fgvc-aircraft


Evaluating zeroshot transfer: 100%|██████████| 14/14 [00:12<00:00,  1.11it/s]


Results for model r_clip_tau_0.01_lr_8e-4 on dataset fgvc-aircraft: {'zeroshot_top1': 0.0111011101110111, 'zeroshot_top3': 0.0351035103510351, 'zeroshot_top5': 0.056405640564056406, 'zeroshot_top10': 0.12091209120912091}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_cos_0.01_0.05_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for fgvc-aircraft


Evaluating zeroshot transfer: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]


Results for model r_clip_cos_0.01_0.05_lr_8e-4 on dataset fgvc-aircraft: {'zeroshot_top1': 0.0144014401440144, 'zeroshot_top3': 0.0429042904290429, 'zeroshot_top5': 0.06900690069006901, 'zeroshot_top10': 0.13081308130813082}
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for fgvc-aircraft


Evaluating zeroshot transfer: 100%|██████████| 14/14 [00:07<00:00,  1.94it/s]


Results for model r_scheduled_clip_0.01_0.04_lr_2e-4_quad_crossmodal_and_unimodal_augmented on dataset fgvc-aircraft: {'zeroshot_top1': 0.012601260126012601, 'zeroshot_top3': 0.0408040804080408, 'zeroshot_top5': 0.06930693069306931, 'zeroshot_top10': 0.13021302130213022}
Validation dataset: oxford-pets, number of samples: 3669
Load checkpoint from /BS/dduka/work/training_metadata/bimodal_cl/dhimitrios/r_clip_tau_0.01_lr_8e-4/checkpoint_best.pth
Keys in checkpoint model: dict_keys(['model', 'optimizer', 'lr_scheduler', 'args', 'epoch'])
===> Loading zeroshot transfer config for oxford-pets


FileNotFoundError: [Errno 2] No such file or directory: 'zeroshot_transfer/oxford-pets_classes.py'

# ViT Model

In [None]:
vit50_model = CLIP(
    **vit_args
)
print(vit50_model)

In [8]:
data_folder = "/BS/scratch/inf0/user/dduka"

In [10]:
datasets.FER2013(
    root=data_folder, split="PublicTest", download=True
)

TypeError: FER2013.__init__() got an unexpected keyword argument 'download'