In [1]:
import json
import numpy as np
import os
from tqdm import tqdm
import copy

In [2]:
PROMPT_PATH = "../prompts"

## Create Subprompts

In [3]:
datasets = {
    "dtd": [2, 4, 8, 16, 32, 46, 60], 
    "fgvc_aircraft": [2, 4, 8, 14, 20], 
    "sun397": [2, 4, 8, 16, 23, 30], 
    "flowers": [2, 4, 8, 14, 20],
    "imagenet1k": [16, 32, 64, 100],
}

n_seeds = 10

In [4]:
for dataset in ["imagenet1k"]:
    with open(f"../prompts/{dataset}/{dataset}_llama3_prompts_full.json", "r") as f:
        prompt_dict = json.load(f)[dataset]

    for sample_size in datasets[dataset]:
        for seed in tqdm(range(n_seeds)):
            # output = {classname.replace("(", "").replace(")", ""): [] for classname in prompt_dict}
            output = {classname: [] for classname in prompt_dict}
            for c, classname in enumerate(prompt_dict):
                idx = np.random.RandomState(seed + c).choice(datasets[dataset][-1], sample_size, replace=False)
                for i in idx:
                    # output[classname.replace("(", "").replace(")", "")].append(prompt_dict[classname][i])
                    output[classname].append(prompt_dict[classname][i])
            fp = os.path.join(PROMPT_PATH, f"{dataset}/sample_size_{sample_size:02d}_seed_{seed}.json") 
            with open(fp, 'w') as file:
                json.dump({dataset: output}, file, indent=4)

100%|██████████| 10/10 [00:02<00:00,  4.94it/s]
100%|██████████| 10/10 [00:02<00:00,  4.55it/s]
100%|██████████| 10/10 [00:02<00:00,  3.68it/s]
100%|██████████| 10/10 [00:03<00:00,  3.05it/s]


## Create ImageNet-1k Prompts

In [13]:
class_path = os.path.join(PROMPT_PATH, "imagenet1k/classes")
prompt_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
classnames = [f.split("imagenet1k_llama3_prompts_")[1].split(".json")[0] for f in prompt_files]
classnames

['coral_reef',
 'Bedlington Terrier',
 'water buffalo',
 'forklift',
 'african grey parrot',
 'chambered nautilus',
 'shipwreck',
 'diaper',
 'Bullmastiff',
 'greenhouse',
 'reflex_camera',
 'Pembroke Welsh Corgi',
 'chain_mail',
 'slide_rule',
 'tights',
 'cardoon',
 'hammerhead shark',
 'hot_tub',
 'harvestman',
 'shower_curtain',
 'measuring_cup',
 'grasshopper',
 'croquet_ball',
 'vacuum_cleaner',
 'Ibizan Hound',
 'tent',
 'gas_pump',
 'European garden spider',
 'mitten',
 'Basset Hound',
 'Shetland Sheepdog',
 'mosque',
 'German Shepherd Dog',
 'water snake',
 'spiral_or_coil',
 'analog clock',
 'sea snake',
 'laptop_computer',
 'hornbill',
 'hand-held_computer',
 'drilling_rig',
 'mashed_potatoes',
 'Standard Poodle',
 'snowplow',
 'ladle',
 'barbell',
 'eraser',
 'pretzel',
 'kite bird of prey',
 'overskirt',
 'submarine',
 'scarf',
 'hair_clip',
 'Bouvier des Flandres dog',
 'home_theater',
 'eggnog',
 'brain coral',
 'beach',
 'brass_memorial_plaque',
 'bagel',
 'sturgeon',
 

In [14]:
output = {"imagenet1k": {class_: [] for class_ in classnames}}

In [None]:
for class_ in classnames:
    with open(os.path.join(class_path, f"imagenet1k_llama3_prompts_{class_}.json"), 'r') as f:
        prompts = json.load(f)[class_]
    output["imagenet1k"][class_] = prompts
    

In [17]:
with open(os.path.join(PROMPT_PATH, f"imagenet1k/imagenet1k_llama3_prompts_full.json"), 'w') as f:
    json.dump(output, f)

## Match ImageNet-Prompts

In [3]:
with open(os.path.join(PROMPT_PATH, f"imagenet1k/imagenet1k_prompts_full.json"), 'r') as f:
    prompts_open_clip = json.load(f)["imagenet1k"]

with open(os.path.join(PROMPT_PATH, f"imagenet1k/imagenet1k_llama3_prompts_full.json"), 'r') as f:
    prompts_llama = json.load(f)["imagenet1k"]

# prompts_llama_temp = copy.deepcopy(prompts_llama)

In [6]:
list_llama = sorted(list(prompts_llama.keys()))
list_llama

['Affenpinscher',
 'Afghan Hound',
 'African bush elephant',
 'African rock python',
 'African wild dog',
 'Airedale Terrier',
 'Alaskan Malamute',
 'Alaskan tundra wolf',
 'Alpine ibex',
 'American Staffordshire Terrier',
 'American alligator',
 'American black bear',
 'American bullfrog',
 'American coot',
 'American dipper',
 'American lobster',
 'American robin',
 'Angora rabbit',
 'Appenzeller Sennenhund',
 'Arctic fox',
 'Asian elephant',
 'Australian Kelpie',
 'Australian Silky Terrier',
 'Australian Terrier',
 'Band-Aid',
 'Basenji',
 'Basset Hound',
 'Beagle',
 'Bedlington Terrier',
 'Bernese Mountain Dog',
 'Black and Tan Coonhound',
 'Bloodhound',
 'Bluetick Coonhound',
 'Border Collie',
 'Border Terrier',
 'Boston Terrier',
 'Bouvier des Flandres dog',
 'Boxer',
 'Briard',
 'Brittany dog',
 'Bullmastiff',
 'CD player',
 'CRT monitor',
 'Cairn Terrier',
 'Cardigan Welsh Corgi',
 'Carolina anole',
 'Chesapeake Bay Retriever',
 'Chihuahua',
 'Chow Chow',
 'Christmas stocking',

In [16]:
list_open_clip = sorted(list(prompts_open_clip.keys()))
list_llama = sorted(list(prompts_llama.keys()))

for p_true, p_curr in zip(list_open_clip, list_llama):
    if p_true != p_curr:
        prompts_llama[p_true] = prompts_llama.pop(p_curr)

In [17]:
list_llama_new = sorted(list(prompts_llama.keys()))

count = 0
for p_true, p_curr in zip(list_open_clip, list_llama_new):
    if p_true != p_curr:
        count += 1

count

0

In [18]:
with open(os.path.join(PROMPT_PATH, f"imagenet1k/imagenet1k_llama3_prompts_full.json"), 'w') as f:
    json.dump({"imagenet1k": prompts_llama}, f, indent=2)