In [None]:
from ddsp_textures.dataset.makers    import read_wavs_from_folder
from ddsp_textures.auxiliar.features import *

audio_path    = ...
sampling_rate = ...
frame_size    = ...
hop_size      = ...
audios_list   = read_wavs_from_folder(audio_path, sampling_rate)
data          = []
for audio in audios_list:
    size = len(audio)
    number_of_segments = (size - frame_size) // hop_size
    print(f"Number of segments: {number_of_segments}")
    for i in range(number_of_segments):
        segment = audio[i * hop_size : i * hop_size + frame_size]
        segment = audio_improver(segment, sampling_rate, 4)
        segment = signal_normalizer(segment)
        data.append([segment, i])

In [None]:
import torch
from ddsp_textures.loss.functions import *

from torch.utils.data import DataLoader, TensorDataset

# Define constants (replace these with actual values)
N_filter_bank = ...
M_filter_bank = ...
erb_bank      = ...
log_bank      = ...
downsampler   = ...

# Convert data into tensors if necessary
signals    = [item[0] for item in data]
categories = [item[1] for item in data]

# Create a DataLoader for batching
batch_size = 16  # Choose a batch size based on available memory
dataset    = TensorDataset(torch.tensor(signals), torch.tensor(categories))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize alpha with random values
alpha = torch.randn(5, requires_grad=True)  # Start with random values

# Optimizer
optimizer = torch.optim.Adam([alpha], lr=0.01)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        batch_signals, batch_categories = batch

        optimizer.zero_grad()
        
        # Apply softmax to ensure sum constraint on alpha
        normalized_alpha = torch.softmax(alpha, dim=0)

        # Calculate loss within the batch
        batch_loss = 0
        batch_size = len(batch_signals)
        
        for i in range(batch_size):
            for j in range(i + 1, batch_size):
                signal_1, category_1 = batch_signals[i], batch_categories[i]
                signal_2, category_2 = batch_signals[j], batch_categories[j]

                # Calculate loss using batch_statistics_loss
                if category_1 == category_2:
                    # Minimize distance for same class
                    batch_loss += statistics_loss(
                        signal_1, signal_2, N_filter_bank, M_filter_bank,
                        erb_bank, log_bank, downsampler, normalized_alpha
                    )
                else:
                    # Maximize distance for different classes
                    batch_loss -= statistics_loss(
                        signal_1, signal_2, N_filter_bank, M_filter_bank,
                        erb_bank, log_bank, downsampler, normalized_alpha
                    )
        
        # Backpropagate and optimize
        batch_loss.backward()
        optimizer.step()
        
        # Track total loss for reporting
        total_loss += batch_loss.item()
    
    # Optionally print progress
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss}")

# Final optimized parameters
final_alpha = torch.softmax(alpha, dim=0).detach()
print("Optimized Parameters:", final_alpha)