# Export Data for ImageBind Analysis
### This notebook exports data for the ImageBind analysis done after the interactive article was accepted by VISxAI. 

In [None]:
! pip install git+https://github.com/ginihumer/Amumo.git

In [3]:
import amumo
from amumo import data as am_data
from amumo import utils as am_utils
from amumo import model as am_model

In [4]:
import os
def create_dir_if_not_exists(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)
    return dir

In [5]:
export_directory = './exported_data_checkpoints/'
create_dir_if_not_exists(export_directory)

'./exported_data_checkpoints/'

### Text-Image

In [34]:


def export_data(dataset_name, images, prompts, models):

    # create folder structure
    dataset_directory = create_dir_if_not_exists(export_directory + dataset_name)
    similarities_dir = create_dir_if_not_exists(dataset_directory + '/similarities')

    # export projections and similarities
    import torch
    from sklearn.decomposition import PCA
    from openTSNE import TSNE
    from umap import UMAP
    import numpy as np
    import pandas as pd
    import json

    # if there already exists a dataset with projections from prior exports, load it
    if not os.path.exists(dataset_directory + '/projections.csv'):
        projections_df = pd.DataFrame({'emb_id': list(np.arange(0,len(images),1))+list(np.arange(0,len(prompts),1)), 'data_type':['image']*len(images)+['text']*len(prompts)})
    else:
        projections_df = pd.read_csv(dataset_directory + '/projections.csv')
    

    for model in models:
        # compute embeddings
        image_embedding_gap, text_embedding_gap, logit_scale = am_utils.get_embedding(model, dataset_name, images, prompts)
        image_embedding_nogap, text_embedding_nogap = am_utils.get_closed_modality_gap(image_embedding_gap, text_embedding_gap)
        
        for image_embedding, text_embedding, mode in [(image_embedding_gap, text_embedding_gap, ''), (image_embedding_nogap, text_embedding_nogap, '_nogap')]:
            
            # compute similarities
            similarity_image_text, similarity = am_utils.get_similarity(image_embedding, text_embedding)
            np.savetxt('%s/%s%s.csv'%(similarities_dir,model.model_name,mode), similarity, delimiter=',')
            
            # compute meta information and similarity clustering
            meta_info = {}
            meta_info['gap_distance'] = float(am_utils.get_modality_distance(image_embedding, text_embedding))
            meta_info['loss'] = float(am_utils.calculate_val_loss(image_embedding, text_embedding, logit_scale.exp()))

            idcs, clusters, clusters_unsorted = am_utils.get_cluster_sorting(similarity_image_text)
            cluster_labels = []
            cluster_sizes = []
            for c in set(clusters):
                cluster_size = int(np.count_nonzero(clusters==c))
                cluster_label = am_utils.get_textual_label_for_cluster(np.where(clusters_unsorted==c)[0], prompts)
                cluster_labels.append(cluster_label)
                cluster_sizes.append(cluster_size)

            idcs_reverse = np.argsort(idcs)
            meta_info['cluster_sort_idcs'] = idcs.tolist()
            meta_info['cluster_sort_idcs_reverse'] = idcs_reverse.tolist()
            meta_info['cluster_sizes'] = cluster_sizes
            meta_info['cluster_labels'] = cluster_labels
            # print(meta_info)

            with open("%s/%s%s_meta_info.json"%(similarities_dir, model.model_name, mode), "w") as file:
                json.dump(meta_info, file)

            # compute projections
            embedding = np.array(torch.concatenate([image_embedding, text_embedding]))

            projection_methods = {
                'PCA': PCA,
                'UMAP': UMAP,
                'TSNE': TSNE
            }
            for method in projection_methods.keys():
                if method == 'PCA':
                    proj = projection_methods[method](n_components=2)
                else:
                    proj = projection_methods[method](n_components=2, metric='cosine', random_state=31415)
                
                if method == 'TSNE':
                    low_dim_data = proj.fit(embedding)
                else:
                    low_dim_data = proj.fit_transform(embedding)
                
                projections_df['%s%s_%s_x'%(model.model_name, mode, method)] = low_dim_data[:,0]
                projections_df['%s%s_%s_y'%(model.model_name, mode, method)] = low_dim_data[:,1]


    projections_df.to_csv(dataset_directory + '/projections.csv')

In [35]:

# reuse mscoco subset from previous analysis
from PIL import Image
import numpy as np

class Custom_Dataset(am_data.DatasetInterface):
    name = 'MSCOCO-Val'

    def __init__(self, path, seed=54, batch_size=None):
        # create triplet dataset if it does not exist
        super().__init__(path, seed, batch_size)
        # path: path to the triplet dataset
        image_paths = [path + "images/%i.jpg"%i for i in range(100)]

        all_images = []
        for image_path in image_paths:
            with open(image_path, "rb") as fopen:
                image = Image.open(fopen).convert("RGB")
                all_images.append(image)

        self.all_images = np.array(all_images)
        
        with open(path + "/prompts.txt", "r") as file:
            self.all_prompts = file.read().splitlines()

mscoco_val_dataset_name = "MSCOCO-Val_size-100"
dataset_mscoco_val = Custom_Dataset(export_directory + mscoco_val_dataset_name + '/')
mscoco_val_images, mscoco_val_prompts = dataset_mscoco_val.get_data()

  self.all_images = np.array(all_images)
  self.all_images = np.array(all_images)


In [37]:
export_data(mscoco_val_dataset_name, mscoco_val_images, mscoco_val_prompts, [am_model.ImageBind_Model()])

found cached embeddings for MSCOCO-Val_size-100_ImageBind_huge


### Text-Image-Audio

In [31]:
def export_data(dataset_name, images, prompts, audios, infos, models):

    # create folder structure
    dataset_directory = create_dir_if_not_exists(export_directory + dataset_name)
    images_dir = create_dir_if_not_exists(dataset_directory + '/images')
    audios_dir = create_dir_if_not_exists(dataset_directory + '/audios')
    similarities_dir = create_dir_if_not_exists(dataset_directory + '/similarities')

    # save images
    for i in range(len(images)):
        im = images[i]
        im.resize((400,400))
        im.save('%s/%i.jpg'%(images_dir,i))
        
    # save audios
    for i in range(len(audios)):
        import soundfile as sf
        sf.write('%s/%i.wav'%(audios_dir,i), audios[i][0], audios.sample_rate, format="wav")

    # save texts
    with open(dataset_directory + "/prompts.txt", "w") as file:
        for prompt in prompts:
            file.write(prompt + "\n")

    # save infos about youtube source and labels
    infos.to_csv(dataset_directory + "/infos.csv")

    # export projections and similarities
    import torch
    from sklearn.decomposition import PCA
    from openTSNE import TSNE
    from umap import UMAP
    import numpy as np
    import pandas as pd
    import json

    projections_df = pd.DataFrame({'emb_id': list(np.arange(0,len(images),1))+list(np.arange(0,len(prompts),1))+list(np.arange(0,len(audios),1)), 'data_type':['image']*len(images)+['text']*len(prompts)+['audio']*len(prompts)})


    for model in models:
        # compute embeddings
        embeddings, logit_scale = am_utils.get_embeddings_per_modality(model, dataset_name, {"image": images, "text": prompts, "audio": audios})
        
        # compute similarities
        similarity = am_utils.get_similarities_all(embeddings)
        np.savetxt('%s/%s.csv'%(similarities_dir,model.model_name), similarity, delimiter=',')
        
        # compute meta information and similarity clustering
        # compute gap_distances and losses for each modality pair
        gap_distances = {}
        losses = {}
        for modality_1, modality_2 in [("audio", "image"), ("text", "image"), ("text", "audio")]:
            gap_distances['%s_%s'%(modality_1, modality_2)] = float(am_utils.get_modality_distance(embeddings[modality_1], embeddings[modality_2]))
            losses['%s_%s'%(modality_1, modality_2)] = float(am_utils.calculate_val_loss(embeddings[modality_1], embeddings[modality_2], logit_scale.exp()))
            
        meta_info = {}
        meta_info['gap_distance'] = gap_distances
        meta_info['loss'] = losses

        # compute cluster sorting for modality pair (including in-modal pairs)
        all_clusters = {}
        import itertools
        for modality_1, modality_2 in itertools.product(["audio", "image", "text"], ["audio", "image", "text"]):
            similarity_by_modalities = am_utils.get_similarities(torch.from_numpy(embeddings[modality_1]), torch.from_numpy(embeddings[modality_2]))
            idcs, clusters, clusters_unsorted = am_utils.get_cluster_sorting(similarity_by_modalities)
            cluster_labels = []
            cluster_sizes = []
            for c in set(clusters):
                cluster_size = int(np.count_nonzero(clusters==c))
                cluster_label = am_utils.get_textual_label_for_cluster(np.where(clusters_unsorted==c)[0], prompts)
                cluster_labels.append(cluster_label)
                cluster_sizes.append(cluster_size)

            idcs_reverse = np.argsort(idcs)
            cluster_dict = {}
            cluster_dict['cluster_sort_idcs'] = idcs.tolist()
            cluster_dict['cluster_sort_idcs_reverse'] = idcs_reverse.tolist()
            cluster_dict['cluster_sizes'] = cluster_sizes
            cluster_dict['cluster_labels'] = cluster_labels
            all_clusters['%s_%s'%(modality_1, modality_2)] = cluster_dict

        meta_info['clusters'] = all_clusters
        print(meta_info)

        with open("%s/%s_meta_info.json"%(similarities_dir, model.model_name), "w") as file:
            json.dump(meta_info, file)

        # compute projections
        embedding = np.concatenate(list(embeddings.values()))

        projection_methods = {
            'PCA': PCA,
            'UMAP': UMAP,
            'TSNE': TSNE
        }
        for method in projection_methods.keys():
            if method == 'PCA':
                proj = projection_methods[method](n_components=2)
            else:
                proj = projection_methods[method](n_components=2, metric='cosine', random_state=31415)
            
            if method == 'TSNE':
                low_dim_data = proj.fit(embedding)
            else:
                low_dim_data = proj.fit_transform(embedding)
            
            projections_df['%s_%s_x'%(model.model_name, method)] = low_dim_data[:,0]
            projections_df['%s_%s_y'%(model.model_name, method)] = low_dim_data[:,1]


    projections_df.to_csv(dataset_directory + '/projections.csv')

In [32]:
from glob import glob
from PIL import Image
import numpy as np
import torchaudio
import pandas as pd

class Triplet_Dataset(am_data.DatasetInterface):
    name='Triplet'

    def __init__(self, path, seed=31415, batch_size=100, sample_rate=16000):
        # create triplet dataset if it does not exist
        super().__init__(path, seed, batch_size)
        # path: path to the triplet dataset
        image_paths = glob(path + "image/*.jpg", recursive = True)
        audio_paths = glob(path + "audio/*.wav", recursive = True)

        self.sample_rate = sample_rate
        
        all_images = []
        for image_path in image_paths:
            with open(image_path, "rb") as fopen:
                image = Image.open(fopen).convert("RGB")
                all_images.append(image)

        all_audios = []
        for audio_path in audio_paths:
            waveform, sr = torchaudio.load(audio_path)
            if sample_rate != sr:
                waveform = torchaudio.functional.resample(
                    waveform, orig_freq=sr, new_freq=sample_rate
                )
            all_audios.append(waveform)
        
        self.all_infos = pd.read_csv(path + "info.csv", converters={"labels": lambda x: x.strip("[]").replace("'","").split(", ")})

        # TODO... load on demand with a custom loader
        self.all_images = np.array(all_images)
        self.all_prompts = np.array(self.all_infos["labels"].map(lambda x: ", ".join(x)))
        self.all_audios = np.array(all_audios)
        
    
    
    def get_data(self):
        # create a random batch
        batch_idcs = self._get_random_subsample(len(self.all_images))

        images = self.MODE1_Type(self.all_images[batch_idcs])
        texts = self.MODE2_Type(self.all_prompts[batch_idcs])
        audios = am_data.AudioType(self.all_audios[batch_idcs], self.sample_rate)
        
        return images, texts, audios, self.all_infos.iloc[batch_idcs].reset_index(drop=True)
    


In [33]:
data_path = '../../../../../Data/'
triplet_dir = data_path + "imagebind/text-audio-image/"

dataset = Triplet_Dataset(path=triplet_dir, batch_size=100)
all_images, all_prompts, all_audios, all_infos = dataset.get_data()
print(len(all_images), len(all_prompts), len(all_audios))
export_data("%s_size-%i"%(dataset.name, dataset.batch_size), all_images, all_prompts, all_audios, all_infos, [am_model.ImageBind_Model()])

  self.all_images = np.array(all_images)
  self.all_images = np.array(all_images)
  self.all_audios = np.array(all_audios)
  self.all_audios = np.array(all_audios)


100 100 100
batch 1 of 1
batch 1 of 1
batch 1 of 1


  linkage = sch.linkage(1-similarity, method='complete')


{'gap_distance': {'audio_image': 0.6844899275524431, 'text_image': 0.876292816992076, 'text_audio': 0.7458446635745786}, 'loss': {'audio_image': 3.7037110545737204, 'text_image': 3.815758442215615, 'text_audio': 3.805993391119964}, 'clusters': {'audio_audio': {'cluster_sort_idcs': [52, 2, 89, 15, 77, 33, 67, 79, 80, 85, 43, 22, 81, 27, 75, 36, 23, 18, 40, 51, 84, 94, 71, 24, 57, 88, 11, 19, 44, 69, 98, 42, 53, 54, 70, 66, 74, 72, 73, 0, 38, 5, 6, 7, 93, 9, 39, 13, 92, 17, 87, 25, 26, 29, 61, 76, 14, 60, 95, 97, 65, 49, 56, 55, 20, 50, 47, 28, 30, 35, 62, 4, 12, 16, 86, 34, 31, 41, 37, 68, 46, 58, 45, 83, 21, 3, 8, 82, 1, 96, 63, 10, 90, 64, 78, 32, 48, 59, 91, 99], 'cluster_sort_idcs_reverse': [39, 88, 1, 85, 71, 41, 42, 43, 86, 45, 91, 26, 72, 47, 56, 3, 73, 49, 17, 27, 64, 84, 11, 16, 23, 51, 52, 13, 67, 53, 68, 76, 95, 5, 75, 69, 15, 78, 40, 46, 18, 77, 31, 10, 28, 82, 80, 66, 96, 61, 65, 19, 0, 32, 33, 63, 62, 24, 81, 97, 57, 54, 70, 90, 93, 60, 35, 6, 79, 29, 34, 22, 37, 38, 36, 1