# Runtime evaluation similar to [1]
References:
[1] Speeding Up Permutation Invariant Training for Source Separation

In [None]:
from scipy.optimize import linear_sum_assignment

import paderbox as pb
import itertools
import timeit
from collections import defaultdict
import torch
import numpy as np
import padertorch as pt
import paderbox as pb
from tqdm.notebook import tqdm
import graph_pit
import torch
import padertorch as pt
import functools

In [None]:
# General settings
max_time = 1  # Time in s over which runs are ignored
device = 'cpu' # Device for computing the losses / loss matrices. The permutation solver always works on the CPU
number = 3  # Number of runs per configuration. Higher number means smoother curves, but large values are impractical for an interactive notebook

In [None]:
# Utilty functions

def plot_timings(timings, xrange, xlabel, logx=False):
    with pb.visualization.axes_context() as ac:
        for idx, (key, values) in enumerate(timings.items()):
            values = np.asarray(values)
            x = xrange[:len(values)]
            pb.visualization.plot.line(x, values.mean(axis=-1), label=key, ax=ac.last, color=f'C{idx}')
            ac.last.fill_between(x, values.min(axis=-1), values.max(axis=-1), color=f'C{idx}', alpha=0.3)
    #         std = values.std(axis=-1)
    #         mean = values.mean(axis=-1)
    #         ac.last.fill_between(x, mean - std, mean + std, color=f'C{idx}', alpha=0.3)
        if logx:
            ac.last.loglog()
        else:
            ac.last.semilogy()
        ac.last.set_xlabel(xlabel)
        ac.last.set_ylabel('runtime in s')
        ac.last.set_ylim([ac.last.get_ylim()[0], max_time])
        ac.last.set_xlim([xrange[0], xrange[-1]]) 

## uPIT

In [None]:
from padertorch.ops.losses.source_separation import pit_loss_from_loss_matrix, compute_pairwise_losses
from torch.nn.functional import mse_loss

In [None]:
# Define the uPIT loss functions

def upit_sa_sdr_decomp_dot(estimate, target, algorithm='hungarian'):
    """
    sa-SDR decomposed with dot product, eq. (13)/(14)
    """
    loss_matrix = -torch.matmul(estimate, target.T)
    loss = pit_loss_from_loss_matrix(
        loss_matrix, reduction='sum', algorithm=algorithm
    )
    numerator = torch.sum(target**2)
    loss = -10*(torch.log10(numerator) - torch.log10(
        numerator + torch.sum(estimate**2) + 2*loss
    ))
    return loss

def upit_sa_sdr_decomp_mse(estimate, target, algorithm='hungarian'):
    """
    sa-SDR decomposed with MSE, eq. (11)/(12)
    """
    loss_matrix = compute_pairwise_losses(estimate, target, axis=0, loss_fn=functools.partial(mse_loss, reduction='sum'))
    loss = pit_loss_from_loss_matrix(
        loss_matrix, reduction='sum', algorithm=algorithm
    )
    loss = -10*(torch.log10(torch.sum(target**2)) - torch.log10(
        loss
    ))
    return loss

def upit_sa_sdr_naive_brute_force(estimate, target):
    """
    Brute-force sa-SDR, eq. (5)
    """
    return pt.pit_loss(estimate, target, 0, pt.source_aggregated_sdr_loss)

def upit_a_sdr_naive_brute_force(estimate, target):
    """
    Brute-force a-SDR
    """
    return pt.pit_loss(estimate, target, 0, pt.sdr_loss)

def upit_a_sdr_decomp(estimate, target, algorithm='hungarian'):
    """
    Decomposed a-SDR
    """
    loss_matrix = compute_pairwise_losses(estimate, target, axis=0, loss_fn=pt.sdr_loss)
    loss = pit_loss_from_loss_matrix(
        loss_matrix, reduction='mean', algorithm=algorithm
    )
    return loss

In [None]:
# Check if the losses all give the same loss values
estimate = torch.randn(3, 32000)
target = torch.randn(3, 32000)

ref = upit_sa_sdr_naive_brute_force(estimate, target)
np.testing.assert_allclose(ref, upit_sa_sdr_decomp_dot(estimate, target), rtol=1e-5)
np.testing.assert_allclose(ref, upit_sa_sdr_decomp_dot(estimate, target, algorithm='brute_force'), rtol=1e-5)
np.testing.assert_allclose(ref, upit_sa_sdr_decomp_mse(estimate, target), rtol=1e-5)
np.testing.assert_allclose(ref, upit_sa_sdr_decomp_mse(estimate, target, algorithm='brute_force'), rtol=1e-5)

ref = upit_a_sdr_naive_brute_force(estimate, target)
np.testing.assert_allclose(ref, upit_a_sdr_decomp(estimate, target), rtol=1e-5)
np.testing.assert_allclose(ref, upit_a_sdr_decomp(estimate, target, algorithm='brute_force'), rtol=1e-5)

In [None]:
# Define all loss functions whose runtime we want to compare
losses = {
    'sa_sdr naive brute_force': upit_sa_sdr_naive_brute_force,
    'sa_sdr brute_force deomp mse': functools.partial(upit_sa_sdr_decomp_mse, algorithm='brute_force'),
    'sa_sdr brute_force deomp dot': functools.partial(upit_sa_sdr_decomp_dot, algorithm='brute_force'),
    'sa_sdr hungarian decomp mse': upit_sa_sdr_decomp_mse,
    'sa_sdr hungarian decomp dot': upit_sa_sdr_decomp_dot,
    'a_sdr naive brute_force': upit_a_sdr_naive_brute_force,
    'a_sdr decomp brute_force': functools.partial(upit_a_sdr_decomp, algorithm='brute_force'),
    'a_sdr decomp hungarian': upit_a_sdr_decomp,
}

In [None]:
# Settings for uPIT
num_speakers_range = list(range(2, 100))
T = 32000

In [None]:
def time_loss(loss, num_speakers=3, T=8000 * 4, number=10, device='cuda'):
    import torch
    targets = torch.tensor(np.random.randn(num_speakers, T)).to(device)
    estimates = torch.tensor(np.random.randn(num_speakers, T)).to(device)
    timings = timeit.repeat('float(loss(estimates, targets).cpu())', globals=locals(), repeat=number, number=1)
    timings = np.asarray(timings)
    return timings

upit_timings = defaultdict(list)
skip = defaultdict(lambda: False)

for num_speakers in tqdm(num_speakers_range):
    for loss_name, loss_fn in losses.items():
        if skip[loss_name]:
            continue
        timing = time_loss(loss_fn, num_speakers=num_speakers, number=number, device=device, T=T)
        upit_timings[loss_name].append(timing)
        if np.mean(timing) > max_time:
            skip[loss_name] = True 

In [None]:
plot_timings(upit_timings, num_speakers_range, '#speakers', logx=True)

- Brute-force becomes impractical for very small numbers of speakers (<10)
- The Hungarian Algorithm can be used for large numbers of speakers with no significant runtime
- The dot decomposition is the fastest here. It is, however, probably possible to push the MSE below the dot with a low-level implementation

## Graph-PIT assignment algorithms

In [None]:
graph_pit_losses = {
    'naive brute-force': graph_pit.loss.unoptimized.GraphPITLossModule(pt.source_aggregated_sdr_loss),
    'decomp brute-force': graph_pit.loss.optimized.OptimizedGraphPITSourceAggregatedSDRLossModule(assignment_solver='optimal_brute_force'),
    'decomp branch-and-bound': graph_pit.loss.optimized.OptimizedGraphPITSourceAggregatedSDRLossModule(assignment_solver='optimal_branch_and_bound'),
    'decomp dfs': graph_pit.loss.optimized.OptimizedGraphPITSourceAggregatedSDRLossModule(assignment_solver='dfs'),
    'decomp dynamic programming': graph_pit.loss.optimized.OptimizedGraphPITSourceAggregatedSDRLossModule(assignment_solver='optimal_dynamic_programming'),
} 

In [None]:
num_utterances_range = list(range(2, 30))
utterance_length = 8000
overlap = 500

In [None]:
def time_alg(loss, num_segments, num_estimates=3, number=10, device='cpu',
             utterance_length=2*8000, overlap=500):
    timings = []
    for i in range(number):
        segment_boundaries = [
            (i * (utterance_length - overlap), (i + 1) * utterance_length)
            for i in range(num_segments)
        ]
        num_samples = max(s[-1] for s in segment_boundaries) + 100
        targets = [torch.rand(stop - start).to(device) for start, stop in segment_boundaries]
        estimate = torch.rand(num_estimates, num_samples).to(device)

        timings.append(timeit.timeit(
            # 'float(l.loss.cpu().numpy())',
            setup='l = loss.get_loss_object(estimate, targets, segment_boundaries)',
            stmt='float(l.loss.cpu())',
            globals={
            'loss': loss,
            'estimate': estimate,
            'targets': targets,
            'segment_boundaries': segment_boundaries,
        }, number=1))
    return np.asarray(timings)

graph_pit_timings = defaultdict(list)
skip = defaultdict(lambda: False)

for num_segments in tqdm(num_utterances_range):
    for loss_name, loss_fn in graph_pit_losses.items():
        if skip[loss_name]:
            continue
        timing = time_alg(loss_fn, num_segments=num_segments, number=number, device='cpu', utterance_length=utterance_length, overlap=overlap)
        graph_pit_timings[loss_name].append(timing)
        if np.mean(timing) > max_time:
            skip[loss_name] = True 

In [None]:
plot_timings(graph_pit_timings, num_utterances_range, '#utterances')

- The brute-force variants quickly become inpractical for training a network
- The branch-and-bound algorithm has a much larger variance in its runtime than all other algorithms
- The dynamic programming algorithm has a similar runtime compared to the DFS algorithm, but it always finds the optimal coloring
- DFS and dynamic programming have a runtime that is neglectible compared to common network architectures