In [28]:
import sys

sys.path.append("..")

from pathlib import Path
import torch
import random
import torchaudio
from src.model.VAELightningModule import VAELightningModule
from src.model.AudioVAE import AudioVAE
from src.loss_fn.VAELossCalculator import VAELossCalculator
from src.loss_fn.WaveformLoss import WaveformLoss
from tqdm import tqdm
from src.dataset.GTZANDataset import GTZANDataset

# Import your module class
# from your_file import VAELightningModule

In [29]:
ds = GTZANDataset(root="/data1/midi_generation_datasets/gtzan", split="valid")

In [30]:
idx = random.randint(0, len(ds.audio_files))
track_name = Path(ds.audio_files[idx]).name
print(track_name)
audio, sr = torchaudio.load(ds.audio_files[idx])
audio = torch.cat([audio, audio], dim=0)
audio.shape

blues.00081.wav


torch.Size([2, 661794])

In [None]:
checkpoint = torch.load(
    "../checkpoints/audio_vae_with_warmup/last.ckpt", map_location="cpu"
)
state_dict = checkpoint["state_dict"]

# Remove the 'model.' prefix added by the LightningModule
new_state_dict = {
    k.replace("model.", ""): v for k, v in state_dict.items() if k.startswith("model.")
}

model = AudioVAE(
    base_channels=64,
    strides=[2, 4, 4, 4, 4],
    channel_mults=[1, 2, 4, 8, 8],
    latent_dim=128,
)
model.load_state_dict(new_state_dict)



<All keys matched successfully>

In [32]:
def reconstruct_musdb_track(
    model,
    track,
    track_name: str,
    chunk_size=int(1.5 * 22050),
    sample_rate: int = 22050,
    overlap=1024,
    device="cuda",
    output_dir="reconstructions",
):
    """
    Takes a musdb Track, reconstructs it in chunks, and saves the result.
    """
    import os

    os.makedirs(output_dir, exist_ok=True)

    # 1. Setup Model
    model.to(device)
    model.eval()

    # 2. Get Audio (MUSDB yields [samples, channels], we need [channels, samples])
    # track.audio is the stereo mixture
    if not isinstance(track, torch.Tensor):
        original_audio = torch.from_numpy(track).float()
    else:
        original_audio = track.float()
    num_channels, total_samples = original_audio.shape

    # 3. Initialize reconstruction buffer
    reconstructed_full = torch.zeros_like(original_audio)
    # Weight buffer for linear cross-fading (optional, here we use simple overwrite)

    # 4. Process in Chunks
    step = chunk_size - overlap
    print(f"Reconstructing track: {track_name}")

    with torch.no_grad():
        for start in tqdm(range(0, total_samples, step)):
            end = min(start + chunk_size, total_samples)

            # Extract chunk and pad if it's the last short chunk
            chunk = original_audio[:, start:end]
            actual_len = chunk.shape[1]

            if actual_len < chunk_size:
                chunk = torch.nn.functional.pad(chunk, (0, chunk_size - actual_len))

            # Prepare for VAE (Batch, Channels, Length)
            input_tensor = chunk.unsqueeze(0).to(device)

            # VAE Forward Pass
            # Based on your module: recon, z, mean, logvar = self.forward(x)
            recon, _, _, _ = model(input_tensor)

            # Remove batch dim and crop padding if necessary
            recon_chunk = recon.squeeze(0).cpu()
            recon_chunk = recon_chunk[:, :actual_len]

            # Insert into full buffer
            # Simple stitching:
            reconstructed_full[:, start:end] = recon_chunk

    # 5. Save Files
    orig_path = f"{output_dir}/{track_name.replace(' ', '_')}_original.wav"
    recon_path = f"{output_dir}/{track_name.replace(' ', '_')}_reconstructed.wav"

    torchaudio.save(orig_path, original_audio, sample_rate=sample_rate)
    torchaudio.save(recon_path, reconstructed_full, sample_rate=sample_rate)

    print(f"Saved: {recon_path}")
    return reconstructed_full

In [33]:
reconstruct_musdb_track(
    model=model, track=audio, track_name=track_name, device="cuda:3"
)

Reconstructing track: blues.00081.wav


100%|██████████| 21/21 [00:00<00:00, 42.33it/s]


Saved: reconstructions/blues.00081.wav_reconstructed.wav


tensor([[0.1885, 0.2607, 0.2502,  ..., 0.0069, 0.0073, 0.0123],
        [0.1878, 0.2601, 0.2496,  ..., 0.0061, 0.0065, 0.0114]])