In [1]:
import torch
import torch.nn.functional as F
import typing as tp
import matplotlib.pyplot as plt
import random

In [2]:
pan_bias = 30
S = 4
C = 2
T = 1024

arr = torch.arange(0.0, 8192, 1.0) 
arr = arr.view(S, C, T)

In [3]:
def inconsistency_pan(
    mix_segment: torch.Tensor
) -> torch.Tensor:
    assert mix_segment.shape[0] == 4  # Ensure there are 4 tracks
    assert mix_segment.shape[1] == 2  # Ensure stereo channels

    # Generate random pan values for each track
    deg = torch.tensor([random.randint(180 - pan_bias, 180 + pan_bias) for _ in range(4)], dtype=torch.float32)
    x = deg / 360.0

    norm_power = 2 / (2 ** 0.5)
    # Calculate left and right gains based on pan values using the sin pan law
    right_amps = torch.sin(x       * torch.pi * 0.5) * norm_power
    left_amps  = torch.sin((1 - x) * torch.pi * 0.5) * norm_power
    
    # Apply the gains to left and right channels
    mix_segment[:, 0, :] *= left_amps.view(-1, 1)
    mix_segment[:, 1, :] *= right_amps.view(-1, 1)
    
    sum_of_squares = right_amps**2 + left_amps**2
    tolerance = 1e-6
    assert torch.all(torch.isclose(sum_of_squares, torch.tensor(2.0), atol=tolerance)), \
    f"Assertion failed: not all elements are close to 2 within tolerance {tolerance}. Values: {sum_of_squares}"
    
    assert not torch.isnan(mix_segment).any()
    return mix_segment

In [4]:
out = inconsistency_pan(arr)