In [4]:
import numpy as np
import math
import torch
from scipy.special import logit
from scipy.special import expit as logistic
from scipy.stats import multivariate_normal
from matplotlib.animation import FuncAnimation
from matplotlib import animation
import ternary

In [None]:
min_b = 0.1
max_b = 20.0
def b(t):
    return t*min_b + (1/2)*(t**2)*(max_b-min_b)

t = 0.9
x_0 = np.eye(3)[0]
mu = np.exp(-1/2*b(t)) * x_0
var = 1 - np.exp(-b(t))
plot_simplex(mu, var)

In [None]:
def plot_simplex(mu, var):
    mvn = multivariate_normal(mu, var)
    def logit_normal_pdf(p):
        # Transform into multivariate gaussian
        p = np.array(p)
        x = logistic(p)
        return mvn.pdf(x)
    scale = 30
    figure, tax = ternary.figure(scale=scale)
    tax.heatmapf(logit_normal_pdf, boundary=False, style="triangular")
    tax.boundary(linewidth=2.0)
    tax.set_title("Logits-normal")
    tax.show()

In [None]:
plot_simplex(1)

In [68]:
restypes = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
    'S', 'T', 'W', 'Y', 'V'
]

In [69]:
len(restypes)

20

In [70]:
restypes[19]

'V'

In [7]:
min_b = 0.1
max_b = 15.0
def b(t):
    return t*min_b + (1/2)*(t**2)*(max_b-min_b)

scale = 30
fig, tax = ternary.figure(scale=scale)
num_t = 100
ts = np.linspace(1e-3, 1.0, num_t)

def plot_simplex(i):
    t = ts[i]
    x_0 = np.eye(3)[0]
    mu = np.exp(-1/2*b(t)) * x_0
    var = 1 - np.exp(-b(t))
    mvn = multivariate_normal(mu, var)
    def logit_normal_pdf(p):
        # Transform into multivariate gaussian
        p = np.array(p)
        x = logistic(p)
        return mvn.pdf(x)
    tax.heatmapf(
        logit_normal_pdf,
        boundary=True,
        style="triangular",
        colorbar=False)
    tax.boundary(linewidth=2.0)
    tax.set_title(f'Simplex diffusion t={t:.4f}')

def update(frame):
    tax.close()
    _ =  plot_simplex(frame)

plot_simplex(0)
anim = FuncAnimation(
    fig,
    update,
    frames=list(range(1, num_t)),
    interval=10,
    blit=False)
writergif = animation.PillowWriter(fps=30)
save_path = f'simplex_diffusion_linear_minb_{min_b}_maxb_{max_b}.gif'
anim.save(save_path, writer=writergif)

### Simplex diffuser class

In [56]:
x_0 = np.eye(20)[0]
# x_0 = np.zeros_like(x_0)

In [59]:
x_0

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.])

In [65]:
min_b = 0.1
max_b = 20.0
def b(t):
    return t*min_b + (1/2)*(t**2)*(max_b-min_b)

t = 1.0
x_t = np.random.normal(
    loc=np.exp(-1/2*b(t)) * x_0,
    scale=np.sqrt(1 - np.exp(-b(t)))
)
exp_x_t = np.concatenate([np.exp(x_t), np.array([1.0])]) / (1 + np.sum(np.exp(x_t)))
exp_x_t

array([0.02066836, 0.01662281, 0.01078764, 0.05670665, 0.14920168,
       0.01738793, 0.07572645, 0.05901159, 0.01951077, 0.08604732,
       0.01008881, 0.04045653, 0.02898527, 0.01607936, 0.18563185,
       0.04814904, 0.06741189, 0.02365585, 0.01553152, 0.02339176,
       0.02894691])

In [66]:
exp_x_t = np.concatenate([np.exp(x_t), np.array([1.0])]) / (1 + np.sum(np.exp(x_t)))

In [22]:
exp_x_t

array([0.50675705, 0.18315034, 0.14362313, 0.16646947])

In [54]:
np.eye(21)[np.array([1, 4, 2])].shape

(3, 21)