## Run KL analysis for intergenic region, footprinting region, and DHS sites

In [None]:
## write function for model

def model(beta_neut, mu_vals, gene_ids, covariates, n_bins, mu_ref, 
                             sample_sfs=None, n_mix=2, cov_sigma_prior=torch.tensor(0.1, dtype=torch.float32), trans="abs", pdist="t"):
    """
    Pyro sampling model for a gene-based DFE with covariates
    """
    n_covs = covariates.shape[-1]          # number of covariates included
    n_sites = len(mu_vals)                 # number of sites (potential mutations) we are modeling
    n_genes = len(torch.unique(gene_ids))  # number of genes
    mu = torch.unique(mu_vals)             # set of all possible mutation rates
    n_mu = len(mu)                         # number of unique mutation rates
    
    ## Setup flexible prior
    # parameters describing the prior over genes are set as pyro.param, meaning they will get point estimates (no posterior)
    if pdist=="t":
        # t-distribution can modulate covariance (L) and kurtosis (df)
        # uses a fixed "point mass" at zero as one of the mixtures, not sure if this should be kept
        beta_prior_mean = pyro.param("beta_prior_mean", torch.randn((n_mix-1,n_bins)),
                                     constraint=constraints.real)
        beta_prior_L = pyro.param("beta_prior_L", torch.linalg.cholesky(0.01*torch.diag(torch.ones(n_bins, dtype=torch.float32))).expand(n_mix-1, n_bins, n_bins), 
                                                                        constraint=constraints.lower_cholesky)
        beta_prior_df = pyro.param("beta_prior_df", torch.tensor([10]*(n_mix-1), dtype=torch.float32), constraint=constraints.positive)
        mix_probs = pyro.param("mix_probs", torch.ones(n_mix, dtype=torch.float32)/n_mix, constraint=constraints.simplex)
    elif pdist=="normal":
        # normal model has zero covariance, a different variance for each bin though
        mix_probs = pyro.param("mix_probs", torch.ones(n_mix, dtype=torch.float32)/n_mix, constraint=constraints.simplex)
        beta_prior_loc = pyro.param("beta_prior_loc", torch.randn((n_mix, n_bins)), constraint=constraints.real)
        beta_prior_scale = pyro.param("beta_prior_scale", torch.rand((n_mix, n_bins)), constraint=constraints.positive)
        
    # interaction term bewteen gene-based selection and mutation rate
    beta_prior_b = pyro.param("beta_prior_b", torch.tensor([0.001]*n_bins, dtype=torch.float32), constraint=constraints.positive)
    
    # Each covariate has a vector of betas, one for each bin, maybe think about different prior here?
    with pyro.plate("covariates", n_covs):
        beta_cov = pyro.sample("beta_cov", dist.HalfCauchy(cov_sigma_prior).expand([n_bins]).to_event(1))
    
    with pyro.plate("genes", n_genes):
        # sample latent betas from either t or normal distribution
        if pdist=="t":
            beta_sel = pyro.sample("beta_sel", dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                   dist.MultivariateStudentT(df=torch.cat((beta_prior_df, torch.tensor([1000], dtype=torch.float32))), 
                                                             loc=torch.cat((beta_prior_mean, 
                                                                            torch.tensor([0]*n_bins, dtype=torch.float32).expand((1, n_bins)))), 
                                                             scale_tril=torch.cat((beta_prior_L, 
                                                                                   torch.linalg.cholesky(torch.diag(1e-8*torch.ones(n_bins, dtype=torch.float32))).expand(1, n_bins, n_bins))))))
        elif pdist=="normal":
            beta_sel = pyro.sample("beta_sel", dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                                                      dist.Normal(beta_prior_loc, beta_prior_scale).to_event(1)))
        # apply transform to latent betas
        if trans == "abs":
            beta_trans = torch.cumsum(torch.abs(beta_sel), dim=-1)
        elif trans=="logabs":
            beta_trans = torch.cumsum(torch.log(torch.abs(beta_sel)+1), dim=-1)
        elif trans=="relu":
            beta_trans = torch.cumsum(relu(beta_sel), dim=-1)
        elif trans=="logrelu":
            beta_trans = torch.cumsum(torch.log(relu(beta_sel)+1), dim=-1)
        
    # calculate the multinomial coefficients for each gene and each mutation rate
    mu_adj = mu_ref[...,None] * torch.cumsum(beta_prior_b, -1) * beta_trans[...,None,:]
    mn_sfs = (beta_neut  - 
              beta_trans[...,None,:] -
              mu_adj)
    # convert to probabilities per-site and adjust for covariates
    sfs = softmax(pad(mn_sfs[..., gene_ids, mu_vals, :] - torch.matmul(covariates, torch.cumsum(beta_cov, -1))))
    
    with pyro.plate("sites", n_sites):
        pyro.sample("obs", dist.Categorical(sfs), obs=sample_sfs)