In [12]:
import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from IPython import display
from matplotlib import animation, cm

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.tools import plot

import dataclasses
from types import MappingProxyType
from typing import Any, Dict, Iterator, Literal, Mapping, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import sklearn.datasets
from functools import partial
from jax import jit

import optax
from flax import linen as nn

from matplotlib import pyplot as plt

from ott import datasets
from ott.geometry import costs, pointcloud
from ott.neural.methods import monge_gap
from ott.neural.networks import potentials
from ott.solvers.linear import acceleration
from ott.tools import sinkhorn_divergence
from ott.problems.linear import potentials

from scipy.stats import multivariate_normal
import plotly.graph_objects as go
import random

# Functions

In [13]:
def generate_multivariate_normal_points(key, N, mean, cov):
    """Generate N d-dimensional points from a multivariate normal distribution.

    Args:
    key: A PRNGKey used for random number generation.
    N: Number of points to generate.
    mean: Mean vector of the normal distribution.
    cov: Covariance matrix of the normal distribution.

    Returns:
    A jax.numpy array of shape (N, len(mean)), where each row is a point sampled from the distribution.
    """
    d = len(mean)  # Dimensionality inferred from the length of the mean vector
    return jax.random.multivariate_normal(key, mean, cov, shape=(N,))

In [14]:
def plot_3D(points_jax):
    
    # Convert JAX array to NumPy for plotting if necessary
    points = np.array(points_jax)

    # Plotting
    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='blue',  # Color can be modified
            opacity=0.5
        )
    )])
    fig.update_layout(
        title="3D Scatter Plot of Multivariate Normal Distribution",
        scene=dict(
            xaxis_title='X Axis',
            yaxis_title='Y Axis',
            zaxis_title='Z Axis'
        ),
        autosize=False,
        width=700,
        height=700,
        margin=dict(l=65, r=50, b=65, t=90)
    )
    fig.show()

In [15]:
def plot_3D_two_sources(points_jax1, points_jax2):
    # Convert JAX arrays to NumPy for plotting if necessary
    points1 = np.array(points_jax1)
    points2 = np.array(points_jax2)

    # Plotting
    fig = go.Figure()

    # Add the first set of points
    fig.add_trace(go.Scatter3d(
        x=points1[:, 0],
        y=points1[:, 1],
        z=points1[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='blue',  # Color for the first set
            opacity=0.5
        ),
        name='Set 1'  # Label for the legend
    ))

    # Add the second set of points
    fig.add_trace(go.Scatter3d(
        x=points2[:, 0],
        y=points2[:, 1],
        z=points2[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='red',  # Color for the second set
            opacity=0.5
        ),
        name='Set 2'  # Label for the legend
    ))

    # Update layout
    fig.update_layout(
        title="3D Scatter Plot of Two Sets of Points",
        scene=dict(
            xaxis_title='X Axis',
            yaxis_title='Y Axis',
            zaxis_title='Z Axis'
        ),
        autosize=False,
        width=700,
        height=700,
        margin=dict(l=65, r=50, b=65, t=90)
    )

    fig.show()

# Computations

In [16]:
#parameters
N = 10000

mean_source = np.array([0, 0, 0,])
cov_source = np.array([[1, 0, 0], [0, 1, 0.2], [0, 0.2, 1]])     

# JAX random key
key = jax.random.PRNGKey(0)

# Generate points
points_source = generate_multivariate_normal_points(key, N, mean_source, cov_source)

In [17]:
def apply_function(points):
    return 2 * points**2 -30
points_target = apply_function(points_source)

In [18]:
plot_3D_two_sources(points_source, points_target)

In [19]:
geom = ott.geometry.pointcloud.PointCloud(points_source, points_target)                   
problem = ott.problems.linear.linear_problem.LinearProblem(geom)  
solver = ott.solvers.linear.sinkhorn.Sinkhorn()          

out = solver(problem)                                     

f = out.f                                                      
g = out.g

entropic_potentials = potentials.EntropicPotentials(f, g, problem)

In [20]:
key, key_2 = jax.random.split(key, num=2)
X_p = generate_multivariate_normal_points(key, N, mean_source, cov_source)
transported_samples = entropic_potentials.transport(X_p)

In [21]:
plot_3D_two_sources(points_target, transported_samples)