In [2]:
import sys

sys.path.append("..")

import torch
import torchaudio
from src.model.VAELightningModule import VAELightningModule
from src.model.AudioVAE import AudioVAE
from src.loss_fn.VAELossCalculator import VAELossCalculator
from src.loss_fn.WaveformLoss import WaveformLoss
from tqdm import tqdm
from src.dataset.MUSDB18Dataset import MUSDB18Dataset

# Import your module class
# from your_file import VAELightningModule

In [8]:
set(val_ds.mus.tracks), set(train_ds.mus.tracks)

({ANiMAL - Rockshow,
  Actions - One Minute Smile,
  Alexander Ross - Goodbye Bolero,
  Clara Berry And Wooldog - Waltz For My Victims,
  Fergessen - Nos Palpitants,
  James May - On The Line,
  Johnny Lokke - Promises & Lies,
  Leaf - Summerghost,
  Meaxic - Take A Step,
  Patrick Talbot - A Reason To Leave,
  Skelpolu - Human Mistakes,
  Traffic Experiment - Sirens,
  Triviul - Angelsaint,
  Young Griffo - Pennies},
 {A Classic Education - NightOwl,
  ANiMAL - Clinic A,
  ANiMAL - Easy Tiger,
  Actions - Devil's Words,
  Actions - South Of The Water,
  Aimee Norwich - Child,
  Alexander Ross - Velvet Curtain,
  Angela Thomas Wade - Milk Cow Blues,
  Atlantis Bound - It Was My Fault For Waiting,
  Auctioneer - Our Future Faces,
  AvaLuna - Waterduct,
  BigTroubles - Phantom,
  Bill Chudziak - Children Of No-one,
  Black Bloc - If You Want Success,
  Celestial Shore - Die For Us,
  Chris Durban - Celebrate,
  Clara Berry And Wooldog - Air Traffic,
  Clara Berry And Wooldog - Stella,
  

In [29]:
durations = [track.duration for track in ds]
durations

[209.906009,
 200.327007,
 179.393991,
 219.192993,
 336.84,
 219.925011,
 254.710998,
 221.05,
 275.626009,
 250.019002,
 253.933991,
 190.281995,
 186.796009,
 273.385011,
 430.175011,
 317.080998,
 141.473991,
 205.272993,
 267.451995,
 281.411995,
 292.621995,
 227.229002,
 269.8,
 198.657007,
 246.522993,
 271.965011,
 343.513991,
 162.503991,
 75.996009,
 221.398005,
 212.046009,
 317.68,
 286.646009,
 250.896009,
 243.482993,
 177.198005,
 395.462993,
 252.96,
 312.467007,
 331.241995,
 234.483991,
 246.790998,
 305.520998,
 244.26,
 320.097007,
 252.849002,
 175.652993,
 234.91,
 207.698005,
 234.183991]

In [30]:
track = ds[1]

In [31]:
track.duration

200.327007

In [32]:
x = track.audio.T
sr = 44100

In [33]:
track.rate

44100

In [34]:
checkpoint = torch.load("../checkpoints/big_model/last.ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]

# Remove the 'model.' prefix added by the LightningModule
new_state_dict = {
    k.replace("model.", ""): v for k, v in state_dict.items() if k.startswith("model.")
}

model = AudioVAE(base_channels=128, strides=[2, 4, 4, 8])
model.load_state_dict(new_state_dict)



<All keys matched successfully>

In [35]:
def reconstruct_musdb_track(
    model,
    track,
    track_name: str,
    chunk_size=int(1.5 * 44100),
    sample_rate: int = 44100,
    overlap=1024,
    device="cuda",
    output_dir="reconstructions",
):
    """
    Takes a musdb Track, reconstructs it in chunks, and saves the result.
    """
    import os

    os.makedirs(output_dir, exist_ok=True)

    # 1. Setup Model
    model.to(device)
    model.eval()

    # 2. Get Audio (MUSDB yields [samples, channels], we need [channels, samples])
    # track.audio is the stereo mixture
    original_audio = torch.from_numpy(track).float()
    num_channels, total_samples = original_audio.shape

    # 3. Initialize reconstruction buffer
    reconstructed_full = torch.zeros_like(original_audio)
    # Weight buffer for linear cross-fading (optional, here we use simple overwrite)

    # 4. Process in Chunks
    step = chunk_size - overlap
    print(f"Reconstructing track: {track_name}")

    with torch.no_grad():
        for start in tqdm(range(0, total_samples, step)):
            end = min(start + chunk_size, total_samples)

            # Extract chunk and pad if it's the last short chunk
            chunk = original_audio[:, start:end]
            actual_len = chunk.shape[1]

            if actual_len < chunk_size:
                chunk = torch.nn.functional.pad(chunk, (0, chunk_size - actual_len))

            # Prepare for VAE (Batch, Channels, Length)
            input_tensor = chunk.unsqueeze(0).to(device)

            # VAE Forward Pass
            # Based on your module: recon, z, mean, logvar = self.forward(x)
            recon, _, _, _ = model(input_tensor)

            # Remove batch dim and crop padding if necessary
            recon_chunk = recon.squeeze(0).cpu()
            recon_chunk = recon_chunk[:, :actual_len]

            # Insert into full buffer
            # Simple stitching:
            reconstructed_full[:, start:end] = recon_chunk

    # 5. Save Files
    orig_path = f"{output_dir}/{track_name.replace(' ', '_')}_original.wav"
    recon_path = f"{output_dir}/{track_name.replace(' ', '_')}_reconstructed.wav"

    torchaudio.save(orig_path, original_audio, sample_rate=sample_rate)
    torchaudio.save(recon_path, reconstructed_full, sample_rate=sample_rate)

    print(f"Saved: {recon_path}")
    return reconstructed_full

In [36]:
reconstruct_musdb_track(model=model, track=x, track_name=track.name, device="cuda:2")

Reconstructing track: Al James - Schoolboy Facination


 11%|█         | 15/136 [00:00<00:04, 27.81it/s]

100%|██████████| 136/136 [00:04<00:00, 27.77it/s]


Saved: reconstructions/Al_James_-_Schoolboy_Facination_reconstructed.wav


tensor([[-0.0033,  0.0070,  0.0169,  ...,  0.0200,  0.0143,  0.0191],
        [ 0.0160,  0.0129,  0.0215,  ...,  0.0335,  0.0346,  0.0345]])