In [1]:
from pytorch_lightning.trainer.trainer import Trainer
import itertools
from functools import partial

import sys
paths_to_add = ['..', 
                '../deep-learning-base', 
                '../deep-learning-base/training', 
                '../deep-learning-base/datasets',
                '../deep-learning-base/architectures', 
                '../deep-learning-base/attack',
                '../deep-learning-base/self_supervised']
for p in paths_to_add:
    if p not in sys.path:
        sys.path.append(p)

from training import LitProgressBar, NicerModelCheckpointing
import training.finetuning as ft
import architectures as arch
from architectures.callbacks import LightningWrapper, MultimodalEvalWrapper
from training.trainer_callback import ZeroShotCallback
from data_modules import DATA_MODULES
import dataset_metadata as dsmd
from partially_inverted_reps import DATA_PATH_IMAGENET, DATA_PATH

In [2]:
BASE_DATASET = 'clip' # used for mean and std and test-time transforms
EVAL_DATASETS = ['cifar10', 'cifar100', 'oxford-iiit-pets', 'flowers', 'stl10']
CLASS_PROMPTS = ['a photo of a ', 'this is a photo of a ', 'this is a photo of ', 
                 'the following is a photo of ', 'this is a ']
MODELS = ['resnet50', 'resnet101', 'vit_base_patch32_224', 'vit_base_patch16_224']
BATCH_SIZE = 512
DEVICES = [1]
ACCELERATOR = "gpu"

In [3]:
for model in MODELS:
    m1 = arch.create_model(model, BASE_DATASET, 
                           pretrained=True, checkpoint_path='', 
                           num_classes=dsmd.DATASET_PARAMS[BASE_DATASET]['num_classes'],
                           callback=partial(MultimodalEvalWrapper,
                                            dataset_name=BASE_DATASET),
                           multimodal_clip=True)
    for eval_dataset, class_prompt in itertools.product(EVAL_DATASETS, CLASS_PROMPTS):
        dm = DATA_MODULES[eval_dataset](
            data_dir=DATA_PATH_IMAGENET if 'imagenet' in eval_dataset else DATA_PATH,
            transform_train=dsmd.DATASET_PARAMS[BASE_DATASET]['transform_test'],
            transform_test=dsmd.DATASET_PARAMS[BASE_DATASET]['transform_test'],
            batch_size=BATCH_SIZE)
        dm.init_remaining_attrs(BASE_DATASET)
        
        m1._set_classes(dm.train_ds.classes)
        m1._set_class_prompt(class_prompt)
        
        t = Trainer(accelerator='gpu', devices=1, 
                    deterministic=True, num_sanity_val_steps=0)
        out = t.predict(self, 
                dataloaders=[dm.val_dataloader(),
                             dm.test_dataloader()])
        print (len(out))

Global seed set to 0


KeyError: 'loss'