In [None]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset


import numpy as np


from util import NSimplex, define_style


import matplotlib.pyplot as plt


define_style()
device = torch.device("cpu")

Prep
===

Models
---

In [None]:

class SinusoidalEmbedding(nn.Module):
    def __init__(self, size: int, scale: float = 1.0):
        super().__init__()
        self.size = size
        self.scale = scale

    def forward(self, x: torch.Tensor):
        x = x * self.scale
        half_size = self.size // 2
        emb = torch.log(torch.Tensor([10000.0]).to(x.device)) / (half_size - 1)
        emb = torch.exp(-emb * torch.arange(half_size).to(x.device))
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
        return emb

    def __len__(self):
        return self.size


class LinearEmbedding(nn.Module):
    def __init__(self, size: int, scale: float = 1.0):
        super().__init__()
        self.size = size
        self.scale = scale

    def forward(self, x: torch.Tensor):
        x = x / self.size * self.scale
        return x.unsqueeze(-1)

    def __len__(self):
        return 1


class LearnableEmbedding(nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.size = size
        self.linear = nn.Linear(1, size)

    def forward(self, x: torch.Tensor):
        return self.linear(x.unsqueeze(-1).float() / self.size)

    def __len__(self):
        return self.size


class IdentityEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return x.unsqueeze(-1)

    def __len__(self):
        return 1


class ZeroEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return x.unsqueeze(-1) * 0

    def __len__(self):
        return 1


class PositionalEmbedding(nn.Module):
    def __init__(self, size: int, type: str, **kwargs):
        super().__init__()

        if type == "sinusoidal":
            self.layer = SinusoidalEmbedding(size, **kwargs)
        elif type == "linear":
            self.layer = LinearEmbedding(size, **kwargs)
        elif type == "learnable":
            self.layer = LearnableEmbedding(size)
        elif type == "zero":
            self.layer = ZeroEmbedding()
        elif type == "identity":
            self.layer = IdentityEmbedding()
        else:
            raise ValueError(f"Unknown positional embedding type: {type}")

    def forward(self, x: torch.Tensor):
        return self.layer(x)


class Block(nn.Module):
    def __init__(
        self, size: int, t_emb_size: int = 0, add_t_emb=False, concat_t_emb=False
    ):
        super().__init__()

        in_size = size + t_emb_size if concat_t_emb else size
        self.ff = nn.Linear(in_size, size)
        self.act = nn.GELU()

        self.add_t_emb = add_t_emb
        self.concat_t_emb = concat_t_emb

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
        in_arg = torch.cat([x, t_emb], dim=-1) if self.concat_t_emb else x
        out = x + self.act(self.ff(in_arg))

        if self.add_t_emb:
            out = out + t_emb

        return out


class MyMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int = 128,
        hidden_layers: int = 3,
        emb_size: int = 128,
        out_dim: int = 2,
        time_emb: str = "sinusoidal",
        input_emb: str = "sinusoidal",
        add_t_emb: bool = False,
        concat_t_emb: bool = False,
        input_dim: int = 2,
        energy_function=None,
    ):
        super().__init__()

        self.add_t_emb = add_t_emb
        self.concat_t_emb = concat_t_emb

        self.time_mlp = PositionalEmbedding(emb_size, time_emb)

        positional_embeddings = []
        for i in range(input_dim):
            embedding = PositionalEmbedding(emb_size, input_emb, scale=25.0)

            self.add_module(f"input_mlp{i}", embedding)

            positional_embeddings.append(embedding)

        self.channels = 1
        self.self_condition = False
        concat_size = len(self.time_mlp.layer) + sum(
            map(lambda x: len(x.layer), positional_embeddings)
        )

        layers = [nn.Linear(concat_size, hidden_size)]
        for _ in range(hidden_layers):
            layers.append(Block(hidden_size, emb_size, add_t_emb, concat_t_emb))

        in_size = emb_size + hidden_size if concat_t_emb else emb_size
        layers.append(nn.Linear(in_size, out_dim))

        self.layers = layers
        self.joint_mlp = nn.Sequential(*layers)

    def forward(self, x, t, x_self_cond=False):
        positional_embs = [
            self.get_submodule(f"input_mlp{i}")(x[:, i]) for i in range(x.shape[-1])
        ]

        t_emb = self.time_mlp(t.squeeze())
        x = torch.cat((*positional_embs, t_emb), dim=-1)

        for i, layer in enumerate(self.layers):
            if i == 0:
                x = nn.GELU()(layer(x))
                if self.add_t_emb:
                    x = x + t_emb

            elif i == len(self.layers) - 1:
                if self.concat_t_emb:
                    x = torch.cat([x, t_emb], dim=-1)

                x = layer(x)

            else:
                x = layer(x, t_emb)

        #return torch.softmax(x, dim=-1)
        x_3 = -x[:, 0] - x[:, 1]
        return torch.cat([x, x_3.unsqueeze(-1)], dim=-1)

In [None]:
class MLP(nn.Module):
    def __init__(self, dim: int, depth: int, hidden: int, batch_norm: bool = False, time_fts: int = 0):
        super().__init__()
        self.time_fts = time_fts
        net = []
        for i in range(depth):
            # + 1 for time
            out = hidden if i < depth - 1 else dim
            net += [nn.Linear(
                dim + 1 + time_fts if i == 0 else hidden,
                out
            )]
            if i < depth - 1:
                if batch_norm:
                    net += [nn.BatchNorm1d(out)]
                net += [nn.ReLU()]
        self.net = nn.Sequential(*net)

    def forward(self, x, t):
        x = torch.cat([x, t] + [torch.cos(t ** i) for i in range(self.time_fts)], dim=-1)
        return torch.softmax(self.net(x), dim=-1)

Dataset
---

In [None]:
def plot_dirichlet_3d(points):
    v_a = torch.Tensor([[0, 1.0]])
    v_b = torch.Tensor([[-0.5, 0]])
    v_c = torch.Tensor([[0.5, 0]])
    points = points[:, 0].unsqueeze(-1) * v_a + points[:, 1].unsqueeze(-1) * v_b + points[:, 2].unsqueeze(-1) * v_c
    plt.scatter(points[:, 0], points[:, 1])
    plt.show()

def plot(points, points_b):

    v_a = torch.Tensor([[0, 1.0]])
    v_b = torch.Tensor([[-0.5, 0]])
    v_c = torch.Tensor([[0.5, 0]])
    points = points[:, 0].unsqueeze(-1) * v_a + points[:, 1].unsqueeze(-1) * v_b + points[:, 2].unsqueeze(-1) * v_c
    plt.scatter(points[:, 0], points[:, 1])


    v_a = torch.Tensor([[0, 1.0]])
    v_b = torch.Tensor([[-0.5, 0]])
    v_c = torch.Tensor([[0.5, 0]])
    points_b = points_b[:, 0].unsqueeze(-1) * v_a + points_b[:, 1].unsqueeze(-1) * v_b + points_b[:, 2].unsqueeze(-1) * v_c
    plt.scatter(points_b[:, 0], points_b[:, 1])
    plt.show()

In [None]:
def generate_simple_dirichlet(points: int, alpha):
    ret = []
    for _ in range(points):
        p = np.random.dirichlet(alpha)
        ret += [torch.Tensor(p)]
    return torch.stack(ret)

In [None]:
def generate_dirichlet_mixture(points: int, *alphas):
    ret = []
    import random
    for _ in range(points):
        dist = random.choice(alphas)
        ret += [torch.Tensor(np.random.dirichlet(dist))]
    return torch.stack(ret)

In [None]:
plot_dirichlet_3d(generate_simple_dirichlet(1000, [50, 1, 1]))
plot_dirichlet_3d(generate_simple_dirichlet(1000, [1, 50, 1]))
plot_dirichlet_3d(generate_simple_dirichlet(1000, [1, 1, 50]))
plot_dirichlet_3d(generate_dirichlet_mixture(1000, [50, 1, 1], [1, 50, 1], [1, 1, 50]))

In [None]:
raw_dataset = generate_dirichlet_mixture(10000, [50, 1, 1], [1, 50, 1], [1, 1, 50])
dataset = TensorDataset(raw_dataset)
test_dataset = TensorDataset(generate_dirichlet_mixture(1000, [50, 1, 1], [1, 50, 1], [1, 1, 50]))
train_loader = DataLoader(dataset, 128, shuffle=True)
test_loader = DataLoader(test_dataset, 128, shuffle=False)

Train
---

In [None]:
def train(model: nn.Module, epochs: int, lr: float = 1e-3, time_eps: float = 0.0):
    optimizer = Adam(model.parameters(), lr=lr)
    per_epoch = []
    per_epoch_test = []
    m = NSimplex()
    model = model.to(device)
    w1s = []
    for epoch in range(epochs):
        model.train()
        losses = []
        if epoch % 10 == 0:
            with torch.no_grad():
                points = test_dataset.tensors[0].shape[0]
                final_traj = m.tangent_euler(torch.stack([torch.Tensor(np.random.dirichlet([1, 1, 1])) for _ in range(points)]), model, 100)
                test = torch.cat(test_dataset.tensors)
                w1 = m.wasserstein_dist(test, final_traj, power=2)
                w1s.append(w1)
            print(f"W1 distance: {w1:.5f}")
        for x_1 in train_loader:
            x_1 = x_1[0]
            optimizer.zero_grad()
            times = torch.rand((x_1.size(0), 1)) * (1.0 - time_eps) + time_eps

            # Mapping uniform Dirichlet to our target distribution
            x_0 = torch.stack([torch.Tensor(np.random.dirichlet([1, 1, 1])) for _ in range(x_1.size(0))])
            x_t = m.geodesic_interpolant(x_0, x_1, times)
            target = m.log_map(x_0, x_1)
            target = m.parallel_transport(x_0, x_t, target)
            out = model(x_t, times)
            diff = out - target
            loss = m.square_norm_at(x_t, diff).mean()
            
            loss.backward()
            optimizer.step()
            losses += [loss.item()]
        print(f"--- Epoch {epoch+1:03d}/{epochs:03d}: mean loss {np.mean(losses):.5f}")
        per_epoch += [np.mean(losses)]
        # Test
        model.eval()
        test_loss = []
        with torch.no_grad():
            for x_1 in test_loader:
                x_1 = x_1[0].to(device)
                times = torch.rand((x_1.size(0), 1)) * (1.0 - time_eps) + time_eps

                # Mapping uniform Dirichlet to our target distribution
                x_0 = torch.stack([torch.Tensor(np.random.dirichlet([1, 1, 1])) for _ in range(x_1.size(0))])
                x_t = m.geodesic_interpolant(x_0, x_1, times)
                target = m.log_map(x_0, x_1)
                target = m.parallel_transport(x_0, x_t, target)
                out = model(x_t, times)
                diff = out - target
                loss = m.square_norm_at(x_t, diff).mean()

                test_loss += [loss.item()]
        print(f"Test loss {np.mean(test_loss):.5f}")
        per_epoch_test += [np.mean(test_loss)]
        
    plt.plot(per_epoch, label="Train")
    plt.plot(per_epoch_test, label="Test")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()
    plt.plot(w1s)
    plt.show()
    return model

# model = train(MLP(3, 4, 64, False), 50)
model = train(MyMLP(input_dim=3, out_dim=2, add_t_emb=True), 50)

In [None]:
import seaborn as sns
from matplotlib import tri


@torch.no_grad()
def euler(model, x_0, steps: int = 1000):
    # x_0 batched
    delta = 1.0 / steps
    x = x_0
    for i in range(steps):
        t = torch.ones((x_0.size(0), 1)) * delta * i
        x = x + delta * model(x, t)
    return x


def get_points(points):
    v_a = torch.Tensor([[0, 1.0]])
    v_b = torch.Tensor([[-0.5, 0]])
    v_c = torch.Tensor([[0.5, 0]])
    points = points[:, 0].unsqueeze(-1) * v_a + points[:, 1].unsqueeze(-1) * v_b + points[:, 2].unsqueeze(-1) * v_c
    return points


@torch.no_grad()
def tangent_euler(model, x_0, steps: int = 100, evol: bool = False):
    dt = 1.0 / steps
    x = x_0
    every = 100
    m = NSimplex()
    xs = []
    for i in range(steps):
        t = torch.ones((x.size(0), 1)) * dt * (i + 1)
        x = m.exp_map(x, model(x, t) * dt)
        if (i + 1) % every == 0:
            xs += [x]
    if evol:
        f, axs = plt.subplots(nrows=(len(xs) // 4) + 1, ncols=4, figsize=(15, 15))
        for i, x in enumerate(xs):
            pts = get_points(x)
            axs[i // 4, i % 4].scatter(pts[:, 0], pts[:, 1], label=f"t={(i + 1) * every}")
        plt.show()
    return x


class Dirichlet(object):
    def __init__(self, alpha):
        from math import gamma
        from operator import mul
        self._alpha = np.array(alpha)
        self._coef = gamma(np.sum(self._alpha)) / \
                           np.multiply.reduce([gamma(a) for a in self._alpha])
    def pdf(self, x):
        '''Returns pdf value for `x`.'''
        from operator import mul
        return self._coef * np.multiply.reduce([xx ** (aa - 1)
                                               for (xx, aa)in zip(x, self._alpha)])


def viz_model(model):
    import dirichlet
    points = torch.stack([torch.Tensor(np.random.dirichlet([1, 1, 1])) for _ in range(1000)])
    dest = tangent_euler(model, points)
    corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
    triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=2)
    AREA = 0.5 * 1 * 0.75**0.5
    pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(4)]
    def xy2bc(xy, tol=1.e-4):
        '''Converts 2D Cartesian coordinates to barycentric.'''
        tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))
        coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
        return np.clip(coords, tol, 1.0 - tol)

    class Dirichlet(object):
        def __init__(self, alpha):
            from math import gamma
            from operator import mul
            self._alpha = np.array(alpha)
            self._coef = gamma(np.sum(self._alpha)) / np.multiply.reduce([gamma(a) for a in self._alpha])
        def pdf(self, x):
            '''Returns pdf value for `x`.'''
            from operator import mul
            return self._coef * np.multiply.reduce([xx ** (aa - 1.0)
                                                for (xx, aa) in zip(x, self._alpha)])
    alphas = dirichlet.mle(dest.numpy(), tol=1e-5)
    print(alphas)
    pdf = Dirichlet(alphas)
    pvals = [pdf.pdf(xy2bc(x)) for x in zip(trimesh.x, trimesh.y)]

    plt.tricontourf(trimesh, pvals, 100, cmap="jet")
    plt.show()
    plot(points, dest)
    plot(raw_dataset, dest)


viz_model(model)