In [None]:
# Add the parent directory to the Python path
import sys
import os
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import loss function
from texstat.functions import *
import texstat.torch_filterbanks.filterbanks as fb

# Import extra packages
import numpy as np
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import torch
import torchaudio
import time

# Pick device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# texture generation (for testing
def texture_generator(sr, duration, num_sounds):
    sound = np.zeros(duration * sr)
    for i in range(num_sounds):
        frequency = 110 * 2**(7 * np.random.rand(1) + 1)
        sinusoid = np.sin(2 * np.pi * frequency * np.linspace(0, duration, duration * sr)+
                          np.random.rand(1) * 2 * np.pi)
        sound += sinusoid
    return sound / np.max(np.abs(sound))

# Generate two sounds
sr = 44100
sound_1 = texture_generator(sr, 5, 150)
sound_2 = texture_generator(sr, 5, 150)

# # Uncomment to load proper audio files 
# sound_1_path  = "your_sound_1.wav"
# sound_2_path  = "your_sound_2.wav"
# sr     = 44100
# sound_1, _  = librosa.load(sound_1_path, sr=sr, mono=True)
# sound_1     = sound_1/np.max(np.abs(sound_1))
# sound_2, _ = librosa.load(sound_2_path, sr=sr, mono=True)
# sound_2    = sound_2/np.max(np.abs(sound_2))

# # white noise generation for testing
# sr = 44100
# sound_1 = np.random.normal(0, 1, 5*sr)
# sound_2 = np.random.normal(0, 1, 5*sr)

# display audio
display(Audio(sound_1, rate=sr)) # sound_1
display(Audio(sound_2, rate=sr)) # sound_2

# Pick parameters
frame_size    = 2**16
N_filter_bank = 16
M_filter_bank = 6
N_moments     = 4
alpha         = torch.tensor([100, 1, 1/10, 1/100])
beta          = torch.tensor([1, 1, 1, 1, 1])

# Chop segments of both audios and make them tensors
sound_1_segment = torch.tensor(sound_1[:frame_size], device=device)
sound_2_segment = torch.tensor(sound_2[:frame_size], device=device)

# Make a batch out of the sounds
batch_size = 8
sound_1_batch = torch.stack([sound_1_segment] * batch_size).to(device)
sound_2_batch = torch.stack([sound_2_segment] * batch_size).to(device)

# Make filters and downsampler
new_sr, new_frame_size = sr // 4, frame_size // 4 # for downsampler
downsampler = torchaudio.transforms.Resample(sr, new_sr).to(device)
coch_fb = fb.EqualRectangularBandwidth(frame_size, sr, N_filter_bank, 20, sr // 2)
mod_fb  = fb.Logarithmic(new_frame_size,       new_sr, M_filter_bank, 10, new_sr // 4)

In [None]:
# Compute summary statistics of sound_1
stats_sound_1 = statistics_mcds(sound_1_segment, coch_fb, mod_fb, downsampler, N_moments, alpha)
stats_1_sound_1, stats_2_sound_1, stats_3_sound_1, stats_4_sound_1, stats_5_sound_1 = stats_sound_1
print("Sound 1 summary statistics: ")
print("Stats_1_sound_1:", stats_1_sound_1)
print("Stats_2_sound_1:", stats_2_sound_1)
print("Stats_3_sound_1:", stats_3_sound_1)
print("Stats_4_sound_1:", stats_4_sound_1)
print("Stats_5_sound_1:", stats_5_sound_1)

In [None]:
# Loss function computation between segments of sound_1 and sound_2
loss = texstat_loss(sound_1_segment, sound_2_segment, coch_fb, mod_fb, downsampler, N_moments, alpha, beta)
print("Loss:", loss)

In [None]:
# Compute summary statistics for the batch corresponding to sound_1
stats_sound_1_batch = statistics_mcds(sound_1_batch, coch_fb, mod_fb, downsampler, N_moments, alpha)
stats_1_sound_1_batch, stats_2_sound_1_batch, stats_3_sound_1_batch, stats_4_sound_1_batch, stats_5_sound_1_batch = stats_sound_1_batch
print("Sound 1 summary statistics: ")
print("Stats_1_sound_1_batch:", stats_1_sound_1_batch)
print("Stats_2_sound_1_batch:", stats_2_sound_1_batch)
print("Stats_3_sound_1_batch:", stats_3_sound_1_batch)
print("Stats_4_sound_1_batch:", stats_4_sound_1_batch)
print("Stats_5_sound_1_batch:", stats_5_sound_1_batch)

In [None]:
# Compute loss function between batches of sound_1 and sound_2
loss_batch = texstat_loss(sound_1_batch, sound_2_batch, coch_fb, mod_fb, downsampler, N_moments, alpha, beta)
print("Loss batch:", loss_batch)