In [1]:
import torch
from denoiser.data.audio import Audio
from denoiser.models.codec.mimi import MimiCodec
from denoiser.models.codec.dac import DescriptAudioCodec

from IPython.display import Audio as AudioPlayer

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
codec = 'dac'

if codec == 'mimi':
    codec = MimiCodec(
        "/home/lucas/models/moshi/tokenizer-e351c8d8-checkpoint125.safetensors"
    )
elif codec == 'dac':
    codec = DescriptAudioCodec('/data/models/dac/weights_24khz_8kbps_0.0.4.pth')
codec = codec.eval()
codec = codec.to(device)

In [None]:
audio = Audio("/data/denoising/speech/daps/clean/f10_script1_clean.wav")
audio.resample(codec.sample_rate)

In [23]:
with torch.inference_mode():
    reconstructed = codec.decode(codec.encode(audio.waveform[None].to(device)))
    # reconstructed = mimi.reconstruct(audio.waveform[None].to(device))

In [None]:
AudioPlayer(audio.waveform, rate=codec.sample_rate)

In [None]:
AudioPlayer(reconstructed[0].cpu(), rate=codec.sample_rate)

In [None]:
from pathlib import Path
from denoiser.data.source import AudioSource
from denoiser.data.dataset import AudioDataset
from torch.utils.data import DataLoader
from denoiser.data.collate import collate
from tqdm import tqdm

sr = codec.sample_rate
sequence_length_s = 64 / codec.resolution_hz
speech_folder = Path('/data/denoising/speech/daps')

train_audio_source = AudioSource(
    speech_folder / "index.json",
    sequence_length_s=sequence_length_s,
)
train_dataset = AudioDataset(
    train_audio_source,
    sample_rate=sr,
    # augmentation=train_augments,
)
dloader = DataLoader(
    train_dataset,
    batch_size=64,
    collate_fn=collate,
    num_workers=8,
)

features = []
for batch in tqdm(dloader):
    batch = batch.to(device)
    with torch.inference_mode():
        feature = codec.encode(batch.waveforms).cpu()
    features.append(feature)
features = torch.cat(features, dim=0)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(features[0], aspect='auto')
plt.colorbar()

In [None]:
features.min(), features.max(), features.mean(), features.std()