In [None]:
import torch, math


from IPython.display import display
from IPython.display import Audio as AudioPlayer

In [None]:
device = 'cuda'

In [None]:
from senhance.models.codec.dac import DescriptAudioCodec

codec = DescriptAudioCodec('/data/models/dac/weights_24khz_8kbps_0.0.4.pth').to(device, non_blocking=True)

In [None]:
from senhance.models.checkpoint import Checkpoint
from senhance.models.cfm.cfm import ConditionalFlowMatcher
from senhance.models.unet.unet import UNET1d, UNET1dDims

checkpoint = Checkpoint.load('/data/experiments/test2/checkpoint.25000.pt')
dims = UNET1dDims(codec.dim, 1024, 1024)
unet = UNET1d(dims)
# unet = torch.compile(unet)
cfm = ConditionalFlowMatcher(unet)
cfm.load_state_dict(checkpoint.model)
cfm = cfm.to(device, non_blocking=True)

In [None]:
from senhance.data.audio import Audio

x = Audio("/data/denoising/speech/ljspeech/LJSpeech-1.1/wavs/LJ001-0001.wav")
x = x.resample(24_000)
seq_length = 64 / codec.resolution_hz
excerpt = x.normalize(-24.0).salient_excerpt(seq_length)
AudioPlayer(excerpt.waveform.numpy(), rate=excerpt.sample_rate)

In [None]:
from senhance.data.augmentations.default import get_default_augmentation

augment = get_default_augmentation(noise_folder='/data/denoising/noise/', sample_rate=x.sample_rate, sequence_length_s=1, split="test", p=1.0)
augment_params = augment.sample_parameters(excerpt)
augmented = augment.augment(excerpt.waveform[None].clone(), augment_params)
AudioPlayer(augmented[0], rate=excerpt.sample_rate)

In [None]:
with torch.inference_mode():
    augmented = augmented.to(device)
    z_nsy = codec.normalize(codec.encode(augmented.clone()))
    timesteps = torch.linspace(1, math.exp(1), 20).log()
    z_hat = cfm.sample(z_nsy.clone(), timesteps.tolist())
    denoised = codec.decode(codec.unnormalize(z_hat.clone())).cpu().numpy()
AudioPlayer(denoised[0], rate=excerpt.sample_rate)

In [None]:
with torch.inference_mode():
    z_cln = codec.normalize(codec.encode(excerpt.waveform[None].clone().to('cuda'))) 

In [None]:
(z_cln - z_hat).mean()

In [None]:
import matplotlib.pyplot as plt

plt.hist((z_cln-z_hat).detach().cpu().view(-1), bins=50)

In [None]:
v = 2.5
fig = plt.figure(figsize=(10, 15))
axs = fig.subplots(3)
axs[0].imshow(z_cln[0].detach().cpu(), aspect='auto', interpolation='none', vmin=-v, vmax=v)
axs[0].set_title('clean')
axs[1].imshow(z_nsy[0].detach().cpu(), aspect='auto', interpolation='none', vmin=-v, vmax=v)
axs[1].set_title('noisy')
axs[2].imshow(z_hat[0].detach().cpu(), aspect='auto', interpolation='none', vmin=-v, vmax=v)
axs[2].set_title('denoised')

In [None]:
plt.figure(figsize=(10, 15))
plt.imshow((z_hat - z_cln).abs()[0].detach().cpu(), aspect='auto', interpolation='none', vmin=-v, vmax=v)
plt.colorbar()

In [None]:
plt.hist(z_nsy.detach().cpu().view(-1), bins=50)