Here, we implement Wasserstein regression using PyTorch. We use PyTorch so that we can utilize the AutoGrad framework to calculate gradients since I do not know a closed-form solution for Wasserstein regression similar to scalar linear regression.

If this does not work since some of the operations do not let the gradients pass through such as `torch.unique` or `torch.searchsorted`, we will use SPSA to estimate the gradients.

In [3]:
import torch # will use for AutoGrad
import numpy as np
import pandas as pd

In [9]:
torch.cat([torch.tensor(0).reshape(1), torch.diff(torch.tensor([1,2,3,4,5]))])

tensor([0, 1, 1, 1, 1])

In [17]:
def expectation(barycenter):
    # This is just the integral from 0 to 1 of the quantile function
    x,y = barycenter
    x_diff = torch.diff(x)
    return torch.sum(torch.mul(x_diff, y[1:])) + torch.mul(x[0], y[0])

In [43]:
x = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)
y = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)

mean = expectation((x,y))
mean

tensor(0.6250, grad_fn=<AddBackward0>)

In [44]:
mean.backward()

In [124]:
import torch

def empirical_quantile_function(samples):
    """
    Returns a function that computes the empirical quantile function for the given 1D samples.

    Args:
        samples (torch.Tensor): A 1D tensor of samples from a distribution. Assumes is sorted.

    Returns:
        function: A function that takes a tensor of quantiles (q) and returns the corresponding quantile values.
    """
    def quantile_function(q):
        """
        Computes the empirical quantile for the given quantiles.

        Args:
            q (torch.Tensor): A tensor of quantiles (values between 0 and 1).

        Returns:
            torch.Tensor: The corresponding quantile values.
        """
        # Compute the empirical CDF values
        n = len(samples)
        cdf = torch.arange(1, n + 1, dtype=torch.float32) / n
        # Use broadcasting to calculate the Heaviside contributions
        heaviside_matrix = torch.heaviside(q.unsqueeze(1) - cdf.unsqueeze(0), torch.tensor(1.0))
        # Compute quantile values by summing contributions
        quantile_values = heaviside_matrix @ samples

        return quantile_values

    return quantile_function

# Example usage
samples = torch.tensor([1.0, 2.0]).requires_grad_(True)
quantile_fn = empirical_quantile_function(samples)
quantiles = torch.tensor([0.1, 0.5, 1.1])
result = quantile_fn(quantiles)
print(result)

tensor([0., 1., 3.], grad_fn=<MvBackward0>)


In [41]:
def step_function(x_points, y_points, x):
    # Find the interval that x falls into
    idx = torch.searchsorted(x_points, x, right = True) - 1
    # If x is before the first point, return the first y
    if idx < 0:
        return y_points[0]
    # If x is beyond the last point, return the last y
    elif idx >= len(y_points):
        return y_points[-1]
    # Otherwise, return the y corresponding to the interval
    return y_points[idx]

val = step_function(x,y, 0.5)
val

tensor(0.5000, grad_fn=<SelectBackward0>)

In [57]:
val.backward()

In [131]:
samples1 = torch.tensor([1.0, 2.0]).requires_grad_(True)
samples2 = torch.tensor([2.0, 3.0]).requires_grad_(True)
samples3 = torch.tensor([3.0, 4.0]).requires_grad_(True)

q1 = empirical_quantile_function(samples1)
q2 = empirical_quantile_function(samples2)
q3 = empirical_quantile_function(samples3)

def linear_combination(quantile_fns, weights):
    def lin_comb_fn(q):
        # Compute the quantile values for each function
        quantile_values = torch.stack([fn(q) for fn in quantile_fns])
        # Compute the weighted sum of quantile values
        lin_comb_values = torch.sum(weights.unsqueeze(1) * quantile_values, dim=0)

        return lin_comb_values

    return lin_comb_fn

# Example usage
quantile_fns = [q1, q2, q3]
weights = torch.tensor([1/3, 1/3, 1/3])
barycenter_fn = linear_combination(quantile_fns, weights)

quantiles = torch.tensor([0.1, 0.5, 1.0])
barycenter_fn(quantiles)

tensor([0., 2., 5.], grad_fn=<SumBackward1>)

In [132]:
torch.sum(barycenter_fn(quantiles)).backward()

In [69]:
def squared_difference_integral(x1, y1, x2, y2):
    # Combine and sort all x points from both arrays, ensuring no duplicates
    x_points = torch.unique_consecutive(torch.cat([x1, x2]))
    total_integral = 0.0

    for i in range(len(x_points) - 1):
        # Define the interval [x_points[i], x_points[i+1]]
        x_left = x_points[i]
        x_right = x_points[i + 1]
        interval_length = x_right - x_left

        # Find the y-values for this interval
        # Last y-value from each array that is <= x_left
        y1_value = step_function(x1, y1, x_left)
        y2_value = step_function(x2, y2, x_left)

        # Compute squared difference and add to total integral
        squared_difference = (y1_value - y2_value) ** 2
        total_integral += squared_difference * interval_length

    return total_integral

def squared_difference_integral_approx(x1, y1, x2, y2, N = 1000):
    x_points = torch.linspace(0,1,N)
    step1 = torch.zeros_like(x_points)
    step2 = torch.zeros_like(x_points)
    
    for i in range(len(x_points)):
        step1[i] = step_function(x1, y1, x_points[i])
        step2[i] = step_function(x2, y2, x_points[i])
        
    return torch.sum((step1 - step2) ** 2) / N

x1 = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)
y1 = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)
x2 = torch.tensor([0, 0.5, 1]).requires_grad_(True)
y2 = torch.tensor([0, 0.5, 1]).requires_grad_(True)

diff = squared_difference_integral(x1, y1, x2, y2)
print(diff)

x1 = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)
y1 = torch.tensor([0, 0.25, 0.5, 0.75, 1]).requires_grad_(True)
x2 = torch.tensor([0, 0.5, 1]).requires_grad_(True)
y2 = torch.tensor([0, 0.5, 1]).requires_grad_(True)

diff2 = squared_difference_integral_approx(x1, y1, x2, y2)
print(diff2)

tensor(0.0312, grad_fn=<AddBackward0>)
tensor(0.0312, grad_fn=<DivBackward0>)


In [68]:
diff2.backward() # works

In [47]:
mean = expectation(barycenter)
mean

tensor(1.3125, grad_fn=<AddBackward0>)

In [48]:
mean.backward()

NotImplementedError: the derivative for 'unique_consecutive' is not implemented.

Instead of calculating the average directly, we could just approximate it with a lot of points.

In [None]:
def average_step_function_approx(step_functions):
    x = torch.linspace(0,1,10000)

    # Loop through each interval
    for i in range(len(all_x) - 1):
        x_left = all_x[i]
        
        # For each step function, get the y-value at the start of this interval
        y_values = torch.zeros(len(step_functions))
        for j, (x, y) in enumerate(step_functions):
            y_val = step_function(x, y, x_left)
            y_values[j] += y_val
        # Compute the average y-value for this interval
        avg_y[i] += torch.mean(y_values)

    # Add the final y-value after the last x-point
    y_values = torch.zeros(len(step_functions))
    for j, (x, y) in enumerate(step_functions):
        y_val = y[torch.searchsorted(x, all_x[-1], right=True) - 1]
        y_values[j] += y_val
    # Compute the average y-value for this interval
    avg_y[-1] += torch.mean(y_values)

    return all_x, avg_y

In [49]:
x = torch.linspace(0,1,10000)
y = torch.sin(2 * np.pi * x)

Now, we will try using SPSA to estimate the gradients since the gradients are not passing through the `torch.unique` and `torch.searchsorted` operations.