In [None]:
import sys, os

# sys.path.insert(0, os.path.abspath("../local_diffusers/src"))

from typing import Callable, List, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
import soundfile as sf

# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
# from diffusers.pipelines.stable_audio.pipeline_stable_audio import StableAudioPipeline

from diffusers import StableAudioPipeline, DDIMScheduler
from tqdm import tqdm
import matplotlib.pyplot as plt


class ZstarAudioPipeline(StableAudioPipeline):
    def __init__(
        self, vae, text_encoder, projection_model, tokenizer, transformer, scheduler
    ):
        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            projection_model=projection_model,
            tokenizer=tokenizer,
            transformer=transformer,
            scheduler=scheduler,
        )

    def step(
        self,
        model_output: torch.Tensor,  # this is v_t
        timestep: int,
        sample: torch.Tensor,        # this is x_t
    ):
        # ---- 1) figure out your "backward" timestep decrement ----
        dt = self.scheduler.num_train_timesteps // TOTAL_STEPS
        t_prev = max(timestep - dt, 0)
    
        # ---- 2) grab alphas ----
        alpha_t      = self.scheduler.alphas_cumprod[timestep]
        alpha_t_prev = (
            self.scheduler.alphas_cumprod[t_prev]
            if t_prev >= 0
            else self.scheduler.final_alpha_cumprod
        )
        beta_t = 1 - alpha_t
    
        # ---- 3) predict x0 and eps from v ----
        x0  = alpha_t**0.5 * sample - beta_t**0.5 * model_output
        eps = alpha_t**0.5 * model_output + beta_t**0.5 * sample
    
        # ---- 4) step backward ----
        prev_sample = alpha_t_prev**0.5 * x0 + (1 - alpha_t_prev)**0.5 * eps
    
        return prev_sample, x0

    @torch.no_grad()
    def latent2audio(self, latents: torch.FloatTensor, return_type: str = "np"):
        """
        Decode a latent back into a waveform.
        – latents: (B, C_latent, T_latent)
        – returns either a torch.Tensor (B, C_audio, T_audio)
          or a numpy array (T_audio, C_audio) for the first batch item.
        """
        # 1) undo the scaling
        vae_scale_factor = 0.3704
        latents = latents.detach() / vae_scale_factor
        # 2) decode
        audio = self.vae.decode(latents)["sample"]
        # audio is now a torch.Tensor of shape (B, C, T)
        if return_type == "np":
            # clamp into valid [-1,1] range and convert
            audio = audio.clamp(-1, 1)
            # move channels axis last and pick batch 0
            audio_np = audio.cpu().permute(0, 2, 1).numpy()[0]
            return audio_np
        return audio

    @torch.no_grad()
    def audio2latent(self, audio: Union[np.ndarray, torch.Tensor]):
        """
        Encode a waveform into VAE latents.
        – audio: either a np.ndarray (T, C) or torch.Tensor (C, T) / (B, C, T)
        – returns a torch.FloatTensor of latents (B, C_latent, T_latent)
        """
        # 1) to torch
        # if isinstance(audio, np.ndarray):
        #     # assume shape (T, C) → (C, T)
        #     audio = torch.from_numpy(audio).permute(1, 0)
        # now audio is Tensor of shape (C, T) or (B, C, T)
        if audio.dim() == 2:
            audio = audio.unsqueeze(0)    # → (1, C, T)
        audio = audio.to(self.device).float()
        # 2) encode
        latents = self.vae.encode(audio)["latent_dist"].mean
        # 3) apply the same scale factor
        vae_scale_factor = 0.3704
        latents = latents * vae_scale_factor
        return latents

    

    @torch.no_grad()
    def __call__(
        self,
        prompt,
        batch_size=1,
        num_inference_steps=100,
        guidance_scale=7.5,
        eta=0.0,
        latents=None,
        unconditioning=None,
        uncond_embeddings=None,
        neg_prompt=None,
        ref_intermediate_latents=None,
        return_intermediates=False,
        **kwds,
    ):
        DEVICE = (
            torch.device("cuda") if torch.cuda.is_available(
            ) else torch.device("cpu")
        )
        
        if isinstance(prompt, list):
            batch_size = len(prompt)
        elif isinstance(prompt, str):
            if batch_size > 1:
                prompt = [prompt] * batch_size


        # text embeddings, we pass 64 manually because we changed it in the config from 128 to 64
        text_input = self.tokenizer(
            prompt, padding="max_length", max_length=64, return_tensors="pt"
        )

        text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]

        # We probably never go into this, but check this later
        if kwds.get("dir"):
            dir = text_embeddings[-2] - text_embeddings[-1]
            u, s, v = torch.pca_lowrank(
                dir.transpose(-1, -2), q=1, center=True)
            text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v

        # define initial latents
        latents_shape = (batch_size, self.transformer.in_channels, latents.shape[2]) #Todo: remove hardcoded length
        
        if latents is None:
            latents = torch.randn(latents_shape, device=DEVICE)
        else:
            assert (
                latents.shape == latents_shape
            ), f"The shape of input latent tensor {latents.shape} should equal to predefined one."

        print(latents.shape)


        # unconditional embedding for classifier free guidance
        if guidance_scale > 1.0:
            max_length = text_input.input_ids.shape[-1]
            if neg_prompt:
                uc_text = neg_prompt
            else:
                uc_text = ""
            if uncond_embeddings is None:
                uncond_input = model.tokenizer(
                    [""] * batch_size,
                    padding="max_length",
                    max_length=max_length,
                    return_tensors="pt",
                )
                uncond_embeddings_ = model.text_encoder(
                    uncond_input.input_ids.to(model.device)
                )[0]
            else:
                uncond_embeddings_ = None

        # iterative sampling
        self.scheduler.set_timesteps(num_inference_steps)
        latents_list = [latents]
        pred_x0_list = [latents]

        for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
            if uncond_embeddings_ is None:
                context = torch.cat(
                    [
                        uncond_embeddings[i].expand(*text_embeddings.shape),
                        text_embeddings,
                    ]
                )
            else:
                context = torch.cat([uncond_embeddings_, text_embeddings])
            if ref_intermediate_latents is not None:
                # note that the batch_size >= 2
                style_latents_ref = ref_intermediate_latents[1][-1 - i]
                _, content_latents_cur = latents.chunk(2)
                content_latents_cur = (
                    ref_intermediate_latents[0][-1 - i] * 0.01
                    + content_latents_cur * 0.99
                )
                latents = torch.cat([style_latents_ref, content_latents_cur])

            if guidance_scale > 1.0:
                model_inputs = torch.cat([latents] * 2)
            else:
                model_inputs = latents
            if unconditioning is not None and isinstance(unconditioning, list):
                _, context = context.chunk(2)
                context = torch.cat(
                    [unconditioning[i].expand(*context.shape), context])
            # predict tghe noise
            timestep = t.expand(model_inputs.shape[0]).to(self.device)
            # print(f"{context.shape=}")
            # print(f"{timestep.shape=}")
            # print(f"{model_inputs.shape=}")

            duration_s = model_inputs.shape[-1] / self.vae.sampling_rate
            ss, es = self.encode_duration(0.0, duration_s, self.device, do_classifier_free_guidance=False, batch_size=1)
            self.global_states = self.transformer.global_proj(torch.cat([ss, es], dim=-1))
            self.global_states = self.global_states.repeat(4,1,1)
            # print("self.global_states", self.global_states.shape) # [4, 1, 1536]


            noise_pred = self.transformer(
                model_inputs, timestep, encoder_hidden_states=context,global_hidden_states=self.global_states
            )[0]
            
            if guidance_scale > 1.0:
                noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
                noise_pred = noise_pred_uncon + guidance_scale * (
                    noise_pred_con - noise_pred_uncon
                )


            # compute the previous noise sample x_t -> x_t-1
            latents, pred_x0 = self.step(noise_pred, t, latents)
            latents_list.append(latents)
            pred_x0_list.append(pred_x0)


        audio = self.latent2audio(latents, return_type="pt")
        if return_intermediates:
            pred_x0_list = [
                self.latent2audio(aud, return_type="pt") for aud in pred_x0_list
            ]
            latents_list = [
                self.latent2audio(aud, return_type="pt") for aud in latents_list
            ]
            return audio, pred_x0_list, latents_list
        return audio
        
    


In [None]:
import os, sys

# sys.path.insert(0, "local_diffusers/src")

import argparse
import warnings
import pickle

import torch
import torch.nn.functional as F
import numpy as np
import soundfile as sf
from tqdm.auto import tqdm
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers import DDIMScheduler, DDIMInverseScheduler
# from diffusers.schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from pytorch_lightning import seed_everything
from torch.optim import Adam

from IPython.display import display, Markdown

from typing import Optional, Union, Tuple, List, Callable, Dict

# from Zstar_test.d_utils import ZstarAudioPipeline
# from model import ZstarAudioPipeline
from Zstar.Zstar_utils import AttentionBase, register_attention_editor_audio, AttentionStore
from Zstar.Zstar import ReweightCrossAttentionControl, ReweightAndStoreAttentionControl
# from attention import ReweightCrossAttentionControl
# from attention_register import register_attention_editor_diffusers

import ptp_utils

from helpers import load_audio_to_numpy, load_audio, get_audio_files, load_and_view_audio

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


In [14]:
# Constants
TARGET_SR = 44100
CLIP_DURATION_SECONDS = 5.0
TARGET_SAMPLES = int(TARGET_SR * CLIP_DURATION_SECONDS)
TOTAL_STEPS = 100
NUM_DDIM_STEPS = 100
GUIDANCE_SCALE = 1.1
AUDIO_MODEL_PATH = "./stable-audio-open-1.0"
SEED = 9999
START_STEP = 15
END_STEP = 50
LAYER_INDEX = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]

# Setup
seed_everything(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Scheduler
scheduler = DDIMScheduler.from_pretrained(
    AUDIO_MODEL_PATH,
    subfolder="scheduler",
    prediction_type="v_prediction",
    clip_sample=False,
    set_alpha_to_one=False,
)

# Load model
model = ZstarAudioPipeline.from_pretrained(
    AUDIO_MODEL_PATH, scheduler=scheduler
).to(device)

model.scheduler.set_timesteps(TOTAL_STEPS)
model.scheduler.timesteps = model.scheduler.timesteps.to(device)
model.scheduler.alphas_cumprod = model.scheduler.alphas_cumprod.to(device)

model.to(torch.float32)

Seed set to 9999
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 21.22it/s]


ZstarAudioPipeline {
  "_class_name": "ZstarAudioPipeline",
  "_diffusers_version": "0.35.0.dev0",
  "_name_or_path": "./stable-audio-open-1.0",
  "projection_model": [
    "stable_audio",
    "StableAudioProjectionModel"
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "T5EncoderModel"
  ],
  "tokenizer": [
    "transformers",
    "T5TokenizerFast"
  ],
  "transformer": [
    "diffusers",
    "StableAudioDiTModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderOobleck"
  ]
}

In [15]:
import torch
import numpy as np
from tqdm.auto import tqdm
from torch.optim import Adam
import torch.nn.functional as F
from diffusers import DDIMScheduler, DDIMInverseScheduler

# Your helper functions
# from helpers import load_audio_to_numpy, load_audio
# import ptp_utils

class NullInversion:
    def __init__(self, model):
        self.model = model
        self.tokenizer = model.tokenizer
        self.prompt = None
        self.context = None
        self.global_states = None
        self.cond_embed = None
        self.device = device
        self.vae_scale_factor = 0.3704

    @torch.no_grad()
    def init_prompt(self, prompt: str):
        uncond_input = self.model.tokenizer(
            [""],
            padding="max_length",
            max_length=self.model.tokenizer.model_max_length,
            return_tensors="pt",
        )
        uncond_embeddings = self.model.text_encoder(
            uncond_input.input_ids.to(self.model.device)
        )[0]
        text_input = self.model.tokenizer(
            [prompt],
            padding="max_length",
            max_length=self.model.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = self.model.text_encoder(
            text_input.input_ids.to(self.model.device)
        )[0]
        # print(f"{uncond_embeddings.shape=}")
        # print(f"{text_embeddings.shape=}")
        self.context = torch.cat([uncond_embeddings, text_embeddings])
        self.prompt = prompt

    # @torch.no_grad()
    def get_v_prediction(self, latents, t, context, global_states):
        # Helper to get the model's v-prediction
        timestep = t.expand(latents.shape[0])
        v_pred = self.model.transformer(
            latents, timestep, 
            encoder_hidden_states=context, 
            global_hidden_states=global_states,
            return_dict=False
        )[0]
        # print("v_pred",v_pred.shape)
        return v_pred

    def get_noise_pred(self, latents, t, is_forward=True, context=None):
        latents_input = torch.cat([latents] * 2)
        self.global_states = self.global_states.repeat(2, 1, 1)
        if context is None:
            context = self.context
        guidance_scale = 1 if is_forward else GUIDANCE_SCALE
        timestep = t.expand(latents.shape[0])
        # print(timestep.shape)
        noise_pred = self.model.transformer(latents_input, timestep, encoder_hidden_states=context,global_hidden_states=self.global_states)[0]
        noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (
            noise_prediction_text - noise_pred_uncond
        )
        if is_forward:
            latents = self.next_step(noise_pred, t, latents)
        else:
            latents = self.prev_step(noise_pred, t, latents)
        return latents

    def next_step(
        self,
        model_output: torch.Tensor,  # this is v_t
        timestep: int,
        sample: torch.Tensor,        # this is x_t
        return_orig: bool = False,
    ):
        dt = self.model.scheduler.num_train_timesteps // TOTAL_STEPS
        max_step = self.model.scheduler.num_train_timesteps - 1
    
        # clamp to [0, max_step]
        t = int(timestep)
        t_next = min(t + dt, max_step)
    
        # grab alphas
        alpha_t      = self.model.scheduler.alphas_cumprod[t]
        alpha_t_next = self.model.scheduler.alphas_cumprod[t_next]
        beta_t       = 1 - alpha_t
    
        # recover x0 and eps from v
        x0  = alpha_t**0.5 * sample - beta_t**0.5 * model_output
        eps = alpha_t**0.5 * model_output + beta_t**0.5 * sample
    
        # DDIM forward (inversion) update
        next_sample = alpha_t_next**0.5 * x0 + (1 - alpha_t_next)**0.5 * eps
    
        if return_orig:
            return next_sample, x0
        return next_sample

    def prev_step(
        self,
        model_output: torch.Tensor,  # this is v_t
        timestep: int,
        sample: torch.Tensor,        # this is x_t
    ):
        # ---- 1) figure out your "backward" timestep decrement ----
        dt = self.model.scheduler.num_train_timesteps // TOTAL_STEPS
        t_prev = max(timestep - dt, 0)
    
        # ---- 2) grab alphas ----
        alpha_t      = self.model.scheduler.alphas_cumprod[timestep]
        alpha_t_prev = (
            self.model.scheduler.alphas_cumprod[t_prev]
            if t_prev >= 0
            else self.model.scheduler.final_alpha_cumprod
        )
        beta_t = 1 - alpha_t
    
        # ---- 3) predict x0 and eps from v ----
        x0  = alpha_t**0.5 * sample - beta_t**0.5 * model_output
        eps = alpha_t**0.5 * model_output + beta_t**0.5 * sample
    
        # ---- 4) step backward ----
        prev_sample = alpha_t_prev**0.5 * x0 + (1 - alpha_t_prev)**0.5 * eps
    
        return prev_sample

    @torch.no_grad()
    def latent2audio(self, latents: torch.FloatTensor, return_type: str = "np"):
        """
        Decode a latent back into a waveform.
        – latents: (B, C_latent, T_latent)
        – returns either a torch.Tensor (B, C_audio, T_audio)
          or a numpy array (T_audio, C_audio) for the first batch item.
        """
        # 1) undo the scaling
        latents = latents.detach() / self.vae_scale_factor
        # 2) decode
        audio = self.model.vae.decode(latents)["sample"]
        # audio is now a torch.Tensor of shape (B, C, T)
        if return_type == "np":
            # clamp into valid [-1,1] range and convert
            audio = audio.clamp(-1, 1)
            # move channels axis last and pick batch 0
            audio_np = audio.cpu().permute(0, 2, 1).numpy()[0]
            return audio_np
        return audio

    @torch.no_grad()
    def audio2latent(self, audio: Union[np.ndarray, torch.Tensor]):
        """
        Encode a waveform into VAE latents.
        – audio: either a np.ndarray (T, C) or torch.Tensor (C, T) / (B, C, T)
        – returns a torch.FloatTensor of latents (B, C_latent, T_latent)
        """
        # 1) to torch
        # if isinstance(audio, np.ndarray):
        #     # assume shape (T, C) → (C, T)
        #     audio = torch.from_numpy(audio).permute(1, 0)
        # now audio is Tensor of shape (C, T) or (B, C, T)
        if audio.dim() == 2:
            audio = audio.unsqueeze(0)    # → (1, C, T)
        audio = audio.to(self.device).float()
        # 2) encode
        latents = self.model.vae.encode(audio)["latent_dist"].mean
        # 3) apply the same scale factor
        latents = latents * self.vae_scale_factor
        return latents


    @torch.no_grad()
    def ddim_loop(self, latent, audio_tensor):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        all_latent = [latent]
        latent = latent.clone().detach()

        #before we can call the transformer for the prediction, we will have to calculate global states by projection 
        # We will get audio length information and and add it to the global state projection

        duration_s = audio_tensor.shape[-1] / self.model.vae.sampling_rate
        ss, es = self.model.encode_duration(0.0, duration_s, self.model.device, do_classifier_free_guidance=False, batch_size=1)
        self.global_states = self.model.transformer.global_proj(torch.cat([ss, es], dim=-1))
        # print("globalstates", self.global_states)
        for i in range(NUM_DDIM_STEPS):
            t = self.model.scheduler.timesteps[
                len(self.model.scheduler.timesteps) - i - 1
            ]
            noise_pred = self.get_v_prediction(latent, t, cond_embeddings, self.global_states)
            latent = self.next_step(noise_pred, t, latent)
            all_latent.append(latent)
        return all_latent

    @torch.no_grad()
    def ddim_inversion(self, audio_tensor):
        latent = self.audio2latent(audio_tensor)
        print("after latent")
        # # print(f"latent shape before ddim loop {latent.shape}")
        audio_rec = self.latent2audio(latent)
        print("after latent to audio")
        ddim_latents = self.ddim_loop(latent, audio_tensor)
        print("in loop")
        # print(f"{len(ddim_latents)=}")
        return audio_rec, ddim_latents

    def null_optimization(self, latents, num_inner_steps, epsilon):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        uncond_embeddings_list = []
        latent_cur = latents[-1]
        for i in tqdm(range(NUM_DDIM_STEPS)):
            uncond_embeddings = uncond_embeddings.clone().detach()
            uncond_embeddings.requires_grad = True
            optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0))
            latent_prev = latents[len(latents) - i - 2]
            t = self.model.scheduler.timesteps[i]
            with torch.no_grad():
                # print(latent_cur.shape, cond_embeddings.shape, self.global_states.shape,  "no 1 prediction")
                if latent_cur.shape[0] < 2:
                    self.global_states = self.global_states[:1]
                    
                noise_pred_cond = self.get_v_prediction(
                    latent_cur, t, cond_embeddings, self.global_states
                )
            for j in range(num_inner_steps):
                # print(latent_cur.shape, cond_embeddings.shape, self.global_states.shape, "no 2 prediction")
                noise_pred_uncond = self.get_v_prediction(
                    latent_cur, t, uncond_embeddings, self.global_states
                )
                noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (
                    noise_pred_cond - noise_pred_uncond
                )
                latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
                loss = F.mse_loss(latents_prev_rec, latent_prev)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_item = loss.item()
                if loss_item < epsilon + i * 2e-5:
                    break
            uncond_embeddings_list.append(uncond_embeddings[:1].detach())
            with torch.no_grad():
                context = torch.cat([uncond_embeddings, cond_embeddings])
                latent_cur = self.get_noise_pred(latent_cur, t, False, context)
        return uncond_embeddings_list
        
    # @torch.no_grad()
    def invert(
        self,
        audio_path: str,
        prompt: str,
        num_inner_steps=10,
        early_stop_epsilon=1e-5,
        verbose=False
    ):
        print("first here")
        self.init_prompt(prompt)
        ptp_utils.register_attention_control(self.model, None)
        audio_gt = load_audio(audio_path) # Downsampling audio with 2 channels
        # audio_gt = load_audio(audio_path, device=device)
        print(f"audio after loading to numpy {audio_gt.shape}")

        if verbose:
            print("DDIM inversion...")
        audio_rec, ddim_latents = self.ddim_inversion(audio_gt)

        if verbose:
            print("Null-text optimization...")
        uncond_embeddings = self.null_optimization(
            ddim_latents, num_inner_steps, early_stop_epsilon
        )
        return (audio_gt, audio_rec), ddim_latents, ddim_latents[-1], uncond_embeddings

    @torch.no_grad()
    def invert_style(
        self,
        audio: torch.Tensor,
        prompt,
        num_inference_steps=100,
        guidance_scale=7.0,
        eta=0.0,
        return_intermediates=False,
        **kwds,
    ):
        """
        invert a real audio into noise map with determinisc DDIM inversion
        """
        DEVICE = (
            torch.device("cuda") if torch.cuda.is_available(
            ) else torch.device("cpu")
        )


        batch_size = audio.shape[0]
        if isinstance(prompt, list):
            if batch_size == 1:
                audio = audio.expand(len(prompt), -1, -1)
        elif isinstance(prompt, str):
            if batch_size > 1:
                prompt = [prompt] * batch_size

        # text embeddings
        text_input = self.tokenizer(
            prompt, padding="max_length", max_length=64, return_tensors="pt"
        )

        text_embeddings = self.model.text_encoder(text_input.input_ids.to(DEVICE))[0]
        # print(text_embeddings.shape) # these are the same as hjidden dimension [1,64,768]
        
        # define initial latents
        latents = self.audio2latent(audio)
        # print(latents.shape)# should be [1,64,107]
        
        start_latents = latents
        # unconditional embedding for classifier free guidance
        if guidance_scale > 1.0:
            max_length = text_input.input_ids.shape[-1]
            unconditional_input = self.tokenizer(
                [""] * batch_size,
                padding="max_length",
                max_length=64,
                return_tensors="pt",
            )
            unconditional_embeddings = self.model.text_encoder(
                unconditional_input.input_ids.to(DEVICE)
            )[0]
            text_embeddings = torch.cat(
                [unconditional_embeddings, text_embeddings], dim=0
            )
        # interative sampling
        self.model.scheduler.set_timesteps(num_inference_steps)
        latents_list = [latents]
        pred_x0_list = [latents]
        for i, t in enumerate(
            tqdm(reversed(self.model.scheduler.timesteps), desc="DDIM Inversion")
        ):
            if guidance_scale > 1.0:
                model_inputs = torch.cat([latents] * 2)
            else:
                model_inputs = latents

            # Calcualte global states, length
            duration_s = audio.shape[-1] / self.model.vae.sampling_rate
            ss, es = self.model.encode_duration(0.0, duration_s, self.model.device, do_classifier_free_guidance=False, batch_size=1)
            self.global_states = self.model.transformer.global_proj(torch.cat([ss, es], dim=-1))
            # print(self.global_states.shape)

            if model_inputs.shape[0] > 1:
                self.global_states = self.global_states.repeat(2, 1, 1) #[2, 1, 1536])

            timestep = t.expand(latents.shape[0]).to(self.model.device)
            # print(f"{model_inputs.shape=}")
            # print(f"{text_embeddings.shape=}")
            # print(f"{self.global_states.shape=}") 
            # predict the noise
            noise_pred = self.model.transformer(
                model_inputs, timestep, encoder_hidden_states=text_embeddings, global_hidden_states=self.global_states
            )[0]

            # print(noise_pred.shape) #[2, 64, 107])
     
            if guidance_scale > 1.0:
                noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
                noise_pred = noise_pred_uncon + guidance_scale * (
                    noise_pred_con - noise_pred_uncon
                )
            # compute the previous noise sample x_t-1 -> x_t
            latents, pred_x0 = self.next_step(noise_pred, t, latents, True)
            latents_list.append(latents)
            pred_x0_list.append(pred_x0)

        if return_intermediates:
            # return the intermediate laters during inversion
            return latents, latents_list
        return latents, start_latents


In [16]:
null_inversion = NullInversion(model)

In [7]:
# editor = ReweightCrossAttentionControl(
#     # start_step=int(0.2*TOTAL_STEPS),
#     # end_step=int(0.8*TOTAL_STEPS),
#     # start_layer=0,
#     # end_layer=20,
#     total_steps=TOTAL_STEPS,
#     # style_scale = 5.0
# )
# register_attention_editor_diffusers(model, editor)

In [None]:
content_files = get_audio_files("c")
style_files = get_audio_files("s")
import sys
import torch
import gc

torch.cuda.empty_cache()
gc.collect()
for cpath in content_files:
    print(f"Processing content: {os.path.basename(cpath)}")
    pkl = cpath.rsplit('.',1)[0] + ".pkl"

    source_prompt = ""
    target_prompt = ""
    prompts = [source_prompt, target_prompt]

    if os.path.isfile(pkl):
        (audio_gt, audio_rec),cont_trajectory, content_latent_vector, u_c = pickle.load(open(pkl,'rb'))

        sf.write('content_recon.wav', audio_rec.squeeze(), TARGET_SR)
        test_rec = null_inversion.latent2audio(content_latent_vector)
        wav_noise = test_rec.squeeze()
        sf.write('last_recon.wav', wav_noise, TARGET_SR)
        load_and_view_audio('last_recon.wav')

        
        print("Ground Truth")
        load_and_view_audio("c/content.wav")
        print("\n# ********************\n")

    else:
        # with torch.no_grad():
        (audio_gt, audio_rec),cont_trajectory, content_latent_vector, u_c = null_inversion.invert(
            cpath, prompts, verbose=True
        )
        # print("content latent vector",content_latent_vector.shape)

        print("Ground Truth")
        load_and_view_audio("c/content.wav")
        print("\n# ********************\n")
        print("Reconstruction")
        test_rec = null_inversion.latent2audio(content_latent_vector)
        wav_noise = test_rec.squeeze()
        sf.write('last_recon.wav', wav_noise, TARGET_SR)
        load_and_view_audio('last_recon.wav')
        
        pickle.dump(((audio_gt, audio_rec), cont_trajectory, content_latent_vector, u_c), open(pkl,'wb'))
        print("Saved content reconstruction: {content_recon.wav")

    start_code_content = content_latent_vector.expand(len(prompts), -1, -1)
    for style_path in style_files:
        style_audio = load_audio(style_path)
        final_latent, style_latent_list = null_inversion.invert_style(
            style_audio,
            source_prompt,
            guidance_scale=GUIDANCE_SCALE,
            num_inference_steps=TOTAL_STEPS,
            return_intermediates=True,
        )
        test_rec_style = null_inversion.latent2audio(final_latent)
        wav_noise_style = test_rec_style.squeeze()
        sf.write('last_recon_style.wav', wav_noise_style, TARGET_SR)
        load_and_view_audio('last_recon_style.wav')

        # Basically test this 
        # editor = ReweightCrossAttentionControl(
        #     # start_step=int(0.2*TOTAL_STEPS),
        #     # end_step=int(0.8*TOTAL_STEPS),
        #     # start_layer=0,
        #     # end_layer=20,
        #     total_steps=TOTAL_STEPS,
        #     # style_scale = 5.0
        # )
        # register_attention_editor_diffusers(model, editor)

        # Audio Mix :
        aud_stylized = model(
            prompts,
            latents=start_code_content,
            guidance_scale=GUIDANCE_SCALE,
            uncond_embeddings=u_c,
            num_inference_steps=TOTAL_STEPS,
            ref_intermediate_latents=[cont_trajectory, style_latent_list]
         )

        audio_np = aud_stylized.detach().cpu().numpy()[1].T
        audio_np_style = aud_stylized.detach().cpu().numpy()[0].T
        sf.write("mixed.wav", audio_np, TARGET_SR)
        sf.write("style_from_model.wav", audio_np_style, TARGET_SR)

        
        load_and_view_audio("mixed.wav")
        print("****************** ************************")
        load_and_view_audio("style_from_model.wav")

