# ALL THE IMPORTS 
 - has all the imports

In [None]:
import os
import argparse
import warnings
import pickle

import torch
import torch.nn.functional as F
import numpy as np
import soundfile as sf
from tqdm.auto import tqdm
from diffusers import DDIMScheduler
from pytorch_lightning import seed_everything
from torch.optim import Adam

from typing import Optional, Union, Tuple, List, Callable, Dict

from Zstar.d_utils import ZstarAudioPipeline

# import ptp_utils

from helpers import load_audio_to_numpy, load_audio, get_audio_files, load_and_view_audio

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


# CONSTANTS

-   Constant declerations.
-   This is also where i replace the scheduler with DDIM

In [None]:
# Constants
TARGET_SR = 44100
CLIP_DURATION_SECONDS = 5.0
TARGET_SAMPLES = int(TARGET_SR * CLIP_DURATION_SECONDS)
TOTAL_STEPS = 100
GUIDANCE_SCALE = 7.5
AUDIO_MODEL_PATH = "./stable-audio-open-1.0"
SEED = 9999
# START_STEP = 5
# END_STEP = 15
# LAYER_INDEX = [14, 16, 18, 20, 22, 24]

# Setup
seed_everything(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Scheduler
inversion_scheduler = DDIMScheduler.from_pretrained(
    AUDIO_MODEL_PATH,
    subfolder="scheduler",
    prediction_type="v_prediction",
    clip_sample=False,
    set_alpha_to_one=False,
)

# Load model
model = ZstarAudioPipeline.from_pretrained(
    AUDIO_MODEL_PATH, scheduler=inversion_scheduler
).to(device)
# Replace inversion scheduler
model.inversion_scheduler = inversion_scheduler
model.inversion_scheduler.timesteps = model.inversion_scheduler.timesteps.to(device)
model.inversion_scheduler.alphas_cumprod = model.inversion_scheduler.alphas_cumprod.to(device)


# Null Inversion Class 

The content audio gets passed to the below class where the invert method first calls the ddim_inversion method and then null_optimisation for null prompting later.

This was taken from the Zero shot paper implementation adapted to audio

In [None]:
class NullInversion:
    def __init__(self, model, num_ddim_steps=50):
        self.model = model
        self.tokenizer = model.tokenizer
        self.prompt = None
        self.context = None
        self._ddim_scheduler = DDIMScheduler.from_pretrained(AUDIO_MODEL_PATH, subfolder="scheduler")
        self._ddim_scheduler.set_timesteps(num_ddim_steps)
        self._ddim_scheduler.timesteps = self._ddim_scheduler.timesteps.to(self.model.device)
        # 3. Explicitly move its internal `alphas_cumprod` tensor to the device
        self._ddim_scheduler.alphas_cumprod = self._ddim_scheduler.alphas_cumprod.to(self.model.device)

    def prev_step(self, noise_pred: torch.FloatTensor, timestep: int, latents: torch.FloatTensor):
        sched = self._ddim_scheduler
        prev_t = timestep - sched.config.num_train_timesteps // sched.num_inference_steps
        alpha_t = sched.alphas_cumprod[timestep]
        alpha_prev = (sched.alphas_cumprod[prev_t] if prev_t >= 0 else sched.final_alpha_cumprod)
        beta_t = 1 - alpha_t
        pred_x0 = (latents - beta_t**0.5 * noise_pred) / alpha_t**0.5
        dir_prev = (1 - alpha_prev)**0.5 * noise_pred
        z_prev = alpha_prev**0.5 * pred_x0 + dir_prev
        return z_prev, pred_x0

    def next_step(self, noise_pred: torch.FloatTensor, t: int, latents: torch.FloatTensor):
        # This is the DDIM inversion logic to ADD noise (t -> t+1)
        prev_t = t 
        next_t = t + self._ddim_scheduler.config.num_train_timesteps // self._ddim_scheduler.num_inference_steps
        alpha_t = self._ddim_scheduler.alphas_cumprod[prev_t]
        alpha_next = self._ddim_scheduler.alphas_cumprod[next_t] if next_t < len(self._ddim_scheduler.alphas_cumprod) else self._ddim_scheduler.final_alpha_cumprod
        beta_t = 1 - alpha_t
        pred_x0 = (latents - beta_t**0.5 * noise_pred) / alpha_t**0.5
        dir_next = (1 - alpha_next)**0.5 * noise_pred
        return alpha_next**0.5 * pred_x0 + dir_next

    def get_noise_pred_single(self, latents: torch.Tensor, t: Union[int, torch.Tensor], context: torch.Tensor, global_states: Optional[torch.Tensor] = None):
        if not isinstance(t, torch.Tensor): t = torch.tensor([t], device=self.model.device, dtype=torch.long)
        elif t.dim() == 0: t = t.unsqueeze(0).to(self.model.device)
        gs = global_states if global_states is not None else getattr(self, "global_states", None)
        noise_out = self.model.transformer(latents, t, encoder_hidden_states=context, global_hidden_states=gs, return_dict=False)[0]
        return noise_out

    @torch.no_grad()
    def audio2latent(self, audio: Union[torch.Tensor, np.ndarray]):
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).unsqueeze(0).to(self.model.device, dtype=torch.float32) # Ensure correct dtype
        if audio.dim() == 2:
            audio = audio.unsqueeze(1)
        
        # --- THIS IS THE FIX ---
        # First, get the VAE output object
        vae_output = self.model.vae.encode(audio)
        # Then, access the latent tensor and apply .float() to it
        latents = vae_output.latent_dist.mean.float()
        return latents

    @torch.no_grad()
    def latent2audio(self, latents: torch.Tensor, return_type: str = "pt"):
        wav = self.model.vae.decode(latents / 0.18215)["sample"].clamp(-1, 1)
        return wav if return_type == "pt" else wav[0, 0].cpu().numpy()

    def init_prompt(self, prompt: str):
        uc = self.model.tokenizer([""], padding="max_length", max_length=self.model.tokenizer.model_max_length, return_tensors="pt")
        uc_emb = self.model.text_encoder(uc.input_ids.to(self.model.device))[0]
        tc = self.model.tokenizer([prompt], padding="max_length", max_length=self.model.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        cond_emb = self.model.text_encoder(tc.input_ids.to(self.model.device))[0]
        self.context = torch.cat([uc_emb, cond_emb], dim=0)
        self.prompt = prompt

    @torch.no_grad()
    def ddim_inversion(self, audio: np.ndarray):
        wav_t = torch.from_numpy(audio).permute(1, 0).unsqueeze(0).to(self.model.device, dtype=torch.float32)
        cond_emb = self.context.chunk(2, dim=0)[1]
        z0 = self.audio2latent(wav_t) * 0.18215
        
        duration_s = audio.shape[0] / TARGET_SR
        ss, es = self.model.encode_duration(0.0, duration_s, self.model.device, do_classifier_free_guidance=False, batch_size=1)
        global_states_unprojected = torch.cat([ss, es], dim=-1)
        self.global_states = self.model.transformer.global_proj(global_states_unprojected)
        
        traj = [z0]
        z = z0.clone()
        for t in tqdm(reversed(self._ddim_scheduler.timesteps), desc="DDIM Inversion"):
            noise_pred = self.get_noise_pred_single(z, t, cond_emb, self.global_states)
            z = self.next_step(noise_pred, t, z)
            traj.append(z)
        return self.latent2audio(z0, return_type="pt"), traj

    def null_optimization(self, latents: list[torch.Tensor], num_inner_steps: int, epsilon: float):
        print("null optimisation")
        uncond_emb, cond_emb = self.context.chunk(2, dim=0)
        optimized = []
        z_cur = latents[-1].clone()
        for i, t in enumerate(tqdm(self._ddim_scheduler.timesteps, desc="Null-text Optimization")):
            u = uncond_emb.clone().detach().requires_grad_(True)
            optimizer = Adam([u], lr=1e-2 * (1 - i / len(latents)))
            z_prev_target = latents[len(latents) - i - 2]
            
            with torch.no_grad():
                noise_cond = self.get_noise_pred_single(z_cur, t, cond_emb, self.global_states)
            for _ in range(num_inner_steps):
                noise_uncond = self.get_noise_pred_single(z_cur, t, u, self.global_states)
                noise = noise_uncond + GUIDANCE_SCALE * (noise_cond - noise_uncond)
                z_prev_est, _ = self.prev_step(noise, t, z_cur)
                loss = F.mse_loss(z_prev_est, z_prev_target)
                optimizer.zero_grad(); loss.backward(); optimizer.step()
                if loss.item() < epsilon: break
            
            optimized.append(u.detach())
            
            with torch.no_grad():
                noise_uncond_final = self.get_noise_pred_single(z_cur, t, u, self.global_states)
                noise_final = noise_uncond_final + GUIDANCE_SCALE * (noise_cond - noise_uncond_final)
                z_cur, _ = self.prev_step(noise_final, t, z_cur)
        return optimized

    def invert(self, audio_path: str, prompt: str, num_inner_steps: int = 10, early_stop_epsilon: float = 1e-5, verbose: bool = False):
        self.init_prompt(prompt)
        
        # Not needed for reconstruction, but to be used later
        # ptp_utils.register_attention_control(self.model, None)
        
        
        if verbose: print("Running DDIM inversion…")
        wav_gt = load_audio_to_numpy(audio_path)
        audio_rec, traj = self.ddim_inversion(wav_gt)
        if verbose: print("Running null‐text optimization…")
        uncond_list = self.null_optimization(traj, num_inner_steps, early_stop_epsilon)
        return (wav_gt, audio_rec), traj, traj[-1], uncond_list

In [None]:
null_inversion = NullInversion(model, num_ddim_steps=TOTAL_STEPS)

# Main Loop

- This is where the reconstruction happens, I compare it with the ground truth

- This is also where the d_utils class is used, for reconstructions we only use the Invert method from the class and apply it on style data.

- Null Inversion class is solely used for content and model invert for style

In [None]:
content_files = get_audio_files("content")
style_files = get_audio_files("style")

for cpath in content_files:
    print(f"Processing content: {os.path.basename(cpath)}")
    pkl = cpath.rsplit('.',1)[0] + ".pkl"

    if os.path.isfile(pkl):
        # Reconstruction for content audio if exists
        cont_latents, zT_c, uncond_embeds, (wav_gt, audio_rec_c) = pickle.load(open(pkl,'rb'))

        print("Ground Truth")
        load_and_view_audio("content/content.wav")
        print("\n# ********************\n")
        print("Reconstruction")
        load_and_view_audio("content_recon.wav")

    else:
        (wav_gt, audio_rec_c), cont_latents, zT_c, uncond_embeds = null_inversion.invert(
            cpath, prompt="", verbose=True
        )
        sf.write('content_recon.wav', audio_rec_c.squeeze().cpu().numpy().T, TARGET_SR)

        # Reconstruction for content audio if does not exist
        print("Ground Truth")
        load_and_view_audio("content/content.wav")
        print("\n# ********************\n")
        print("Reconstruction")
        load_and_view_audio("content_recon.wav")

        pickle.dump((cont_latents, zT_c, uncond_embeds, (wav_gt, audio_rec_c)), open(pkl,'wb'))
        print("Saved content reconstruction: {content_recon.wav")

    for style_path in style_files:
        print(f"Processing style: {os.path.basename(style_path)}")

        zT_s, style_latents, z0_s = model.invert(
            load_audio(style_path),
            prompt="",
            num_inference_steps=TOTAL_STEPS,
            guidance_scale=7.0,
            return_intermediates=True,
        )

        wav_s_recon = model.latent2audio(z0_s)

        # Reconstruction for content audio if exists
        sf.write("style_recon.wav", wav_s_recon.squeeze().cpu().numpy().T, TARGET_SR)
        load_and_view_audio("style_recon.wav")

        print("\n# ********************\n")

        load_and_view_audio("s/style.wav")