In [26]:
# Add the parent directory to the Python path
import sys
import os
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import loss function
from texstat.functions import *
import texstat.torch_filterbanks.filterbanks as fb

# Import extra packages
import numpy as np
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import time
import gc

# Pick device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [27]:
def benchmark_loss_function(loss_fn, input_shape=(32, 100), dtype=torch.float32, device='cuda', iterations=10, **loss_kwargs):
    """Benchmarks a given loss function for time and GPU memory usage."""
    torch.cuda.empty_cache()
    gc.collect()  # Garbage collection to free up memory before benchmarking
    
    # Create dummy inputs and targets
    inputs  = torch.randn(*input_shape, dtype=dtype, device=device, requires_grad=True)
    targets = torch.randn(*input_shape, dtype=dtype, device=device)
    
    optimizer = optim.SGD([inputs], lr=0.01)
    
    computation_times = []
    grad_descent_times = []
    memory_usages = []
    
    for _ in range(iterations):
        torch.cuda.synchronize()
        start_mem = torch.cuda.memory_allocated(device)
        
        # Measure computation time
        start_time = time.time()
        loss = loss_fn(inputs, targets, **loss_kwargs)
        torch.cuda.synchronize()
        end_time = time.time()
        computation_time = end_time - start_time
        
        # Measure gradient descent time
        start_time = time.time()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        end_time = time.time()
        grad_descent_time = end_time - start_time
        
        end_mem = torch.cuda.memory_allocated(device)
        
        computation_times.append(computation_time)
        grad_descent_times.append(grad_descent_time)
        memory_usages.append(end_mem - start_mem)
    
    # Compute mean and standard deviation
    mean_computation_time = np.mean(computation_times)
    std_computation_time = np.std(computation_times)
    mean_grad_descent_time = np.mean(grad_descent_times)
    std_grad_descent_time = np.std(grad_descent_times)
    mean_memory = np.mean(memory_usages) / 1e6  # Convert to MB
    std_memory = np.std(memory_usages) / 1e6  # Convert to MB
    
    print(f"Computation Time:      {mean_computation_time*1000:.3f} ms (±{std_computation_time*1000:.3f} ms)")
    print(f"Gradient Descent Time: {mean_grad_descent_time*1000:.3f} ms (±{std_grad_descent_time*1000:.6f} ms)")
    print(f"Memory Usage:          {mean_memory:.2f} MB (±{std_memory:.2f})\n")
    
    return mean_computation_time, std_computation_time, mean_grad_descent_time, std_grad_descent_time, mean_memory, std_memory

In [28]:
import torch.nn.functional as F

frame_size = 2**16

# Mean Squared Error (MSE) Loss
def mse_loss(x, y):
    return F.mse_loss(x, y)

print("MSE single computation benchmark")
benchmark_loss_function(mse_loss, input_shape=(1, frame_size), device=device)

print("MSE batch computation benchmark")
benchmark_loss_function(mse_loss, input_shape=(32, frame_size), device=device)

print("------------------------------------------\n")

# Mean Absolute Error (MAE) Loss
def mae_loss(x, y):
    return F.l1_loss(x, y)

print("MAE single computation benchmark")
benchmark_loss_function(mae_loss, input_shape=(1, frame_size), device=device)
print("MAE batch computation benchmark")
benchmark_loss_function(mae_loss, input_shape=(32, frame_size), device=device)

print("Benchmarking done.")

MSE single computation benchmark
Computation Time:      0.099 ms (±0.166 ms)
Gradient Descent Time: 0.157 ms (±0.053906 ms)
Memory Usage:          0.05 MB (±0.16)

MSE batch computation benchmark
Computation Time:      0.192 ms (±0.305 ms)
Gradient Descent Time: 0.164 ms (±0.078221 ms)
Memory Usage:          1.68 MB (±5.03)

------------------------------------------

MAE single computation benchmark
Computation Time:      0.051 ms (±0.021 ms)
Gradient Descent Time: 0.252 ms (±0.311971 ms)
Memory Usage:          0.03 MB (±0.08)

MAE batch computation benchmark
Computation Time:      0.061 ms (±0.032 ms)
Gradient Descent Time: 0.177 ms (±0.078417 ms)
Memory Usage:          0.84 MB (±2.52)

Benchmarking done.


In [29]:
# Multiscale Spectrogram Loss for comparison
def multiscale_fft(signal, scales=[8192, 4096, 2048, 1024, 512, 256, 128], overlap=.75):
    stfts = []
    for s in scales:
        S = torch.stft(
            signal,
            s,
            int(s * (1 - overlap)),
            s,
            torch.hann_window(s).to(signal),
            True,
            normalized=True,
            return_complex=True,
        ).abs()
        stfts.append(S)
    return stfts

def safe_log(x):
    return torch.log(x + 1e-7)

def multiscale_spectrogram_loss(x, x_hat):
    ori_stft = multiscale_fft(x)
    rec_stft = multiscale_fft(x_hat)
    loss = 0
    for s_x, s_y in zip(ori_stft, rec_stft):
        lin_loss = (s_x - s_y).abs().mean()
        log_loss = (safe_log(s_x) - safe_log(s_y)).abs().mean()
        loss = loss + lin_loss + log_loss
    return loss

# Parameters for TexStat
frame_size = 2**16

# Running single computation benchmark
print("MSS single computation benchmark:")
benchmark_loss_function(multiscale_spectrogram_loss, input_shape=(1, frame_size), device=device)

# Running batch computation benchmark
print("MSS batch computation benchmark:")
benchmark_loss_function(multiscale_spectrogram_loss, input_shape=(32, frame_size), device=device)

print("Benchmarking done.")

MSS single computation benchmark:
Computation Time:      1.817 ms (±0.882 ms)
Gradient Descent Time: 1.144 ms (±0.155149 ms)
Memory Usage:          0.03 MB (±0.08)

MSS batch computation benchmark:
Computation Time:      3.888 ms (±0.314 ms)
Gradient Descent Time: 8.485 ms (±0.280359 ms)
Memory Usage:          0.85 MB (±2.60)

Benchmarking done.


In [32]:
# Parameters for TexStat
sr, frame_size = 44100, 2**16
N_filter_bank = 16
M_filter_bank = 6
N_moments     = 4
alpha         = torch.tensor([10, 1, 1/10, 1/100], device=device)
beta  = torch.tensor([1, 1, 1, 1, 1], device=device)
new_sr, new_frame_size = sr // 4, frame_size // 4
downsampler = torchaudio.transforms.Resample(sr, new_sr).to(device)
coch_fb = fb.EqualRectangularBandwidth(frame_size, sr, N_filter_bank, 20, sr // 2)
mod_fb  = fb.Logarithmic(new_frame_size, new_sr, M_filter_bank, 10, new_sr // 4)

def custom_texstat_loss(x, y, coch_fb, mod_fb, downsampler, N_moments, alpha, beta):
    return texstat_loss(x, y, coch_fb, mod_fb, downsampler, N_moments, alpha, beta)

# Running single computation benchmark
print("TexStat single computation benchmark:")
benchmark_loss_function(custom_texstat_loss, input_shape=(1, frame_size), device=device,
                        coch_fb=coch_fb, mod_fb=mod_fb, downsampler=downsampler,
                        N_moments=N_moments, alpha=alpha, beta=beta)

# Running batch computation benchmark
print("TexStat batch computation benchmark:")
benchmark_loss_function(custom_texstat_loss, input_shape=(32, frame_size), device=device,
                        coch_fb=coch_fb, mod_fb=mod_fb, downsampler=downsampler,
                        N_moments=N_moments, alpha=alpha, beta=beta)

print("Benchmarking done.")

TexStat single computation benchmark:
Computation Time:      6.604 ms (±4.834 ms)
Gradient Descent Time: 6.047 ms (±0.818343 ms)
Memory Usage:          0.29 MB (±0.87)

TexStat batch computation benchmark:
Computation Time:      93.458 ms (±0.485 ms)
Gradient Descent Time: 154.615 ms (±0.401312 ms)
Memory Usage:          0.84 MB (±2.52)

Benchmarking done.
