In [None]:
from denoiser.data.source import AudioSource
from denoiser.data.dataset import AudioDataset
from denoiser.data.collate import collate
from denoiser.data.augmentations.default import get_default_augmentation

In [None]:
from pathlib import Path

noise_folder = Path("/data/denoising/noise/records/DEMAND/48k")
train_augments = get_default_augmentation(sequence_length_s=0.0, split='train', p=1.0)

In [None]:
from torch.utils.data import DataLoader

sr = 24_000
speech_folder = Path("/data/denoising/speech/daps/clean")
train_audio_source = AudioSource(
    speech_folder / "index.train.json",
    sequence_length_s=64 / 75,
)
train_dataset = AudioDataset(
    train_audio_source,
    sample_rate=sr,
    augmentation=train_augments,
)
train_dloader = DataLoader(
    train_dataset,
    batch_size=1,
    collate_fn=collate,
    num_workers=0,
    shuffle=True,
)

In [None]:
dloader = iter(train_dloader)

In [None]:
batch = next(dloader)
print(batch.augmentation_params)
clean = batch.waveforms
noisy = train_augments.augment(clean, parameters=batch.augmentation_params)

In [None]:
from IPython.display import display, Audio

display(Audio(clean[0].numpy(), rate=sr))
display(Audio(noisy[0].numpy(), rate=sr))
# display(Audio(noise[0].numpy(), rate=sr))

In [None]:
import torch
from denoiser.data.audio import Audio
from denoiser.models.codec.dac import DescriptAudioCodec

from IPython.display import Audio as AudioPlayer

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
codec = "dac"

if codec == "mimi":
    codec = MimiCodec(
        "/home/lucas/models/moshi/tokenizer-e351c8d8-checkpoint125.safetensors"
    )
elif codec == "dac":
    codec = DescriptAudioCodec("/data/models/dac/weights_24khz_8kbps_0.0.4.pth")
codec = codec.eval()
codec = codec.to(device)

In [None]:
audio = Audio("/data/denoising/speech/daps/clean/f10_script1_clean.wav")
audio.resample(codec.sample_rate)

In [None]:
with torch.inference_mode():
    reconstructed = codec.decode(codec.encode(audio.waveform[None].to(device)))
    # reconstructed = codec.reconstruct(audio.waveform[None].to(device))

In [None]:
AudioPlayer(audio.waveform, rate=codec.sample_rate)

In [None]:
AudioPlayer(reconstructed[0].cpu(), rate=codec.sample_rate)

In [None]:
with torch.inference_mode():
    encoded = codec.encode(audio.waveform[None].to(device))
    nencoded = codec.normalize(encoded)

In [None]:
import matplotlib.pyplot as plt

plt.plot(0.5 * nencoded[0].std(-1).cpu())

In [None]:
import torch.nn as nn
from denoiser.models.cfm.cfm import ConditionalFlowMatcher

cfm = ConditionalFlowMatcher(nn.Identity())
print(cfm.sigma_0)

In [None]:
nencoded.std(-1)

In [None]:
x_0 = 0.5 * torch.randn_like(nencoded)
xs, lls, ss = [], [], []
for t in torch.linspace(0, 1, 100):
    with torch.inference_mode():
        sigma_t = cfm.sigma_t(t)
        x_t = sigma_t * x_0 + t * nencoded
        xs.append(x_t.cpu())
        lls.append(N.log_prob(x_t).exp())
        ss.append(sigma_t)
xs = torch.cat(xs)

In [None]:
xs = (255 / 5 * xs).int()
xs.min(), xs.max(), xs.numpy().dtype

In [None]:
from PIL import Image
import numpy as np
gif = [Image.fromarray(x) for x in xs.cpu().numpy().astype(np.int8)]

In [None]:
gif[0].save("array.gif", save_all=True, append_images=gif[1:], duration=15, loop=0)

In [None]:
gif[0].save('gif.png')

In [None]:
plt.imshow(xs[0].cpu().numpy(), aspect='auto')
plt.colorbar()

In [None]:
xs.cpu().numpy().astype(np.int8).shape

![SegmentLocal](array.gif "segment")

In [None]:
plt.imshow(xs.std(-1).T, aspect='auto')

In [None]:
_ = plt.plot(xs.std(-1))
plt.show()

In [None]:
from torch.distributions import Normal

N = Normal(0, 1)

In [None]:
N.log_prob(nencoded).exp().mean(-1)

In [None]:
N.log_prob(torch.randn_like(xs[0])).exp().mean(-1)

In [None]:
_ = plt.plot(torch.cat(lls).cpu().mean(-1))

In [None]:
torch.stack(lls).shape

In [None]:
plt.imshow(xs[0], aspect='auto')
plt.colorbar()

In [None]:
plt.imshow(nencoded.cpu()[0], aspect='auto')
plt.colorbar()

In [None]:
plt.plot(torch.stack(ss))

In [None]:
nencoded.mean(), nencoded.std(), x_0.mean(), x_0.std()