In [None]:
# ! pip install git+https://github.com/facebookresearch/ImageBind
! pip install soundfile
! pip install librosa
! pip install "imagebind @ git+https://github.com/facebookresearch/ImageBind@c6a47d6dc2b53eced51d398c181d57049ca59286"

In [None]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
data_path = "../../../../../Data/"

# Create triplet dataset

In [None]:

from datasets import load_dataset
import soundfile as sf
import requests
import pandas as pd
import os

In [None]:
triplet_dir = data_path + "imagebind/text-audio-image/"
if not os.path.exists(triplet_dir):
    # https://huggingface.co/datasets/agkphysics/AudioSet
    # dataset = load_dataset("agkphysics/AudioSet", "bal", split="test")
    dataset = load_dataset("arrow", split="test", data_files={"test":"../../../../../.cache/huggingface/datasets/agkphysics___audio_set/bal/audio_set-test-000**-of-00018.arrow"})#, split="test[0:10]")
    # https://huggingface.co/docs/datasets/v2.16.1/en/package_reference/main_classes#datasets.Dataset.filter
    dataset_it = dataset.to_iterable_dataset()

    animals = ["Bird", "Cat", "Dog", "Horse"]

    infos = []
    idx = 0
    for animal in animals:
        for example in dataset_it.filter(lambda x: animal in x["human_labels"]).take(25):
            # save audiofile
            sf.write(triplet_dir + "audio/%03d.wav"%idx, example["audio"]["array"], example["audio"]["sampling_rate"], format="wav")

            # save image
            # https://stackoverflow.com/questions/2068344/how-do-i-get-a-youtube-video-thumbnail-from-the-youtube-api
            url = 'https://img.youtube.com/vi/%s/hqdefault.jpg'%example["video_id"]
            data = requests.get(url).content 
            with open(data_dir + "image/%03d.jpg"%idx,'wb') as f:
                f.write(data) 

            # save info
            infos.append({"id": idx, "video_id":example["video_id"], "labels":example["human_labels"], "animal": animal})

            idx += 1

    pd.DataFrame(infos).to_csv(data_dir + "info.csv", index=False)

# Image-Text

In [None]:
from amumo import model as am_model
from amumo import data as am_data
from amumo import utils as am_utils
from amumo import widgets as am_widgets

In [None]:
# Data Helpers
def get_data_helper(dataset, filters=[], method=any):
    all_images, all_prompts = dataset.get_filtered_data(filters, method=method)
    print(len(all_images))

    dataset_name = dataset.name
    if len(filters) > 0:
        dataset_name = dataset_name + '_filter-' + method.__name__ + '_' + '-'.join(filters)
    else:
        dataset_name = dataset_name + '_size-%i'%len(all_images)

    return all_images, all_prompts, dataset_name

# subset of mscoco validation data
dataset_mscoco_val = am_data.MSCOCO_Val_Dataset(path=data_path+'mscoco/validation', batch_size=100) 
mscoco_val_images, mscoco_val_prompts, mscoco_val_dataset_name = get_data_helper(dataset_mscoco_val, filters=[], method=any)
mscoco_val_dataset_name

In [None]:
am_widgets.CLIPExplorerWidget(mscoco_val_dataset_name, all_data={"image": mscoco_val_images, "text": mscoco_val_prompts}, models=[am_model.ImageBind_Model(), "CLIP"])

In [None]:
am_widgets.CLIPComparerWidget(mscoco_val_dataset_name, all_images=mscoco_val_images, all_prompts=mscoco_val_prompts, models=[am_model.ImageBind_Model(), "CLIP"])

# Image-Text-Audio

In [None]:
from amumo import model as am_model
from amumo import data as am_data
from amumo import utils as am_utils
from amumo import widgets as am_widgets
# ! pip install torch
# ! pip install torchaudio
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
from glob import glob
from PIL import Image
import numpy as np
import torchaudio
import pandas as pd

class Triplet_Dataset(am_data.DatasetInterface):
    name='Triplet'

    def __init__(self, path, seed=31415, batch_size=100, sample_rate=16000):
        # create triplet dataset if it does not exist
        super().__init__(path, seed, batch_size)
        # path: path to the triplet dataset
        image_paths = glob(path + "image/*.jpg", recursive = True)
        audio_paths = glob(path + "audio/*.wav", recursive = True)

        self.sample_rate = sample_rate
        
        all_images = []
        for image_path in image_paths:
            with open(image_path, "rb") as fopen:
                image = Image.open(fopen).convert("RGB")
                all_images.append(image)

        all_audios = []
        for audio_path in audio_paths:
            waveform, sr = torchaudio.load(audio_path)
            if sample_rate != sr:
                waveform = torchaudio.functional.resample(
                    waveform, orig_freq=sr, new_freq=sample_rate
                )
            all_audios.append(waveform)
        
        self.all_infos = pd.read_csv(path + "info.csv", converters={"labels": lambda x: x.strip("[]").replace("'","").split(", ")})

        # TODO... load on demand with a custom loader
        self.all_images = np.array(all_images)
        self.all_prompts = np.array(self.all_infos["labels"].map(lambda x: ", ".join(x)))
        self.all_audios = np.array(all_audios)
    
    
    def get_data(self):
        
        if self.batch_size is None:
            images = self.MODE1_Type(self.all_images)
            texts = self.MODE2_Type(self.all_prompts)
            audios = am_data.AudioType(self.all_audios, self.sample_rate)
        
            return images, texts, audios

        # create a random batch
        batch_idcs = self._get_random_subsample(len(self.all_images))

        images = self.MODE1_Type(self.all_images[batch_idcs])
        texts = self.MODE2_Type(self.all_prompts[batch_idcs])
        audios = am_data.AudioType(self.all_audios[batch_idcs], self.sample_rate)
        
        return images, texts, audios
    
        
    def get_filtered_data(self, filter_list, method=any):
        # filter_list: a list of strings that are used for filtering
        # method: any -> any substring given in filter_list is present; all -> all substrings must be contained in the string
        if filter_list is None or len(filter_list) <= 0:
            return self.get_data()

        subset_ids = np.array([i for i in range(len(self.all_prompts)) if method(substring in self.all_prompts[i].lower() for substring in filter_list)])
        if len(subset_ids) <= 0:
            print("no filter matches found")
            return [], [], []
        
        # create a random batch
        batch_idcs = self._get_random_subsample(len(subset_ids))
        subset_ids = subset_ids[batch_idcs]
        
        images = self.MODE1_Type(self.all_images[subset_ids])
        texts = self.MODE2_Type(self.all_prompts[subset_ids])
        audios = am_data.AudioType(self.all_audios[subset_ids], self.sample_rate)
        return images, texts, audios


In [None]:
dataset = Triplet_Dataset(path=triplet_dir, batch_size=100)
all_images, all_prompts, all_audios = dataset.get_data()
print(len(all_images), len(all_prompts), len(all_audios))


In [None]:
my_widget = am_widgets.CLIPExplorerWidget("test_audio", all_data={"text": all_prompts, "audio": all_audios}, models=[am_model.ImageBind_Model()])
my_widget.scatter_widget.select_projection_method.value = "PCA"
my_widget

In [None]:
my_widget = am_widgets.CLIPExplorerWidget("test_audio", all_data={"text": all_prompts, "audio": all_audios, "image": all_images}, models=[am_model.ImageBind_Model()]) 
my_widget.scatter_widget.select_projection_method.value = "PCA"
my_widget

# Image-Thermal

In [None]:
from glob import glob
from PIL import Image
import numpy as np


class ThermalType(am_data.ImageType):
    name = "Thermal"

    def __init__(self, data) -> None:
        super().__init__(data)


class LLVIP_Dataset(am_data.DatasetInterface):
    name='LLVIP'

    def __init__(self, path, seed=31415, batch_size = 100):
        # download dataset: https://bupt-ai-cz.github.io/LLVIP/
        super().__init__(path, seed, batch_size)
        # path: path to the LLVIP dataset
        image_paths = glob(path + "visible/test/*.jpg", recursive = True)
        thermal_paths = glob(path + "infrared/test/*.jpg", recursive = True)
        
        batch_idcs = self._get_random_subsample(len(image_paths))
        image_paths = np.array(image_paths)[batch_idcs]
        thermal_paths = np.array(thermal_paths)[batch_idcs]
        
        all_images = []
        for image_path in image_paths:
            with open(image_path, "rb") as fopen:
                image = Image.open(fopen).convert("RGB")
                all_images.append(image)

            
        all_thermals = []
        
        for thermal_path in thermal_paths:
            with open(thermal_path, "rb") as fopen:
                thermal = Image.open(fopen).convert("L")
                all_thermals.append(thermal)

        self.MODE2_Type = ThermalType

        # TODO... load images and thermals on demand with a custom loader
        self.all_images = np.array(all_images)
        self.all_prompts = np.array(all_thermals)


In [None]:
# download dataset: https://bupt-ai-cz.github.io/LLVIP/
dataset = LLVIP_Dataset(path=data_path+"imagebind/LLVIP/", batch_size=100) 
all_images, all_thermals = dataset.get_data()
print(len(all_images), len(all_thermals))

am_widgets.CLIPExplorerWidget("test_thermal", all_data={"image": all_images, "thermal": all_thermals}, models=[am_model.ImageBind_Model()]) 

# Image-Depth

In [None]:
from glob import glob
from PIL import Image
import numpy as np
import torch
import io
import matplotlib.pyplot as plt


class DepthType(am_data.ImageType):
    name = "Depth"

    def __init__(self, data) -> None:
        # data is an array of (1,224,224) tensors
        super().__init__(data)

    def getVisItem(self, idx):
        output_img = io.BytesIO()
        plt.imsave(output_img, self.data[idx][0], cmap='gray')
        plt.savefig(output_img, format='JPEG')
        return {"displayType": am_data.DisplayTypes.IMAGE, "value": output_img.getvalue()}
    
    

class SUNRGBD_Dataset(am_data.DatasetInterface):
    name='SUNRGBD_NYU'

    def __init__(self, path, seed=31415, batch_size = 100):
        # download "SUNRGBD_V1" dataset from https://rgbd.cs.princeton.edu/
        super().__init__(path, seed, batch_size)
        # path: path to the SUNRGBD dataset
        image_paths = glob(path + "kv1/NYUdata/NYU*/fullres/*.jpg", recursive = True)
        depth_paths = glob(path + "kv1/NYUdata/NYU*/fullres/*.png", recursive = True)
        
        batch_idcs = self._get_random_subsample(len(image_paths))
        image_paths = np.array(image_paths)[batch_idcs]
        depth_paths = np.array(depth_paths)[batch_idcs]
        
        all_images = []
        for image_path in image_paths:
            with open(image_path, "rb") as fopen:
                image = Image.open(fopen).convert("RGB")
                all_images.append(image)

            
        all_depths = []
        for depth_path in depth_paths:
            with open(depth_path, "rb") as fopen:
                depth = Image.open(fopen)
                depth = np.array(depth, dtype=int)
                depth = depth.astype(np.float32) / depth.max() # TODO: need to normalize?
                # depth = depth[np.newaxis,:,:] # need 1 channel -> (1,224,224)
                depth = torch.from_numpy(depth).unsqueeze(0) # need 1 channel -> (1,224,224)
                all_depths.append(depth)

        self.MODE2_Type = DepthType

        # TODO... load images and depths on demand with a custom loader
        self.all_images = np.array(all_images)
        self.all_prompts = np.array(all_depths)



In [None]:
# download "SUNRGBD_V1" dataset from https://rgbd.cs.princeton.edu/
dataset = SUNRGBD_Dataset(path=data_path+"imagebind/SUNRGBD/", batch_size=100)
all_images, all_depths = dataset.get_data()
print(len(all_images), len(all_depths))

am_widgets.CLIPExplorerWidget("test_depth", all_data={"image": all_images, "depth": all_depths}, models=[am_model.ImageBind_Model()]) 