In [1]:
import os
import sys

import torch

sys.path.append(f'{os.path.dirname(os.getcwd())}/')
import warnings
import yaml

from tart.tart_modules import Tart
from tart.registry import DATASET_REGISTRY

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='tqdm')

In [2]:
#### CUSTOMIZE AS NEEDED ####
path_tart_weights = '/home/blee/code/TART/tart_heads/head_d16_s258_m10_bs128_bd32_vno_false/'
model_chkpt = "model_24000.pt"
cache_dir = '/home/blee/code/TART/cache'
path_tart_config = 'tart_conf.yaml'  # if you are using the pre-trained module above, don't change this!
data_dir_path = '/home/blee/code/TART/data'

In [3]:
# EMBED_MODELS = ["EleutherAI/gpt-neo-125m"]
# EMBED_MODELS = ["google/vit-large-patch16-224-in21k"]
EMBED_MODELS = ["google/vit-large-patch16-224-in21k", "EleutherAI/gpt-neo-125m"]
EMBED_METHODS = ["stream", "stream"]
PATH_TO_PRETRAINED_HEAD = f"{path_tart_weights}/model_24000.pt"
TART_CONFIG = yaml.load(open(path_tart_config, "r"), Loader=yaml.FullLoader)
TOTAL_TRAIN_SAMPLES = TART_CONFIG['n_positions'] - 2
PATH_TO_FINETUNED_EMBED_MODEL = None
CACHE_DIR = cache_dir
NUM_PCA_COMPONENTS = 16

In [4]:
tart_module = Tart(
    embed_model_names=EMBED_MODELS,
    path_to_pretrained_head=PATH_TO_PRETRAINED_HEAD,
    tart_head_config=TART_CONFIG,
    embed_methods=EMBED_METHODS,
    combination_method="average",
    num_pca_components=NUM_PCA_COMPONENTS,
    path_to_finetuned_embed_model=PATH_TO_FINETUNED_EMBED_MODEL,
    cache_dir=CACHE_DIR,
    evaluate_modality_idx=None, # 0 for image, 1 for text, None for both
)

loading model


In [5]:
DATASET_NAME = "red_caps"
DOMAIN = "multi_image_text"
k_range = [18, 32, 48, 64]
max_eval_samples = 1000
pos_class = 0
neg_class = 1

In [6]:
dataset = DATASET_REGISTRY[DOMAIN][DATASET_NAME](
    total_train_samples=TOTAL_TRAIN_SAMPLES,
    k_range=k_range,
    seed=0,
    cache_dir=CACHE_DIR,
    pos_class=pos_class,
    neg_class=neg_class,
    data_dir_path=data_dir_path,
)

X_train, y_train, X_test, y_test = dataset.get_dataset

Found cached dataset red_caps (/home/blee/code/TART/cache/red_caps/blacksmith/1.0.0/d0d70a901e22f5e3b9a7af1f96f31c6243589705a5ab782b9ac69fcf727d97be)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset red_caps (/home/blee/code/TART/cache/red_caps/dogpictures/1.0.0/d0d70a901e22f5e3b9a7af1f96f31c6243589705a5ab782b9ac69fcf727d97be)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /home/blee/code/TART/cache/red_caps/blacksmith/1.0.0/d0d70a901e22f5e3b9a7af1f96f31c6243589705a5ab782b9ac69fcf727d97be/cache-c06d30dbf1df2152.arrow
Loading cached shuffled indices for dataset at /home/blee/code/TART/cache/red_caps/dogpictures/1.0.0/d0d70a901e22f5e3b9a7af1f96f31c6243589705a5ab782b9ac69fcf727d97be/cache-99c89bd6d71725a6.arrow


Downloading RedCaps images (this may take a while)...)


100%|██████████████████████████| 800/800 [00:19<00:00, 40.34it/s]


In [7]:
results_at_k = {}
with torch.no_grad():
    for k in k_range:
        result = tart_module.evaluate(
            X_train=X_train,
            y_train=y_train,
            X_test=X_test,
            y_test=y_test,
            k=k,
            seed=0,
            # text_threshold=1000,
        )
        results_at_k[k] = result
        print(f"Accuracy at {k} samples: {result['accuracy']}")


Embedding ICL examples...
Embedding modality 0


100%|████████████████████████████| 18/18 [00:00<00:00, 27.50it/s]
100%|██████████████████████████| 591/591 [00:06<00:00, 89.36it/s]


Embedding modality 1
Combining embeddings
Predicting labels...


591it [00:03, 170.93it/s]


Accuracy at 18 samples: 0.7411167512690355
Embedding ICL examples...
Embedding modality 0


100%|████████████████████████████| 32/32 [00:00<00:00, 90.28it/s]
100%|██████████████████████████| 591/591 [00:06<00:00, 89.13it/s]


Embedding modality 1
Combining embeddings
Predicting labels...


591it [00:03, 170.34it/s]


Accuracy at 32 samples: 0.7749576988155669
Embedding ICL examples...
Embedding modality 0


100%|████████████████████████████| 48/48 [00:00<00:00, 90.42it/s]
100%|██████████████████████████| 591/591 [00:06<00:00, 89.12it/s]


Embedding modality 1
Combining embeddings
Predicting labels...


591it [00:03, 172.39it/s]


Accuracy at 48 samples: 0.8206429780033841
Embedding ICL examples...
Embedding modality 0


100%|████████████████████████████| 64/64 [00:00<00:00, 90.35it/s]
100%|██████████████████████████| 591/591 [00:06<00:00, 89.08it/s]


Embedding modality 1
Combining embeddings
Predicting labels...


591it [00:03, 172.64it/s]

Accuracy at 64 samples: 0.8815566835871405



