In [1]:
import torch
import math
import numpy as np
import gpytorch
from torch.nn import Module
import matplotlib.pyplot as plt

## Test Sampler

In [2]:
class ComplexModel():
    def __init__(self, simplex_dict, vertices):
        self.simplexes = simplex_dict
        self.vertices = vertices
        self.n_vert = vertices.shape[-1]

In [3]:
def temp_volume(model, ind):
    cdist = gpytorch.kernels.Kernel().covar_dist
    n_vert = len(model.simplexes[ind])
    total_vert = model.n_vert
        
    mat = torch.ones(n_vert+1, n_vert+1) - torch.eye(n_vert + 1)
    
    ## compute distance between parameters ##
    n_par = model.vertices.shape[0]
    par_vecs = torch.zeros(n_vert, n_par)
    for ii, vv in enumerate(model.simplexes[ind]):
        par_vecs[ii, :] = model.vertices[:, vv]
        
    dist_mat = cdist(par_vecs, par_vecs).pow(2)
    mat[:n_vert, :n_vert] = dist_mat
    
    norm = (math.factorial(n_vert-1)**2) * (2. ** (n_vert-1))
    return torch.abs(torch.det(mat)).div(norm).pow(0.5)

In [4]:
class SimplicialComplex(Module):
    def __init__(self, n_simplex):
        super(SimplicialComplex, self).__init__()
        self.n_simplex = n_simplex
    
    def forward(self, complex_model):
        
        ## first need to pick a simplex to sample from ##
        vols = []
        n_verts = []
        for ii in range(self.n_simplex):
            vols.append(temp_volume(complex_model, ii))
            n_verts.append(len(complex_model.simplexes[ii]))

        norm = sum(vols)
        vol_cumsum = np.cumsum([vv/norm for vv in vols])
        simp_ind = np.min(np.where(np.random.rand(1) < vol_cumsum)[0])
        
        ## sample weights for simplex
        exps = [-(torch.rand(1)).log().item() for _ in range(n_verts[simp_ind])]
        total = sum(exps)
        exps = [exp/total for exp in exps]
        
        ## now assign vertex weights out
        vert_weights = [0] * complex_model.n_vert
        for ii, vert in enumerate(complex_model.simplexes[simp_ind]):
            vert_weights[vert] = exps[ii]

        return vert_weights

In [5]:
simplexes = {0:[0, 1], 1:[0, 2, 3]}
vertices = torch.tensor([[0, -1, 1, 0],[0, -1, 0, 2]])

In [6]:
model = ComplexModel(simplexes, vertices)

In [None]:
sampler = SimplicialComplex(len(model.simplexes))

In [None]:
n_pts = 100
pts = torch.zeros(2, 100)
for ii in range(n_pts):
    pts[:, ii] = vertices.float().matmul(torch.tensor(sampler(model)))

In [None]:
plt.scatter(vertices[0, :], vertices[1, :])
plt.scatter(pts[0, :], pts[1, :], alpha=0.5)

In [None]:
simplexes = {0:[0, 1, 2]}
vertices = torch.tensor([[0, 1, 0],[0, 0, 2]])

In [None]:
model = ComplexModel(simplexes, vertices)

In [None]:
sampler = SimplicialComplex(len(model.simplexes))

In [None]:
n_pts = 100
pts = torch.zeros(2, 100)
for ii in range(n_pts):
    pts[:, ii] = vertices.float().matmul(torch.tensor(sampler(model)))

In [None]:
plt.scatter(vertices[0, :], vertices[1, :])
plt.scatter(pts[0, :], pts[1, :], alpha=0.5)

In [None]:
for ii in [0, 1]:
    simplexes[ii].append(4)

In [None]:
simplexes

In [None]:
len(simplexes)