# Export Data for CLOOM Analysis
### This notebook exports data for the CLOOM analysis done after the interactive article was accepted by VISxAI. 

In [None]:
! pip install git+https://github.com/ginihumer/Amumo.git

In [None]:
import amumo
from amumo import data as am_data
from amumo import utils as am_utils
from amumo import model as am_model

In [None]:
# Data Helpers
def get_data_helper(dataset, filters=[], method=any):
    all_images, all_prompts = dataset.get_filtered_data(filters, method=method)
    print(len(all_images))

    dataset_name = dataset.name
    if len(filters) > 0:
        dataset_name = dataset_name + '_filter-' + method.__name__ + '_' + '-'.join(filters)
    else:
        dataset_name = dataset_name + '_size-%i'%len(all_images)

    return all_images, all_prompts, dataset_name

In [None]:
import os
def create_dir_if_not_exists(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)
    return dir

In [None]:
export_directory = './exported_data_checkpoints/'
create_dir_if_not_exists(export_directory)

def export_data(dataset_name, images, prompts, models):

    # create folder structure
    dataset_directory = create_dir_if_not_exists(export_directory + dataset_name)
    images_dir = create_dir_if_not_exists(dataset_directory + '/images')
    similarities_dir = create_dir_if_not_exists(dataset_directory + '/similarities')

    # save images
    for i in range(len(images)):
        im = images[i]
        im.resize((400,400))
        im.save('%s/%i.jpg'%(images_dir,i))

    # save texts
    with open(dataset_directory + "/prompts.txt", "w") as file:
        for prompt in prompts:
            file.write(prompt + "\n")

    # export projections and similarities
    import torch
    from sklearn.decomposition import PCA
    from openTSNE import TSNE
    from umap import UMAP
    import numpy as np
    import pandas as pd
    import json

    projections_df = pd.DataFrame({'emb_id': list(np.arange(0,len(images),1))+list(np.arange(0,len(prompts),1)), 'data_type':['image']*len(images)+['text']*len(prompts)})


    for model in models:
        # compute embeddings
        image_embedding_gap, text_embedding_gap, logit_scale = am_utils.get_embedding(model, dataset_name, images, prompts)
        image_embedding_nogap, text_embedding_nogap = am_utils.get_closed_modality_gap(image_embedding_gap, text_embedding_gap)
        
        for image_embedding, text_embedding, mode in [(image_embedding_gap, text_embedding_gap, ''), (image_embedding_nogap, text_embedding_nogap, '_nogap')]:
            
            # compute similarities
            similarity_image_text, similarity = am_utils.get_similarity(image_embedding, text_embedding)
            np.savetxt('%s/%s%s.csv'%(similarities_dir,model.model_name,mode), similarity, delimiter=',')
            
            # compute meta information and similarity clustering
            meta_info = {}
            meta_info['gap_distance'] = float(am_utils.get_modality_distance(image_embedding, text_embedding))
            meta_info['loss'] = float(am_utils.calculate_val_loss(image_embedding, text_embedding, logit_scale.exp()))

            idcs, clusters, clusters_unsorted = am_utils.get_cluster_sorting(similarity_image_text)
            cluster_labels = []
            cluster_sizes = []
            for c in set(clusters):
                cluster_size = int(np.count_nonzero(clusters==c))
                cluster_label = prompts.getMinSummary(np.where(clusters_unsorted==c)[0]) # am_utils.get_textual_label_for_cluster(np.where(clusters_unsorted==c)[0], prompts)
                cluster_labels.append(cluster_label)
                cluster_sizes.append(cluster_size)

            idcs_reverse = np.argsort(idcs)
            meta_info['cluster_sort_idcs'] = idcs.tolist()
            meta_info['cluster_sort_idcs_reverse'] = idcs_reverse.tolist()
            meta_info['cluster_sizes'] = cluster_sizes
            meta_info['cluster_labels'] = cluster_labels
            # print(meta_info)

            with open("%s/%s%s_meta_info.json"%(similarities_dir, model.model_name, mode), "w") as file:
                json.dump(meta_info, file)

            # compute projections
            embedding = np.array(torch.concatenate([image_embedding, text_embedding]))

            projection_methods = {
                'PCA': PCA,
                'UMAP': UMAP,
                'TSNE': TSNE
            }
            for method in projection_methods.keys():
                if method == 'PCA':
                    proj = projection_methods[method](n_components=2)
                else:
                    proj = projection_methods[method](n_components=2, metric='cosine', random_state=31415)
                
                if method == 'TSNE':
                    low_dim_data = proj.fit(embedding)
                else:
                    low_dim_data = proj.fit_transform(embedding)
                
                projections_df['%s%s_%s_x'%(model.model_name, mode, method)] = low_dim_data[:,0]
                projections_df['%s%s_%s_y'%(model.model_name, mode, method)] = low_dim_data[:,1]


    projections_df.to_csv(dataset_directory + '/projections.csv')

In [None]:
data_path = '../../../../../Data/'

# TODO: update to your dataset
dataset_diffusiondb = am_data.DiffusionDB_Dataset(path="2m_first_1k", batch_size=100)
diffusiondb_images, diffusiondb_prompts, diffusiondb_dataset_name = get_data_helper(dataset_diffusiondb)
# TODO: update to your model
export_data(diffusiondb_dataset_name, diffusiondb_images, diffusiondb_prompts, [am_model.CLIPModel()])
