In [1]:
import os
import json
import matplotlib.pyplot as plt
from scipy import stats
import pickle

from matplotlib import rcParams
import numpy as np
import seaborn as sns

rcParams["font.family"] = "serif"
rcParams["grid.linestyle"] = ':'
rcParams["xtick.direction"] = 'in'
rcParams["ytick.direction"] = 'in'
rcParams["legend.fontsize"] = 9
rcParams["axes.labelsize"] = 20
rcParams["axes.titlesize"] = 20
rcParams["xtick.labelsize"] = 15
rcParams["ytick.labelsize"] = 15

## Load arguments

In [2]:
results_dir = '../retrieval_evaluations'
search_counts_dir = '../search_counts'

In [3]:
ret_results = [os.path.join(results_dir, x) for x in os.listdir(results_dir) if '_results.json' in x]
rampp_07_search_counts = [os.path.join(search_counts_dir, x) for x in os.listdir(search_counts_dir) if '0.7_' in x and 'integrated' in x]

## Function Fitting

In [4]:
def save_plotting_info(curr_model, pretrained_dataset, prompt_type):

    datasets_included = [
        'coco',
        'flickr'
    ]
    
    _7_counts = []
    accuracies = []
    classnames = []
    
    for downstream_dataset in datasets_included:
        
        _7path = os.path.join(search_counts_dir, '{}_{}_integrated_tlemmatized_i0.7_search_counts.json'.format(downstream_dataset, pretrained_dataset))
    
        dataset = _7path.split('/')[-1].split('_')[0]
    
        res_path = os.path.join(results_dir, '{}_{}_results.json'.format(dataset, prompt_type))
        # load results
        with open(res_path, 'r') as f:
            zs_results = json.load(f)
        # for now just take the first model that satisfies the pt_dataset we have
        for m_key in zs_results.keys():
            if pretrained_dataset in m_key:
                if curr_model in m_key:
                    req_model = m_key
                    break
        model_results = zs_results[req_model]['conceptwise']
    
        # load counts
        with open(_7path, 'r') as f:    
            _7json = json.load(f)
    
        for key in sorted(_7json):
            assert key in model_results, 'Key mismatch {} for {}'.format(key, dataset)
            classnames.append(key)
            _7_counts.append(_7json[key])
            accuracies.append(model_results[key])
    
    _7_counts_positive = [x if x > 0 else 1 for x in _7_counts]
       
    x_vals = np.log(_7_counts_positive)
    
    bins = np.linspace(min(x_vals), max(x_vals), num=7)
    assigned_bins = np.digitize(x_vals, bins, right=True)
    
    cumsums = [0]*len(bins)
    cumcounts = [0]*len(bins)
    cumarrs = {ab:[] for ab in assigned_bins}
    for acc, xv, ab in zip(accuracies, x_vals, assigned_bins):
        cumsums[ab] += acc
        cumcounts[ab] += 1
        cumarrs[ab].append(acc)
    cumaccs = [s/c if c > 0 else 0 for s, c in zip(cumsums, cumcounts)]
    cummeans = np.zeros(len(cumarrs))
    cumstds = np.zeros(len(cumarrs))
    for key in cumarrs:
        cummeans[key] = np.mean(cumarrs[key])
        cumstds[key] = np.std(cumarrs[key])
    
    os.makedirs('./plots', exist_ok=True)
    pickle.dump({'exp_bins': np.exp(bins), 'cum_means': cummeans, 'cum_stds': cumstds}, open('./plots/log-linear-all-datasets-plot-counttype_{}_prompttype_{}_ptdataset_{}_model_{}_retrieval.pkl'.format('integrated_rampp0.7', prompt_type, pretrained_dataset, curr_model), 'wb'))

In [None]:
combinations = [
    ('RN50', 'cc3m'),
    ('ViT-B-16', 'cc3m'),

    ('RN50', 'cc12m'),
    ('ViT-B-16', 'cc12m'),

    ('RN50', 'yfcc15m'),
    ('RN101', 'yfcc15m'),
    ('ViT-B-16', 'yfcc15m'),

    ('ViT-B-16', 'synthci30m'),

    ('ViT-B-32', 'laion200m_train_test_sim_normalized'),

    ('ViT-B-32', 'laion400m'),
    ('ViT-B-16', 'laion400m'),
    ('ViT-L-14', 'laion400m'),
]

metrics = [
    'i2t_k=1',
    'i2t_k=5',
    'i2t_k=10',
    't2i_k=1',
    't2i_k=5',
    't2i_k=10',
]

for comb in combinations:
    for metr in metrics:
        save_plotting_info(comb[0], comb[1], metr)