In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import math
import torch
import torchaudio
from tqdm import tqdm
from model.codec import CODEC
from utils.tools import batch_wdt

def load_model(checkpoint_path, device='cuda'):
    """
    Load the trained CODEC model from checkpoint.
    """
    model = CODEC.load_from_checkpoint(checkpoint_path,map_location='cuda:0',sample_rate=16000)
    model.eval()
    model.to(device)
    return model

def process_audio_in_chunks(file_path, sample_rate, model, chunk_size, device='cuda'):
    """
    Process audio file in chunks to avoid memory issues.
    
    Args:
        file_path (str): Path to the input audio file.
        sample_rate (int): Target sample rate for the model.
        model (CODEC): Trained CODEC model.
        chunk_size (int): Chunk size in samples.
        device (str): Device to run the model ('cuda' or 'cpu').
    
    Returns:
        torch.Tensor: Original audio tensor.
        torch.Tensor: Reconstructed audio tensor.
        torch.Tensor: Quantized codes.
    """
    # Load and preprocess the audio
    audio, sr = torchaudio.load(file_path)
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        audio = resampler(audio)

    audio = model.preprocess(audio, sample_rate).to(device)

    # Initialize outputs
    original_chunks = []
    reconstructed_chunks = []
    codes_list = []

    # Process audio in chunks
    for start in range(0, audio.shape[-1], chunk_size):
        end = min(start + chunk_size, audio.shape[-1])
        chunk = audio[0:1, start:end].unsqueeze(dim=0)

        with torch.no_grad():
            # Encode the chunk           

            audio_data = chunk[:,0:1,:]
            z = model.encoder(audio_data)
            z_q, codes, latents, commitment_loss, codebook_loss = model.quantizer(z)
            # Decode the chunk
            reconstructed_chunk = model.decoder(z_q)
            # Append codes
            codes_list.append(codes)

        original_chunks.append(chunk.cpu())
        reconstructed_chunks.append(reconstructed_chunk.cpu())

    # Concatenate all chunks
    original_audio = torch.cat(original_chunks, dim=-1)
    reconstructed_audio = torch.cat(reconstructed_chunks, dim=-1)
    codes = torch.cat(codes_list, dim=2)  # Concatenate along the time dimension

    return original_audio, reconstructed_audio, codes

def reconstruct_from_codes(codes, model, device='cuda'):
    """
    Reconstruct audio from quantized codes using RVQ's `from_codes`.
    
    Args:
        codes (torch.Tensor): Quantized codes (B x N x T).
        model (CODEC): Trained CODEC model.
        device (str): Device to run the model ('cuda' or 'cpu').
    
    Returns:
        torch.Tensor: Reconstructed continuous representation z_q.
    """
    with torch.no_grad():
        z_q, _, _ = model.quantizer.from_codes(codes.to(device))
        reconstructed_audio = model.decoder(z_q)
    return reconstructed_audio.cpu()

def save_audio(audio_tensor, sample_rate, output_path):
    """
    Save audio tensor to file.
    """
    torchaudio.save(output_path, audio_tensor, sample_rate)

def process_folder(input_folder, output_original_folder, output_reconstructed_folder, model, sample_rate, chunk_size, device='cuda'):
    """
    Process all .wav files in a folder.
    """
    if not os.path.exists(output_original_folder):
        os.makedirs(output_original_folder)
    if not os.path.exists(output_reconstructed_folder):
        os.makedirs(output_reconstructed_folder)

    # Collect all .wav files
    all_files = [
        os.path.join(root, file)
        for root, _, files in os.walk(input_folder)
        for file in files if file.endswith(".flac")
    ]

    # Process with progress bar
    for input_file_path in tqdm(all_files, desc="Processing files"):
        file_name = os.path.splitext(os.path.basename(input_file_path))[0]

        # Process the audio file
        original_audio, reconstructed_audio, codes = process_audio_in_chunks(
            input_file_path, sample_rate, model, chunk_size, device=device
        )

        # Ensure both audio have the same length
        min_length = min(original_audio.shape[-1], reconstructed_audio.shape[-1])
        original_audio = original_audio[..., :min_length]
        reconstructed_audio = reconstructed_audio[..., :min_length]

        # Save the audio files
        output_original_path = os.path.join(output_original_folder, f"{file_name}_original.wav")
        output_reconstructed_path = os.path.join(output_reconstructed_folder, f"{file_name}_reconstructed.wav")
        save_audio(original_audio[0], sample_rate, output_original_path)
        save_audio(reconstructed_audio[0], sample_rate, output_reconstructed_path)

if __name__ == "__main__":
    # Paths and parameters
    checkpoint_path = "/workspace/wptcodec/daccodec-speech/checkpoints/epoch-epoch=61-loss-train_loss=0.00.ckpt"  # Replace with the path to your trained model checkpoint
    input_folder = "/datasets/LibriSpeech/test-clean/"  # Replace with the folder containing input .wav files
    output_original_folder = "/workspace/wptcodec/daccodec-speech/original"  # Folder to save original audio
    output_reconstructed_folder = "/workspace/wptcodec/daccodec-speech/reconstructed"  # Folder to save reconstructed audio
    sample_rate = 16000  # Model's expected sample rate
    chunk_size = 16384*1  # Process ~5 seconds at a time

    # Load the model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = load_model(checkpoint_path, device=device)

    # Process all files in the folder
    process_folder(
        input_folder, output_original_folder, output_reconstructed_folder,
        model, sample_rate, chunk_size, device=device
    )

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import math
import torch
import torchaudio
from tqdm import tqdm
from model.wptcodec import CODEC
from utils.tools import batch_wdt

def load_model(checkpoint_path, device='cuda'):
    """
    Load the trained CODEC model from checkpoint.
    """
    model = CODEC.load_from_checkpoint(checkpoint_path,map_location='cuda:0',sample_rate=16000)
    model.eval()
    model.to(device)
    return model

def process_audio_in_chunks(file_path, sample_rate, model, chunk_size, device='cuda'):
    """
    Process audio file in chunks to avoid memory issues.
    
    Args:
        file_path (str): Path to the input audio file.
        sample_rate (int): Target sample rate for the model.
        model (CODEC): Trained CODEC model.
        chunk_size (int): Chunk size in samples.
        device (str): Device to run the model ('cuda' or 'cpu').
    
    Returns:
        torch.Tensor: Original audio tensor.
        torch.Tensor: Reconstructed audio tensor.
        torch.Tensor: Quantized codes.
    """
    # Load and preprocess the audio
    audio, sr = torchaudio.load(file_path)
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        audio = resampler(audio)

    audio = model.preprocess(audio, sample_rate).to(device)

    # Initialize outputs
    original_chunks = []
    reconstructed_chunks = []
    codes_list = []

    # Process audio in chunks
    for start in range(0, audio.shape[-1], chunk_size):
        end = min(start + chunk_size, audio.shape[-1])
        chunk = audio[0:1, start:end].unsqueeze(dim=0)

        if chunk.shape[2]<512:
            break

        with torch.no_grad():
            # Encode the chunk           

            audio_data = chunk[:,0:1,:]
            nodes = batch_wdt(audio_data,max_level=4)
            nodes_data = model.preprocess(nodes, model.sample_rate)
            nodes_data = nodes

            z = model.encoder(nodes_data)
            z_q, codes, latents, commitment_loss, codebook_loss = model.quantizer(z)
            # Decode the chunk
            reconstructed_chunk = model.decoder(z_q)
            # Append codes
            codes_list.append(codes)

        original_chunks.append(chunk.cpu())
        reconstructed_chunks.append(reconstructed_chunk.cpu())

    # Concatenate all chunks
    original_audio = torch.cat(original_chunks, dim=-1)
    reconstructed_audio = torch.cat(reconstructed_chunks, dim=-1)
    codes = torch.cat(codes_list, dim=2)  # Concatenate along the time dimension

    return original_audio, reconstructed_audio, codes

def reconstruct_from_codes(codes, model, device='cuda'):
    """
    Reconstruct audio from quantized codes using RVQ's `from_codes`.
    
    Args:
        codes (torch.Tensor): Quantized codes (B x N x T).
        model (CODEC): Trained CODEC model.
        device (str): Device to run the model ('cuda' or 'cpu').
    
    Returns:
        torch.Tensor: Reconstructed continuous representation z_q.
    """
    with torch.no_grad():
        z_q, _, _ = model.quantizer.from_codes(codes.to(device))
        reconstructed_audio = model.decoder(z_q)
    return reconstructed_audio.cpu()

def save_audio(audio_tensor, sample_rate, output_path):
    """
    Save audio tensor to file.
    """
    torchaudio.save(output_path, audio_tensor, sample_rate)

def process_folder(input_folder, output_original_folder, output_reconstructed_folder, model, sample_rate, chunk_size, device='cuda'):
    """
    Process all .wav files in a folder.
    """
    if not os.path.exists(output_original_folder):
        os.makedirs(output_original_folder)
    if not os.path.exists(output_reconstructed_folder):
        os.makedirs(output_reconstructed_folder)

    # Collect all .wav files
    all_files = [
        os.path.join(root, file)
        for root, _, files in os.walk(input_folder)
        for file in files if file.endswith(".flac")
    ]

    # Process with progress bar
    for input_file_path in tqdm(all_files, desc="Processing files"):
        file_name = os.path.splitext(os.path.basename(input_file_path))[0]

        # Process the audio file
        original_audio, reconstructed_audio, codes = process_audio_in_chunks(
            input_file_path, sample_rate, model, chunk_size, device=device
        )

        # Ensure both audio have the same length
        min_length = min(original_audio.shape[-1], reconstructed_audio.shape[-1])
        original_audio = original_audio[..., :min_length]
        reconstructed_audio = reconstructed_audio[..., :min_length]

        # Save the audio files
        output_original_path = os.path.join(output_original_folder, f"{file_name}_original.wav")
        output_reconstructed_path = os.path.join(output_reconstructed_folder, f"{file_name}_reconstructed.wav")
        save_audio(original_audio[0], sample_rate, output_original_path)
        save_audio(reconstructed_audio[0], sample_rate, output_reconstructed_path)

if __name__ == "__main__":
    # Paths and parameters
    checkpoint_path = "/workspace/wptcodec/wptcodec-speech-4/checkpoints/epoch-epoch=61-loss-train_loss=0.00.ckpt"  # Replace with the path to your trained model checkpoint
    input_folder = "/datasets/LibriSpeech/test-clean/"  # Replace with the folder containing input .wav files
    output_original_folder = "/workspace/wptcodec/wptcodec-speech-4/original"  # Folder to save original audio
    output_reconstructed_folder = "/workspace/wptcodec/wptcodec-speech-4/reconstructed"  # Folder to save reconstructed audio
    sample_rate = 16000  # Model's expected sample rate
    chunk_size = 16384*1  # Process ~5 seconds at a time

    # Load the model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = load_model(checkpoint_path, device=device)

    # Process all files in the folder
    process_folder(
        input_folder, output_original_folder, output_reconstructed_folder,
        model, sample_rate, chunk_size, device=device
    )


In [None]:
import os
import torch
from pathlib import Path
from nn.loss import mel_spectrogram_loss, multi_scale_stft_loss,sisdr_loss  # Replace with actual import path
import torchaudio


def compute_metrics(original: torch.Tensor, reconstructed: torch.Tensor, sample_rate: int) -> dict:
    """
    Compute audio metrics (mel loss and STFT loss) between original and reconstructed audio.

    Args:
        original (torch.Tensor): Original audio tensor (B x T).
        reconstructed (torch.Tensor): Reconstructed audio tensor (B x T).
        sample_rate (int): Sampling rate of the audio.

    Returns:
        dict: Dictionary containing the computed metrics.
    """
    return {
        "mel_loss": mel_spectrogram_loss(original, reconstructed, sample_rate=sample_rate).item(),
        "stft_loss": multi_scale_stft_loss(original, reconstructed).item(),
        "sisdr_loss":sisdr_loss(original.squeeze(dim=0),reconstructed.squeeze(dim=0))
    }


def process_audio_files(original_dir: str, reconstructed_dir: str, sample_rate: int) -> dict:
    """
    Compute average metrics for all audio files in the original and reconstructed directories.

    Args:
        original_dir (str): Path to the directory containing original audio files.
        reconstructed_dir (str): Path to the directory containing reconstructed audio files.
        sample_rate (int): Sampling rate for the audio files.

    Returns:
        dict: Dictionary containing the average metrics.
    """
    # Find all original and reconstructed files
    original_files = sorted(list(Path(original_dir).glob("*.wav")))
    reconstructed_files = sorted(list(Path(reconstructed_dir).glob("*.wav")))

    assert len(original_files) == len(
        reconstructed_files
    ), "Mismatch between the number of original and reconstructed files."

    total_metrics = {"mel_loss": 0.0, "stft_loss": 0.0,"sisdr_loss":0.0}
    num_files = len(original_files)

    for original_path, reconstructed_path in zip(original_files, reconstructed_files):
        print(f"Processing: {original_path.name}")
        # Load audio files
        original_audio, _ = torchaudio.load(original_path)
        reconstructed_audio, _ = torchaudio.load(reconstructed_path)

        # Compute metrics
        metrics = compute_metrics(original_audio.unsqueeze(dim=0), reconstructed_audio.unsqueeze(dim=0), sample_rate=sample_rate)
        total_metrics["mel_loss"] += metrics["mel_loss"]
        total_metrics["stft_loss"] += metrics["stft_loss"]
        total_metrics["sisdr_loss"] += metrics["sisdr_loss"]

    # Compute averages
    average_metrics = {k: v / num_files for k, v in total_metrics.items()}
    return average_metrics


if __name__ == "__main__":
    # Specify directories and sample rate
    original_dir = "/workspace/wptcodec/daccodec-speech/original"  # Replace with the actual path
    reconstructed_dir = "/workspace/wptcodec/daccodec-speech/reconstructed/"  # Replace with the actual path
    sample_rate = 16000  # Replace with your desired sample rate

    # Process audio files and compute average metrics
    avg_metrics = process_audio_files(original_dir, reconstructed_dir, sample_rate)

    # Print the average metrics
    print("\nAverage Metrics:")
    for k, v in avg_metrics.items():
        print(f"{k}: {v:.4f}")
