# An easy-to-use notebook for beginners:
1. Define ImageBind model
2. Define Diffusion model
3. Image-conditioned audio generation
4. Image-conditioned audio editing

In [None]:
""" importing modules
"""
import torch
from PIL import Image
import numpy as np
import os
from omegaconf import OmegaConf
from easydict import EasyDict
import matplotlib.pyplot as plt

import ImageBind.data as data
from ImageBind.models import imagebind_model
from ImageBind.models.imagebind_model import ModalityType

from ldm.models.diffusion.ddpm import ImageEmbeddingConditionedLatentDiffusion
from ldm.models.diffusion.ddim import DDIMSampler

## 1. Define ImageBind model

In [None]:
class Binder:
    """ Wrapper for ImageBind model
    """
    def __init__(self, pth_path, device='cuda'):
        self.model = imagebind_model.imagebind_huge_pth(pretrained=True, pth_path=pth_path)
        self.device = device
        self.model.eval()
        self.model.to(device)

        self.data_process_dict = {ModalityType.TEXT: data.load_and_transform_text,
                                  ModalityType.VISION: data.load_and_transform_vision_data,
                                  ModalityType.AUDIO: data.load_and_transform_audio_data}

    def run(self, ctype, cpaths, post_process=False):
        """ ctype: str
            cpaths: list[str]
        """
        inputs = {ctype: self.data_process_dict[ctype](cpaths, self.device)}
        with torch.no_grad():
            embeddings = self.model(inputs, post_process=post_process)

        return embeddings[ctype]

device = 'cuda'
binder = Binder(pth_path="ImageBind/.checkpoints/imagebind_huge.pth", device=device)


## 2. Define Diffusion model (AudioLDM)

In [None]:


# options
opt = EasyDict(config = 'stablediffusion/configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml',
               device = device,
               ckpt = 'stablediffusion/checkpoints/sd21-unclip-h.ckpt',
               C = 4,
               H = 768,
               W = 768,
               f = 8,
               steps = 50, 
               n_samples = 1,
               scale = 20,
               ddim_eta = 0,
               )

config = OmegaConf.load(f"{opt.config}")
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
batch_size = opt.n_samples

# prepare diffusion model for audio
# if xtype == 'audio':
#             x = net.audioldm_decode(z)
#             x = self.mel_spectrogram_to_waveform(x)
#             return x
    # def mel_spectrogram_to_waveform(self, mel):
    #     # Mel: [bs, 1, t-steps, fbins]
    #     if len(mel.size()) == 4:
    #         mel = mel.squeeze(1)
    #     mel = mel.permute(0, 2, 1)
    #     waveform = self.net.audioldm.vocoder(mel)
    #     waveform = waveform.cpu().detach().numpy()
    #     return waveform

from audioldm import LatentDiffusion
# No normalization here
model = LatentDiffusion(**config["model"]["params"])
checkpoint = torch.load(opt.ckpt, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.to(opt.device)
model.eval()

model.cond_stage_model.embed_mode = "text"

# -----
# model = ImageEmbeddingConditionedLatentDiffusion(**config.model['params'])
# pl_sd = torch.load(opt.ckpt, map_location="cpu")
# sd = pl_sd["state_dict"]
# model.load_state_dict(sd, strict=False)
# model.to(opt.device)
# model.eval()

# sampler = DDIMSampler(model, device=opt.device)

In [None]:
def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2. * image - 1.

def load_audio(path):
    pass

## 3. Image-conditioned audio generation

In [None]:
prompts = ['colorful, DSLR quality, clear, vivid'] * batch_size    # you may add extra descriptions you like here

c_adm = binder.run(ctype=ModalityType.VISION, cpaths=['ImageBind/.assets/bird_audio.wav'], post_process=False)
# c_adm = binder.run(ctype='audio', cpaths=['ImageBind/.assets/bird_audio.wav'], post_process=False)
c_adm = c_adm / c_adm.norm() * 20   # a norm of 20 typically gives better result 
c_adm = torch.cat([c_adm] * batch_size, dim=0)

with torch.no_grad(), torch.autocast('cuda'):
    
    c_adm, noise_level_emb = model.noise_augmentor(c_adm, noise_level=torch.zeros(batch_size).long().to(c_adm.device))
    c_adm = torch.cat((c_adm, noise_level_emb), 1)

    uc = model.get_learned_conditioning(batch_size * ["text, watermark, blurry, number"])    # negative prompts
    uc = {"c_crossattn": [uc], "c_adm": torch.zeros_like(c_adm)}
    c = {"c_crossattn": [model.get_learned_conditioning(prompts)], "c_adm": c_adm}

    # samples, _ = sampler.sample(S=opt.steps,
    #                             conditioning=c,
    #                             batch_size=batch_size,
    #                             shape=shape,
    #                             verbose=False,
    #                             unconditional_guidance_scale=opt.scale,
    #                             unconditional_conditioning=uc,
    #                             eta=opt.ddim_eta,
    #                             x_T=None)
    # --- generate waveform ---
    # another usage is transfer_style
    ddim_steps=200
    duration=10
    batchsize=1
    guidance_scale=2.5
    n_candidate_gen_per_text=3
    waveform = None
    batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
    
    waveform = model.generate_sample(
            [batch],
            unconditional_guidance_scale=guidance_scale,
            ddim_steps=ddim_steps,
            n_candidate_gen_per_text=n_candidate_gen_per_text,
            duration=duration,
        )

# x_samples = model.decode_first_stage(samples)
# x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
# plt.imshow(x_samples[0].permute(1,2,0).cpu().numpy())

In [None]:
def text_to_audio(
    audioldm,
    text,
    original_audio_file_path = None,
    seed=42,
    ddim_steps=200,
    duration=10,
    batchsize=1,
    guidance_scale=2.5,
    n_candidate_gen_per_text=3,
    config=None,
):
    seed_everything(int(seed))
    waveform = None
    if(original_audio_file_path is not None):
        waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160)
        
    batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)

    audioldm.latent_t_size = duration_to_latent_t_size(duration)
    
    if(waveform is not None):
        print("Generate audio that has similar content as %s" % original_audio_file_path)
        audioldm = set_cond_audio(audioldm)
    else:
        print("Generate audio using text %s" % text)
        audioldm = set_cond_text(audioldm)
        
    with torch.no_grad():
        waveform = audioldm.generate_sample(
            [batch],
            unconditional_guidance_scale=guidance_scale,
            ddim_steps=ddim_steps,
            n_candidate_gen_per_text=n_candidate_gen_per_text,
            duration=duration,
        )
    return waveform

audioldm = build_model(model_name=args.model_name)
waveform = text_to_audio(
        audioldm,
        text,
        args.file_path,
        random_seed,
        duration=duration,
        guidance_scale=guidance_scale,
        ddim_steps=args.ddim_steps,
        n_candidate_gen_per_text=n_candidate_gen_per_text,
        batchsize=args.batchsize,
    )