In [None]:
# Add the parent directory to the Python path
import sys
import os
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import loss function
from texstat.functions import *
import texstat.torch_filterbanks.filterbanks as fb

# Import extra packages
import numpy as np
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import time
import gc

# Pick device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
def benchmark_loss_function(loss_fn, input_shape=(32, 100), dtype=torch.float32, device='cuda', iterations=10, **loss_kwargs):
    """Benchmarks a given loss function for time and GPU memory usage."""
    torch.cuda.empty_cache()
    gc.collect()  # Garbage collection to free up memory before benchmarking
    
    # Create dummy inputs and targets
    inputs = torch.randn(*input_shape, dtype=dtype, device=device, requires_grad=True)
    targets = torch.randn(*input_shape, dtype=dtype, device=device)
    
    optimizer = optim.SGD([inputs], lr=0.01)
    
    computation_times = []
    grad_descent_times = []
    memory_usages = []
    
    for _ in range(iterations):
        torch.cuda.synchronize()
        start_mem = torch.cuda.memory_allocated(device)
        
        # Measure computation time
        start_time = time.time()
        loss = loss_fn(inputs, targets, **loss_kwargs)
        torch.cuda.synchronize()
        end_time = time.time()
        computation_time = end_time - start_time
        
        # Measure gradient descent time
        start_time = time.time()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        end_time = time.time()
        grad_descent_time = end_time - start_time
        
        end_mem = torch.cuda.memory_allocated(device)
        
        computation_times.append(computation_time)
        grad_descent_times.append(grad_descent_time)
        memory_usages.append(end_mem - start_mem)
    
    # Compute mean and standard deviation
    mean_computation_time = np.mean(computation_times)
    std_computation_time = np.std(computation_times)
    mean_grad_descent_time = np.mean(grad_descent_times)
    std_grad_descent_time = np.std(grad_descent_times)
    mean_memory = np.mean(memory_usages) / 1e6  # Convert to MB
    std_memory = np.std(memory_usages) / 1e6  # Convert to MB
    
    print(f"Computation Time:      {mean_computation_time:.6f} sec (±{std_computation_time:.6f})")
    print(f"Gradient Descent Time: {mean_grad_descent_time:.6f} sec (±{std_grad_descent_time:.6f})")
    print(f"Memory Usage:          {mean_memory:.2f} MB (±{std_memory:.2f})\n")
    
    return mean_computation_time, std_computation_time, mean_grad_descent_time, std_grad_descent_time, mean_memory, std_memory

In [None]:
# Parameters for TexStat
sr, frame_size = 44100, 2**16
N_filter_bank = 16
M_filter_bank = 6
N_moments     = 4
alpha         = torch.tensor([100, 1, 1/10, 1/100], device=device)
beta  = torch.tensor([1, 1, 1, 1, 1], device=device)
new_sr, new_frame_size = sr // 4, frame_size // 4
downsampler = torchaudio.transforms.Resample(sr, new_sr).to(device)
coch_fb = fb.EqualRectangularBandwidth(frame_size, sr, N_filter_bank, 20, sr // 2)
mod_fb  = fb.Logarithmic(new_frame_size, new_sr, M_filter_bank, 10, new_sr // 4)

def custom_texstat_loss(x, y, coch_fb, mod_fb, downsampler, N_moments, alpha, beta):
    return texstat_loss(x, y, coch_fb, mod_fb, downsampler, N_moments, alpha, beta)

# Running single computation benchmark
print("TexStat single computation benchmark:")
benchmark_loss_function(custom_texstat_loss, input_shape=(1, frame_size), device=device,
                        coch_fb=coch_fb, mod_fb=mod_fb, downsampler=downsampler,
                        N_moments=N_moments, alpha=alpha, beta=beta)

# Running batch computation benchmark
print("TexStat batch computation benchmark:")
benchmark_loss_function(custom_texstat_loss, input_shape=(32, frame_size), device=device,
                        coch_fb=coch_fb, mod_fb=mod_fb, downsampler=downsampler,
                        N_moments=N_moments, alpha=alpha, beta=beta)

print("Benchmarking done.")