In [None]:
import argparse
import torch
import os

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO
import torchvision.transforms as T
from tqdm.notebook import tqdm

transform = T.ToPILImage()
import torch.optim as optim
import json
import numpy as np
import torch.nn.functional as F
import torchaudio
import torch.nn as nn
import imagebind.data as data
from IPython.display import Audio
import torchvision
from torchvision.transforms import transforms


In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        tensor = tensor.clone()
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor


In [None]:

import torch
import torchaudio
import torchaudio.transforms as T
from IPython.display import Audio


# Parameters
num_mel_bins = 128
num_frames = 204
sample_rate = 16000
n_fft = 400
hop_length = n_fft//4
win_length = n_fft

# Function to create a Mel inversion matrix
def create_mel_inversion_matrix(sr, n_fft, n_mels, fmin=0.0, fmax=None):
    # Create a Mel filter bank using torchaudio
    mel_fb = T.MelScale(n_mels, sr, f_min=fmin, f_max=fmax, n_stft=n_fft//2+1, norm=None)
    # Convert the filter bank to a tensor
    mel_fb_tensor = torch.tensor(mel_fb.fb, dtype=torch.float)
    # Calculate the pseudo inverse
    inversion_matrix = torch.pinverse(mel_fb_tensor)
    print(inversion_matrix.shape)
    
    return inversion_matrix

def inverse_it(mel_spectrogram):
    
    # Create the Mel inversion matrix
    inversion_matrix = create_mel_inversion_matrix(sample_rate, n_fft, num_mel_bins)

    # Invert the Mel spectrogram to a power spectrogram
    power_spectrogram = torch.matmul(mel_spectrogram, inversion_matrix)

    # Create an InverseMelScale transform
    inverse_mel_scale_transform = T.InverseMelScale(
        n_stft=n_fft//2+1,
        n_mels=num_mel_bins,
        sample_rate=sample_rate,
        f_min=0.0,
        f_max=sample_rate//2,
        norm=None
    )

    # Apply the InverseMelScale transform to the Mel spectrogram
    spectrogram = inverse_mel_scale_transform(mel_spectrogram.T)

    # Initialize Griffin-Lim transform
    griffin_lim = T.GriffinLim(n_fft=n_fft, n_iter=32, win_length=win_length, hop_length=hop_length)

    # Recover the waveform from the spectrogram
    recovered_waveform = griffin_lim(spectrogram)
    
    Audio(recovered_waveform, rate=16000)
    
    return recovered_waveform

def inverse_normalize(melspec, mean=-4.268, std=9.138):
    return melspec * std + mean


def combine_results(audio):
    results = list()
    for i in range(audio.shape[1]):
        res = inverse_it(inverse_normalize(audio.clone().detach().cpu().float()[0][i][0]).T)
        results.append(res)
    # return [results[0], results[1], results[2]]
    return [results[0][:-5000], results[1][:-5000], results[2][:-2000]]
    
from pydub import AudioSegment
import os

def get_results(audio_tensor):
    mel_spectrogram_np = audio_tensor[0, :, 0, :, :198].detach().clone().cpu().float()
    audio_cropped = torch.cat(tuple(mel_spectrogram_np), dim=-1)
    return inverse_it(inverse_normalize(audio_cropped.T))


def convert_mp3_to_wav(input_path, output_path, bitrate, duration, shift=0):
    # Load the MP3 file
    audio = AudioSegment.from_mp3(input_path)
    print(len(audio))
    # Set the desired duration
    audio = audio[shift:duration * 1000 + shift]

    # Set the desired bitrate
    audio = audio.set_frame_rate(bitrate)

    # Export the audio as a WAV file
    audio.export(output_path, format='wav')
    return audio

 

In [None]:
def custom_loader(path):
    image_tensor = data.load_and_transform_vision_data([path], 'cpu')
    return image_tensor

## You need access to the ImageNet validation dataset

In [None]:
dataset = torchvision.datasets.ImageNet('./data/imagenet/', split='val', loader=custom_loader)

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
test_transform = transforms.Compose([
            # transforms.Resize(256),
            # transforms.CenterCrop(224),
            transforms.Resize(132),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            normalize,
            
        ])

In [None]:
unnorm = UnNormalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])


In [None]:
len(dataset)

In [None]:
device = 'cuda:2'
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType



# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
0

In [None]:
all_embeds = list()

### Embed all dataset

In [None]:
for i in tqdm(range(50000)):
    image_tensor, _ = dataset[i]
    with torch.no_grad():
        embed = model({'vision': image_tensor.to(device)})
        all_embeds.append(embed['vision'].cpu())

In [None]:
catted_embeds = torch.cat(all_embeds, dim=0)
torch.save(catted_embeds, 'embeds.pt')

In [None]:
catted_embeds.shape

In [None]:
cos_sim =torch.zeros([50000, 50000], dtype=torch.float16)

In [None]:
del cos_sim

In [None]:
s = catted_embeds[:10]
z = catted_embeds[10:20]

In [None]:
for i, x in enumerate(dataset.classes):
    for entry in x:
        if 'sheep' in entry:
            print(i, x)
            break

In [None]:
for i, x in enumerate(dataset.classes):
    for entry in x:
        if 'sheep' in entry:
            print(i, x)
            break

In [None]:
for i, (lbl) in enumerate(dataset.targets):
    if lbl == 348:
        print(i)

In [None]:
dataset.classes[348]

In [None]:
mean_embeds = torch.zeros( [len(dataset.classes), 1024])
test_embeds = torch.zeros( [5000, 1024])
test_labels = torch.ones( [5000])

In [None]:
catted_embeds[ 50*(i+1)-2: 50*(i+1)].shape

In [None]:
for i in range(len(dataset.classes)):
    mean_embeds[i] = torch.mean(catted_embeds[ 50*i: 50*(i+1)-5] , dim=0)
    test_embeds[5*i: 5*(i+1)] = catted_embeds[ 50*(i+1) - 5 : 50*(i+1)]
    test_labels[5*i: 5*(i+1)] = i
        

In [None]:
dataset.classes[270]

In [None]:
dataset.classes[17439]

In [None]:

# text_list=["Everything we see hides another thing, we always want to see what is hidden by what we see, but it is impossible.",]
# image_paths=[".assets/car_image.jpg"] #".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
# image_paths = ['horse.jpg']
audio_paths=["all_assets/wolves.wav"] #"all_assets/police3.wav"] # ".assets/car_audio.wav", ".assets/bird_audio.wav"

# Load data
inputs = {
    # ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    # ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with torch.no_grad():
    # embeddings = imagebind(inputs)
    # text_embed = model.forward( {ModalityType.TEXT: data.load_and_transform_text(text_list, device)}, normalize=False)[ModalityType.TEXT] 
    audio_embed = model.forward( {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device)}, normalize=False)[ModalityType.AUDIO]
    # image_embed = model.forward({ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device)}, normalize=False)[ModalityType.VISION]


In [None]:
path = "./all_assets/wolves.wav"
orig_waveform, sr = torchaudio.load(path)
print(sr)
audio_tensor = data.load_and_transform_audio_data([path], 'cpu', )
# Audio(torch.cat(combine_results(audio_tensor)), rate=16000)
# Audio(orig_waveform, rate=sr)

In [None]:
X = 0.0001 * torch.rand_like(audio_tensor).to(device)
X.requires_grad_(True)
audio_tensor = audio_tensor.to(device)
0

In [None]:
epochs = 20000
optimizer = optim.SGD([X], lr=0.005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                              T_max = epochs, # Maximum number of iterations.
                              eta_min = 1e-5) # Minimum learning rate.

In [None]:
pbar = tqdm(range(epochs))
saved_dict = dict()

for i in pbar:
    lr = scheduler.get_last_lr()[0]

    embeds = model.forward({'audio': X + audio_tensor}, normalize=True)

    loss = 1 - F.cosine_similarity(embeds['audio'], ideal_embed, dim=1).mean()
    grads = torch.autograd.grad(outputs=loss, inputs=X)

    
    X = X - lr * grads[0].sign()
    X.detach().clamp_(min=-0.05, max=0.05)

    
    pbar.set_postfix({'loss': loss.item(), 'lr': lr, 'norm': X.detach().norm().item(), 'saved': list(saved_dict.keys())})
    
        

    scheduler.step()

    del grads, embeds, loss
    

In [None]:
(model.forward({'audio': (audio_tensor + X).detach()}, normalize=True)['audio'].detach().cpu() @ catted_embeds.T).max(dim=1)

In [None]:
(model.forward({'audio': audio_tensor.to(device)}, normalize=True)['audio'].detach().cpu() @ mean_embeds.T).max(dim=1)

In [None]:
import matplotlib.pyplot as plt
import librosa
import librosa.display

In [None]:
mel_spectrogram_np = audio_tensor.detach().cpu().squeeze(0).numpy()

# Split the 3 channels


In [None]:
mel_spectrogram_np = audio_tensor[0, :, 0, :, :198].detach().cpu()
channel_1, channel_2, channel_3 = mel_spectrogram_np

In [None]:
# Plot each of the 3 mel-spectrograms
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

ax.imshow(torch.cat(tuple(mel_spectrogram_np), dim=-1), aspect='auto', origin='lower', cmap='viridis')
ax.set_title(f'Mel-frequency spectrogram')
ax.set_ylabel('Mel bands')
# ax[i].colorbar()

ax.set_xlabel('Time frames')
plt.tight_layout()
plt.show()

In [None]:
Audio(get_results(audio_tensor), rate=16000)

In [None]:
Audio(get_results(X), rate=16000)

In [None]:
# Plot each of the 3 mel-spectrograms
fig, ax = plt.subplots(3, 1, figsize=(15, 10))

    ax[i].imshow(mel_spectrogram_np[i, 0], aspect='auto', origin='lower', cmap='viridis')
    ax[i].set_title(f'Mel-frequency spectrogram - Channel {i+1}')
    ax[i].set_ylabel('Mel bands')
    # ax[i].colorbar()

ax[-1].set_xlabel('Time frames')
plt.tight_layout()
plt.show()