From 0c00b48e43f31be11c4b0bfee1bc3ec185406170 Mon Sep 17 00:00:00 2001 From: Austin Ubuntu Windows Date: Fri, 10 Nov 2023 13:33:46 -0500 Subject: [PATCH] Initial adapted implementation [WIP] --- .../bystro/ancestry/spherical_autoencoder.py | 458 ++++++++++++++++++ 1 file changed, 458 insertions(+) create mode 100644 python/python/bystro/ancestry/spherical_autoencoder.py diff --git a/python/python/bystro/ancestry/spherical_autoencoder.py b/python/python/bystro/ancestry/spherical_autoencoder.py new file mode 100644 index 000000000..aab041fa2 --- /dev/null +++ b/python/python/bystro/ancestry/spherical_autoencoder.py @@ -0,0 +1,458 @@ +""" + + +Objects +------- + +Methods +------- + +""" +import numpy as np +import scipy.special +from numbers import Number +from numpy.typing import NDArray + +import torch +import torch nn as nn +import torch.optim as optim + + +class SphericalVAE: + """ + + """ + + def __init__(self,n_components,model_options=None,training_options=None): + + super().__init__() + + self.n_components = n_components + + + def fit(self,X): + """ + + + """ + + trainable_variables = + + optimizer = optim.Adam() + + + + +class SphericalEncoder(nn.Module): + + def __init__(self,encoder_options): + super().__init__() + + self.layers = nn.Sequential( + + + + def reparameterize(self,z_mean,z_var): + """ + + """ + q_z = VonMisesFisher(z_mean,z_var) + p_z = HypersphericalUniform(self.z_dim - 1) + return q_z,p_z + + def forward(self,x): + z_mean,z_var = self.layers(x) + q_z, p_z = self.reparameterize(z_mean,z_var) + return (z_mean,) + + +class SphericalDecoder(nn.Module): + + def __init__(self,encoder_options): + super().__init__() + + + self.layers = () + + def forward(self,z): + x = self.layers(z) + return x + + + + + + + +class HypersphericalUniform(torch.distributions.Distribution): + + support = torch.distributions.constraints.real + has_rsample = False + _mean_carrier_measure = 0 + + @property + def dim(self): + return self._dim + + @property + def device(self): + return self._device + + @device.setter + def device(self, val): + self._device = val if isinstance(val, torch.device) else torch.device(val) + + def __init__(self, dim, validate_args=None, device="cpu"): + super(HypersphericalUniform, self).__init__( + torch.Size([dim]), validate_args=validate_args + ) + self._dim = dim + self.device = device + + def sample(self, shape=torch.Size()): + output = ( + torch.distributions.Normal(0, 1) + .sample( + (shape if isinstance(shape, torch.Size) else torch.Size([shape])) + + torch.Size([self._dim + 1]) + ) + .to(self.device) + ) + + return output / output.norm(dim=-1, keepdim=True) + + def entropy(self): + return self.__log_surface_area() + + def log_prob(self, x): + return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area() + + def __log_surface_area(self): + if torch.__version__ >= "1.0.0": + lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device)) + else: + lgamma = torch.lgamma( + torch.Tensor([(self._dim + 1) / 2], device=self.device) + ) + return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma + + + + +class VonMisesFisher(torch.distributions.Distribution): + + arg_constraints = { + "loc": torch.distributions.constraints.real, + "scale": torch.distributions.constraints.positive, + } + support = torch.distributions.constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + # option 1: + return self.loc * ( + ive(self.__m / 2, self.scale) / ive(self.__m / 2 - 1, self.scale) + ) + # option 2: + # return self.loc * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) + # options 3: + # return self.loc * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) + + @property + def stddev(self): + return self.scale + + + def __init__(self, loc, scale, validate_args=None, k=1): + self.dtype = loc.dtype + self.loc = loc + self.scale = scale + self.device = loc.device + self.__m = loc.shape[-1] + self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device) + self.k = k + + super().__init__(self.loc.size(), validate_args=validate_args) + + def sample(self, shape=torch.Size()): + with torch.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=torch.Size()): + shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) + + w = ( + self.__sample_w3(shape=shape) + if self.__m == 3 + else self.__sample_w_rej(shape=shape) + ) + + v = ( + torch.distributions.Normal(0, 1) + .sample(shape + torch.Size(self.loc.shape)) + .to(self.device) + .transpose(0, -1)[1:] + ).transpose(0, -1) + v = v / v.norm(dim=-1, keepdim=True) + + w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) + x = torch.cat((w, w_ * v), -1) + z = self.__householder_rotation(x) + + return z.type(self.dtype) + + + def __sample_w3(self, shape): + shape = shape + torch.Size(self.scale.shape) + u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device) + self.__w = ( + 1 + + torch.stack( + [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0 + ).logsumexp(0) + / self.scale + ) + return self.__w + + def __sample_w_rej(self, shape): + c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2) + b_true = (-2 * self.scale + c) / (self.__m - 1) + + # using Taylor approximation with a smooth swift from 10 < scale < 11 + # to avoid numerical errors for large scale + b_app = (self.__m - 1) / (4 * self.scale) + s = torch.min( + torch.max( + torch.tensor([0.0], dtype=self.dtype, device=self.device), + self.scale - 10, + ), + torch.tensor([1.0], dtype=self.dtype, device=self.device), + ) + b = b_app * s + b_true * (1 - s) + + a = (self.__m - 1 + 2 * self.scale + c) / 4 + d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) + + self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k) + return self.__w + + def __while_loop(self, b, a, d, shape, k=20, eps=1e-20): + # matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together + b, a, d = [ + e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1) + for e in (b, a, d) + ] + w, e, bool_mask = ( + torch.zeros_like(b).to(self.device), + torch.zeros_like(b).to(self.device), + (torch.ones_like(b) == 1).to(self.device), + ) + + sample_shape = torch.Size([b.shape[0], k]) + shape = shape + torch.Size(self.scale.shape) + + while bool_mask.sum() != 0: + con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) + con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) + e_ = ( + torch.distributions.Beta(con1, con2) + .sample(sample_shape) + .to(self.device) + .type(self.dtype) + ) + + u = ( + torch.distributions.Uniform(0 + eps, 1 - eps) + .sample(sample_shape) + .to(self.device) + .type(self.dtype) + ) + + w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) + t = (2 * a * b) / (1 - (1 - b) * e_) + + accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u) + accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1) + accept_idx_clamped = accept_idx.clamp(0) + # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards + w_ = w_.gather(1, accept_idx_clamped.view(-1, 1)) + e_ = e_.gather(1, accept_idx_clamped.view(-1, 1)) + + reject = accept_idx < 0 + accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject + + w[bool_mask * accept] = w_[bool_mask * accept] + e[bool_mask * accept] = e_[bool_mask * accept] + + bool_mask[bool_mask * accept] = reject[bool_mask * accept] + + return e.reshape(shape), w.reshape(shape) + + def __householder_rotation(self, x): + u = self.__e1 - self.loc + u = u / (u.norm(dim=-1, keepdim=True) + 1e-5) + z = x - 2 * (x * u).sum(-1, keepdim=True) * u + return z + + def entropy(self): + # option 1: + output = ( + -self.scale + * ive(self.__m / 2, self.scale) + / ive((self.__m / 2) - 1, self.scale) + ) + # option 2: + # output = - self.scale * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) + # option 3: + # output = - self.scale * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) + + return output.view(*(output.shape[:-1])) + self._log_normalization() + + def log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + def _log_unnormalized_prob(self, x): + output = self.scale * (self.loc * x).sum(-1, keepdim=True) + + return output.view(*(output.shape[:-1])) + + def _log_normalization(self): + output = -( + (self.__m / 2 - 1) * torch.log(self.scale) + - (self.__m / 2) * math.log(2 * math.pi) + - (self.scale + torch.log(ive(self.__m / 2 - 1, self.scale))) + ) + + return output.view(*(output.shape[:-1])) + +@register_kl(VonMisesFisher, HypersphericalUniform) +def _kl_vmf_uniform(vmf, hyu): + return -vmf.entropy() + hyu.entropy() + +""" +The provided Python code defines a custom autograd Function class called +IveFunction for PyTorch. This class is used to compute the derivative of +a special function called the modified Bessel function of the first kind, +denoted as Iν(z) or ive(v, z), where ν (v) is a scalar and z is a tensor. +This function defines the forward and backward operations for computing +gradients using autograd. +""" +class IveFunction(torch.autograd.Function): + + @staticmethod + def forward(self, v, z): + """ + In the forward method, the following steps are performed: + + An assertion checks that v is a scalar (an instance of a number). + This ensures that v is not a tensor or an array but a single value. + The z tensor is saved for later use in the backward pass. + The v value is stored in the self.v attribute. + The data from the z tensor is moved to the CPU and converted to a NumPy array, + which allows for using NumPy's special functions on it. + If v is close to 0 special i0e is used, if close to 1, i1e is used, otherwise Iv(z) + Converted to torch.tensor + + Parameters + ---------- + + + Returns + ------- + + + + """ + + assert isinstance(v, Number), "v must be a scalar" + + self.save_for_backward(z) + self.v = v + z_cpu = z.data.cpu().numpy() + + if np.isclose(v, 0): + output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype) + elif np.isclose(v, 1): + output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype) + else: + output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype) + + return torch.Tensor(output).to(z.device) + + @staticmethod + def backward(self, grad_output): + """ + the gradient of the forward operation is computed. This gradient + is with respect to the output of the forward pass (grad_output), + and the result is returned as a tuple of two elements: + + The first element is None, indicating that the gradient with + respect to v is not computed or required. The second element is + the gradient with respect to z, calculated using the formula for + the derivative of Iν(z) with respect to z. + + Parameters + ---------- + + Returns + ------- + + """ + z = self.saved_tensors[-1] + return ( + None, + grad_output * (ive(self.v - 1, z) - ive(self.v, z) * (self.v + z) / z), + ) + +class Ive(torch.nn.Module): + def __init__(self, v): + super(Ive, self).__init__() + self.v = v + + def forward(self, z): + return ive(self.v, z) + + +def ive_fraction_approx(v, z): + """ + Parameters + ---------- + + + Returns + ------- + + """ + + # I_(v/2)(k) / I_(v/2 - 1)(k) >= z / (v-1 + ((v+1)^2 + z^2)^0.5 + return z / (v - 1 + torch.pow(torch.pow(v + 1, 2) + torch.pow(z, 2), 0.5)) + + +def ive_fraction_approx2(v, z, eps=1e-20): + """ + Parameters + ---------- + + + Returns + ------- + + """ + def delta_a(a): + lamb = v + (a - 1.0) / 2.0 + return (v - 0.5) + lamb / ( + 2 * torch.sqrt((torch.pow(lamb, 2) + torch.pow(z, 2)).clamp(eps)) + ) + + delta_0 = delta_a(0.0) + delta_2 = delta_a(2.0) + B_0 = z / ( + delta_0 + torch.sqrt((torch.pow(delta_0, 2) + torch.pow(z, 2))).clamp(eps) + ) + B_2 = z / ( + delta_2 + torch.sqrt((torch.pow(delta_2, 2) + torch.pow(z, 2))).clamp(eps) + ) + + return (B_0 + B_2) / 2.0