# Inference example of timbre transfer with dual latent diffusion bridge

Add the main tssb folder to the system path, or navigate (cd) to that directory to ensure modules can be imported correctly.

In [None]:
import warnings
warnings.filterwarnings('ignore')

from main.module_base_latent_cond import (
    Model, 
    AudioDiffusionModel,
)


from audio_diffusion_pytorch import (
    KarrasSamplerReverse, 
    KarrasSampler,
    KarrasSampler_grad_guided,
    # KarrasSampler_grad_guidedv2,
    KarrasSchedule,
    KDistribution,
    # PitchTracker,
    NormalizedEncodec,
    plot_spec, 
    play_audio
)
import torch
import torchaudio
from IPython.display import HTML
from typing import Dict, Any
from torchaudio.prototype.transforms import ChromaSpectrogram
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Distribution, Schedule, and Samplers

- **`sigma_min`**, **`sigma_max`**, and **`rho`** will define the **`diffusion_sigma_distribution`**, creating a distribution of sigmas following **Formula 1** from the reference paper.

- **`diffusion_schedule`** follows the Karras scheduling algorithm to manage the progression of sigma values during the diffusion process.

The following samplers implement the Karras sampling algorithm:
- **`diffusion_sampler_reverse`**: Handles the reverse diffusion process (adding noise to audio).
- **`diffusion_sampler`**: Handles the forward diffusion process (denoising from noise).



In [None]:
sigma_min = 0.001
sigma_max = 100
rho = 9.0

diffusion_sigma_distribution = KDistribution(sigma_min = sigma_min, sigma_max = sigma_min, rho = rho)
diffusion_schedule = KarrasSchedule(sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)  

diffusion_sampler_reverse = KarrasSamplerReverse()
diffusion_sampler = KarrasSampler()

## Creating and Loading the Diffusion and PyTorch Lightning Models

In [None]:
# Loading violin model
diffusion_model_violin = AudioDiffusionModel(
    diffusion_sigma_distribution =  diffusion_sigma_distribution,
)

violin_mean_path  = '/workspace/data/kinwai/diffusion-timbre-transfer/ckpts/mean_tensor_enc_violin.pt'
violin_std_path = '/workspace/data/kinwai/diffusion-timbre-transfer/ckpts/std_tensor_enc_violin.pt'
violin_model_weights = '/workspace/data/kinwai/diffusion-timbre-transfer/ckpts/violin.ckpt'

# Pytorch lighting model
pl_model_violin = Model(
    model = diffusion_model_violin,
    mean_path = violin_mean_path,
    std_path = violin_std_path,
)

ckpt_violin = torch.load(violin_model_weights, map_location=device)    
pl_model_violin.load_state_dict(ckpt_violin["state_dict"], strict=True)
pl_model_violin.to(device)

## Loading Audio and Converting to Encodec Embeddings

In [None]:
sr = 24000
audio_path = '/workspace/data/kinwai/diffusion-timbre-transfer/audios/216001_1_violin.wav'
waveform_raw, orig_sr = torchaudio.load(audio_path)
resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=sr)
waveform_raw = resampler(waveform_raw)
waveform_ori = waveform_raw.unsqueeze(0)
waveform_ori = waveform_ori.to(device)

# Adjust the audio length to exactly 17 seconds with a sampling rate of 24,000 Hz, either by padding or cropping as needed.
pad_size = 409600 - waveform_ori.shape[-1] 
waveform_ori = torch.nn.functional.pad(waveform_ori, (0, pad_size))

print(f'Input waveform shape: {waveform_ori.shape}')
plot_spec(waveform_raw.numpy(), sr, title='Real Violin')
play_audio(waveform_raw.numpy(), sr)

# Conver the input audio to encodec embeddings
encodec = NormalizedEncodec(device=device)
embeddings_violin = encodec.encode_latent(waveform_ori, pl_model_violin.mean, pl_model_violin.std)
print(f'Input Encodec embeddings shape: {embeddings_violin.shape}')


## Noise adding

### Apply noise to the input audio using violin model, progressing through the forward diffusion steps to transform it into the shared latent space representation.

In [None]:
num_steps = 100
# Convert to Noise
noisy_violin_embeddings = []
# produce 4 noisy samples
for i in range(4):
    noisy_violin_embeddings.append(
        pl_model_violin.model.sample(
            noise=embeddings_violin,
            sampler=diffusion_sampler_reverse,
            sigma_schedule=diffusion_schedule,
            num_steps=num_steps,
        ).cpu().detach()
    )


In [None]:
noise_waveform = encodec.decode_latent(noisy_violin_embeddings[0].to(device), pl_model_violin.mean, pl_model_violin.std)
noise_waveform = noise_waveform.cpu().detach().squeeze(0).numpy()
plot_spec(noise_waveform, sr, title='Noisy violin')
play_audio(noise_waveform, sr)

## Using control function

In [None]:
# loading another audio

audio_path = '/workspace/data/kinwai/diffusion-timbre-transfer/audios/216002_1_violin.wav'
waveform_raw, orig_sr = torchaudio.load(audio_path)
resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=sr)
waveform_raw = resampler(waveform_raw)
waveform_cond = waveform_raw.unsqueeze(0)
waveform_cond = waveform_cond.to(device)

# Adjust the audio length to exactly 17 seconds with a sampling rate of 24,000 Hz, either by padding or cropping as needed.
pad_size = 409600 - waveform_cond.shape[-1] 
waveform_cond = torch.nn.functional.pad(waveform_cond, (0, pad_size))

print(f'Input waveform shape: {waveform_cond.shape}')
plot_spec(waveform_raw.numpy(), sr, title='Real Violin')
play_audio(waveform_raw.numpy(), sr)

# define control functions f
# need to pass the reference condition


In [None]:
for alpha in [0, 0.1, .3, .5, .7, .9, 1]:
    f_chroma = ChromaSpectrogram(sample_rate=sr, n_fft=2048).to(device)
    diffusion_sampler_grad_guided = KarrasSampler_grad_guided(
        f=f_chroma,
        c=waveform_cond, # need to be in raw waveform
        encodec=encodec,
        mean=pl_model_violin.mean,
        std=pl_model_violin.std,
        alpha=alpha
        )
    output_samples = []

    for i in range(4):
        generated_violin_embeddings = pl_model_violin.model.sample(
            noise=noisy_violin_embeddings[i].to(device),
            sampler=diffusion_sampler_grad_guided,
            sigma_schedule=diffusion_schedule,
            num_steps=num_steps,
            index=i
        ).cpu().detach()
        output_samples.append(generated_violin_embeddings)

    # for generated_violin_embedding in output_samples:
    #     print(generated_violin_embedding.mean())
    #     waveform_recon = encodec.decode_latent(generated_violin_embedding.to(device), pl_model_violin.mean, pl_model_violin.std)
    #     waveform_recon = waveform_recon.cpu().detach().squeeze(0).numpy()
    #     plot_spec(waveform_recon, sr, title='Generated Violin')
    #     play_audio(waveform_recon, sr)

