In [3]:
import os
import json
import matplotlib.pyplot as plt
from scipy import stats
import pickle

from matplotlib import rcParams
import numpy as np
import seaborn as sns

rcParams["font.family"] = "serif"
rcParams["grid.linestyle"] = ':'
rcParams["xtick.direction"] = 'in'
rcParams["ytick.direction"] = 'in'
rcParams["legend.fontsize"] = 9
rcParams["axes.labelsize"] = 20
rcParams["axes.titlesize"] = 20
rcParams["xtick.labelsize"] = 15
rcParams["ytick.labelsize"] = 15

In [4]:
results_dir = '../t2i_evaluations'
search_counts_dir = '../search_counts'

In [5]:
t2i_results = [os.path.join(results_dir, x) for x in os.listdir(results_dir) if '.json' in x]

search_counts = [os.path.join(search_counts_dir, x) for x in os.listdir(search_counts_dir) if 'integrated' in x and 't2i' in x]

In [26]:
def save_plotting_info(curr_model, pretrained_dataset, prompt_type):

    datasets_included = [
        't2i'
    ]
    
    lemmatized_text_counts = []
    accuracies = []
    classnames = []
    
    for downstream_dataset in datasets_included:

        lemmatized_path = os.path.join(search_counts_dir, '{}_{}_integrated_tlemmatized_i0.7_search_counts.json'.format(downstream_dataset, pretrained_dataset))
    
        dataset = lemmatized_path.split('/')[-1].split('_')[0]
    
        res_path = os.path.join(results_dir, '{}_{}.json'.format(dataset, prompt_type))

        # load results
        with open(res_path, 'r') as f:
            zs_results = json.load(f)

        # for now just take the first model that satisfies the pt_dataset we have
        req_model = curr_model
        model_results = zs_results[req_model]['classwise']
    
        # load lowercased and lemmatized counts
        with open(lemmatized_path, 'r') as f:    
            lemmatized_json = json.load(f)

        intersecting_keys = set(list(lemmatized_json.keys())).intersection(list(set(model_results.keys())))

        for key in sorted(intersecting_keys):
            assert key in lemmatized_json, 'Key mismatch {} for {}'.format(key, dataset)
            if(len(model_results[key])<=5):
                continue
            classnames.append(key)
            lemmatized_text_counts.append(lemmatized_json[key])
            accuracies.append(np.mean(model_results[key]))

    lemmatized_text_counts_positive = [x if x > 0 else 1 for x in lemmatized_text_counts]

    x_vals = np.log(lemmatized_text_counts_positive)

    nb = 7
    
    bins = np.linspace(min(x_vals), max(x_vals), num=nb)
    assigned_bins = np.digitize(x_vals, bins, right=True)
    
    cumsums = [0]*len(bins)
    cumcounts = [0]*len(bins)
    cumarrs = {ab:[] for ab in assigned_bins}
    for acc, xv, ab in zip(accuracies, x_vals, assigned_bins):
        cumsums[ab] += acc
        cumcounts[ab] += 1
        cumarrs[ab].append(acc)
    cumaccs = [s/c if c > 0 else 0 for s, c in zip(cumsums, cumcounts)]

    cummeans = np.zeros(nb)
    cumstds = np.zeros(nb)
    
    for key in cumarrs:
        cummeans[key] = np.mean(cumarrs[key])
        cumstds[key] = np.std(cumarrs[key])
    
    os.makedirs('./plots', exist_ok=True)
    pickle.dump({'exp_bins': np.exp(bins), 'cum_means': cummeans, 'cum_stds': cumstds}, open('./plots/log-linear-all-datasets-plot-counttype_{}_prompttype_{}_ptdataset_{}_model_{}_t2i.pkl'.format('integrated_rampp0.7', prompt_type, pretrained_dataset, curr_model), 'wb'))

In [None]:
combinations = [
    'huggingface_openjourney-v1-0',
    'DeepFloyd_IF-I-L-v1.0',
    'huggingface_promptist-stable-diffusion-v1-4',
    'craiyon_dalle-mini',
    'huggingface_openjourney-v2-0',
    'adobe_giga-gan',
    'huggingface_vintedois-diffusion-v0-1',
    'huggingface_stable-diffusion-safe-strong',
    'DeepFloyd_IF-I-M-v1.0',
    'huggingface_stable-diffusion-safe-weak',
    'huggingface_stable-diffusion-v1-5',
    'huggingface_dreamlike-diffusion-v1-0',
    'huggingface_stable-diffusion-safe-medium',
    'huggingface_redshift-diffusion',
    'craiyon_dalle-mega',
    'lexica_search-stable-diffusion-1.5',
    'AlephAlpha_m-vader',
    'huggingface_stable-diffusion-v2-base',
    'DeepFloyd_IF-I-XL-v1.0',
    'huggingface_stable-diffusion-safe-max',
    'huggingface_stable-diffusion-v1-4',
    'kakaobrain_mindall-e',
    'huggingface_stable-diffusion-v2-1-base',
    'huggingface_dreamlike-photoreal-v2-0',
]

metrics = [
    'exp_aesthetics',
    'exp_clip',
    'human_aesthetics',
    'human_align',
    'max_clip',
    'max_aesthetics',
]

for comb in combinations:
    for metr in metrics:
        save_plotting_info(comb, 'laion_aesthetics', metr)