## Here, I will try to apply Evan's KL analysis and convert it to a more friendly framework

## There are functions to include

In [2]:
import logging
import os
import math

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

import torch.distributions.transforms as transforms

from tqdm import tqdm
import matplotlib

In [3]:
## Transformations of the SFS
####################################################
def multinomial_trans(sfs_probs, offset=None):
    sfs_probs = np.array(sfs_probs)
    P_0 = sfs_probs[...,0]
    if offset:
        betas = np.log(sfs_probs[...,1:]) - np.log(P_0[...,None]) - offset
    else:
        betas = np.log(sfs_probs[...,1:]) - np.log(P_0[...,None])
    return betas

def multinomial_trans_torch(sfs_probs):
    P_0 = sfs_probs[...,0]
    return torch.log(sfs_probs[...,1:]) - torch.log(P_0[...,None])

def KL_fw(neut_probs, sel_probs):
    return np.sum(neut_probs * (np.log(neut_probs) - np.log(sel_probs)), axis=-1)

def KL_rv(neut_probs, sel_probs):
    return np.sum(sel_probs * (np.log(sel_probs) - np.log(neut_probs)), axis=-1)

## Change Evan's Code to class format

In [None]:
class KL_inference(PyroModule):
    def __init__(self, neut_sfs_full, n_genes, n_covs, n_bins, mu_ref, 
                             sample_sfs=None, n_mix=2, cov_sigma_prior=torch.tensor(0.1, dtype=torch.float32), trans="abs", pdist="t"):
        super().__init__()
        
        ##define useful transformations
        self.pad = torch.nn.ConstantPad1d((1,0), 0.)            # Add a 0 to a tensor
#         self.softmax = PyroModule[nn.Softmax]()
        self.softmax = torch.nn.Softmax(-1)                     # softmax transform along the last dimension
        relu = torch.nn.ReLU()                             # map everything < 0 -> 0 

        #other definitions
        beta_neut_full = multinomial_trans_torch(neut_sfs_full) #neut_sfs_full is the neutral sfs
        beta_neut = beta_neut_full[ref_mu_ii,:]
        self.beta_neut = beta_neut
        
        self.mu_ref = mu_ref
        
        ## Setup flexible prior
        # parameters describing the prior over genes are set as pyro.param, meaning they will get point estimates (no posterior)
        if pdist=="t":
            # t-distribution can modulate covariance (L) and kurtosis (df)
            # uses a fixed "point mass" at zero as one of the mixtures, not sure if this should be kept
            beta_prior_mean = pyro.param("beta_prior_mean", torch.randn((n_mix-1,n_bins)),
                                         constraint=constraints.real)
            beta_prior_L = pyro.param("beta_prior_L", torch.linalg.cholesky(0.01*torch.diag(torch.ones(n_bins, dtype=torch.float32))).expand(n_mix-1, n_bins, n_bins), 
                                                                            constraint=constraints.lower_cholesky)
            beta_prior_df = pyro.param("beta_prior_df", torch.tensor([10]*(n_mix-1), dtype=torch.float32), constraint=constraints.positive)
            mix_probs = pyro.param("mix_probs", torch.ones(n_mix, dtype=torch.float32)/n_mix, constraint=constraints.simplex)
        elif pdist=="normal":
            # normal model has zero covariance, a different variance for each bin though
            mix_probs = pyro.param("mix_probs", torch.ones(n_mix, dtype=torch.float32)/n_mix, constraint=constraints.simplex)
            beta_prior_loc = pyro.param("beta_prior_loc", torch.randn((n_mix, n_bins)), constraint=constraints.real)
            beta_prior_scale = pyro.param("beta_prior_scale", torch.rand((n_mix, n_bins)), constraint=constraints.positive)

        # interaction term bewteen gene-based selection and mutation rate
        self.beta_prior_b = pyro.param("beta_prior_b", torch.tensor([0.001]*n_bins, dtype=torch.float32), constraint=constraints.positive)

        # Each covariate has a vector of betas, one for each bin, maybe think about different prior here?
#         with pyro.plate("covariates", n_covs):
#             beta_cov = pyro.sample("beta_cov", dist.HalfCauchy(cov_sigma_prior).expand([n_bins]).to_event(1))
        self.beta_cov = pyro.sample("beta_cov", dist.HalfCauchy(0.1).expand([n_bins, n_covs]).to_event(2))

        with pyro.plate("genes", n_genes):
            # sample latent betas from either t or normal distribution
            if pdist=="t":
                beta_sel = pyro.sample("beta_sel", dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                       dist.MultivariateStudentT(df=torch.cat((beta_prior_df, torch.tensor([1000], dtype=torch.float32))), 
                                                                 loc=torch.cat((beta_prior_mean, 
                                                                                torch.tensor([0]*n_bins, dtype=torch.float32).expand((1, n_bins)))), 
                                                                 scale_tril=torch.cat((beta_prior_L, 
                                                                                       torch.linalg.cholesky(torch.diag(1e-8*torch.ones(n_bins, dtype=torch.float32))).expand(1, n_bins, n_bins))))))
            elif pdist=="normal":
                beta_sel = pyro.sample("beta_sel", dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                                                          dist.Normal(beta_prior_loc, beta_prior_scale).to_event(1)))
            # apply transform to latent betas
            if trans == "abs":
                self.beta_trans = torch.cumsum(torch.abs(beta_sel), dim=-1)
            elif trans=="logabs":
                self.beta_trans = torch.cumsum(torch.log(torch.abs(beta_sel)+1), dim=-1)
            elif trans=="relu":
                self.beta_trans = torch.cumsum(relu(beta_sel), dim=-1)
            elif trans=="logrelu":
                self.beta_trans = torch.cumsum(torch.log(relu(beta_sel)+1), dim=-1)
                        
    def forward(self, mu_vals, gene_ids, covariates, y=None):
        
        # calculate the multinomial coefficients for each gene and each mutation rate
        mu_adj = self.mu_ref[...,None] * torch.cumsum(self.beta_prior_b, -1) * self.beta_trans[...,None,:]
        mn_sfs = (self.beta_neut  - 
                  self.beta_trans[...,None,:] -
                  mu_adj)
        
        # convert to probabilities per-site and adjust for covariates
        linear = self.pad(mn_sfs[..., gene_ids, mu_vals, :] - torch.matmul(covariates, torch.cumsum(self.beta_cov, -1)))
        sfs = self.softmax(linear)
        
#         linear = torch.matmul(x, torch.cumsum(self.beta_cov, -1))        
#         sfs = self.softmax(neutral_sfs - linear).squeeze(-1)
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", 
                              dist.Categorical(probs=sfs),
                              obs=y.squeeze())
        return model_output

## inference portion

In [None]:
#define variables
#define neut_sfs_full, mu_vals, gene_ids, covariates
n_covs = covariates.shape[-1]          # number of covariates included
n_genes = len(torch.unique(gene_ids))  # number of genes

#define model and guide
model = KL_inference(neut_sfs_full, n_genes, n_covs, n_bins, mu_ref)
guide = pyro.infer.autoguide.AutoNormal(model)

#run inference
pyro.clear_param_store()
# run SVI
adam = pyro.optim.Adam({"lr":lr})
elbo = pyro.infer.Trace_ELBO(num_particles=num_particles, vectorize_particles=True)
svi = pyro.infer.SVI(mu_sfs_sitewise_regr_cov, guide, adam, elbo)
losses = []
for step in tqdm(range(n_steps)): # tqdm is just a progress bar thing 
    loss = svi.step(mu_vals, gene_ids, covariates, sample_sfs)
    losses.append(loss)
fig, ax = plot_losses(losses)

## post inference portion (same as raklette_daniel.py)

In [None]:
# grab gene-DFE prior parameter point estimates
beta_neut = beta_neut_full[ref_mu_ii,:]
if pdist=="t":
    beta_prior_df = pyro.param("beta_prior_df")
    beta_prior_mean = pyro.param("beta_prior_mean")
    beta_prior_L = pyro.param("beta_prior_L")
    mix_probs = pyro.param("mix_probs")
elif pdist=="normal":
    mix_probs = pyro.param("mix_probs")
    beta_prior_loc = pyro.param("beta_prior_loc")
    beta_prior_scale = pyro.param("beta_prior_scale")

beta_prior_b = pyro.param("beta_prior_b")

# Sample betas from the DFE prior, representing the fit distribution across genes
if pdist=="t":
    prior_dist = dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                          dist.MultivariateStudentT(df=torch.cat((beta_prior_df, torch.tensor([1000], dtype=float))), 
                                                                loc=torch.cat((beta_prior_mean, 
                                                                                   torch.tensor([0]*n_bins, dtype=float).expand((1, n_bins)))), 
                                                                scale_tril=torch.cat((beta_prior_L, 
                                                                            torch.linalg.cholesky(torch.diag(1e-8*torch.ones(n_bins, dtype=float))).expand(1, n_bins, n_bins))))
                                                             )
elif pdist=="normal":
    prior_dist = dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                        dist.Normal(beta_prior_loc, beta_prior_scale).to_event(1))
prior_samps = prior_dist.sample((post_samps,))

if trans == "abs":
    prior_trans = torch.cumsum(torch.abs(prior_samps), axis=-1)
elif trans=="logabs":
    prior_trans = torch.cumsum(torch.log(torch.abs(prior_samps)+1), axis=-1)
elif trans=="relu":
    prior_trans = torch.cumsum(relu(prior_samps), axis=-1)
elif trans=="logrelu":
    prior_trans = torch.cumsum(torch.log(relu(prior_samps)+1), axis=-1)

## Prior SFS probabilities for gene effects in the absence of covariates
prior_probs = softmax(pad(beta_neut - prior_trans -
                          mu_ref[ref_mu_ii]*torch.cumsum(beta_prior_b, -1)*prior_trans
                         )
                     ).detach().numpy()

# take samples from the posterior distribution on all betas
with pyro.plate("samples", post_samps, dim=-2):
    post_samples = guide()

if pdist=="t":
    result = {"neut_sfs_full":neut_sfs_full, "beta_neut_full":beta_neut_full, "ref_mu_ii":ref_mu_ii,
              "beta_prior_df":beta_prior_df, "beta_prior_mean":beta_prior_mean, "beta_prior_L":beta_prior_L,
              "mix_probs":mix_probs, 
              "beta_prior_b":beta_prior_b, "trans":trans,
              "prior_probs":prior_probs, "post_samples":post_samples, "mu_ref":mu_ref}
elif pdist=="normal":
    result = {"neut_sfs_full":neut_sfs_full, "beta_neut_full":beta_neut_full, "ref_mu_ii":ref_mu_ii,
              "beta_prior_scale":beta_prior_scale, "beta_prior_loc":beta_prior_loc,
              "mix_probs":mix_probs,
              "beta_prior_b":beta_prior_b, "trans":trans,
              "prior_probs":prior_probs, "post_samples":post_samples, "mu_ref":mu_ref}

# calculate the posterior distribution on KL for each gene
result = calc_KL_genewise(result, pdist=pdist)

## Then calculate the posteriors for covariate betas
result["post_beta_cov"] = torch.cumsum(post_samples['beta_cov'], -1)
result["losses"] = losses
result["fig"] = (fig, ax)