<a href="https://colab.research.google.com/github/macorony/Bioinformatic_analysis/blob/main/Algorithms/Bayesian_selection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd
import scipy

In [48]:
class BayesianSelection:
  def __init__(self, count_table, design_matrix=None, sgrna_efficiency=None):
    self.count_table = count_table
    self.design_matrix = design_matrix
    self.sgrna_efficiency = sgrna_efficiency

    # calculate dimensions
    self.n_sgrnas = len(count_table)
    self.genes = np.unique(count_table['gene'])
    self.n_genes = len(self.genes)

    # initialize results storage
    self.normalized_counts = None
    self.fold_changes = None
    self.gene_guide_map = None

    # initialize parameter storage
    self.gene_priors = None
    self.sgrna_priors = None
    self.dispersion_priors = None

    self.gene_effect = np.zeros(self.n_genes)
    self.sgrna_efficiency = np.zeros(self.n_sgrnas)
    self.dispersion = np.zeros(self.n_genes)

  def normalize_counts(self):
    """
    Normalize sgRNA counts using median normalization.

    Returns:
    - normalized_counts (pd.DataFrame): Normalized sgRNA counts.
    """
    # calculate size factors of control and treatment samples
    control_sf = np.median(self.count_table['control'])/self.count_table['control']
    treatment_sf = np.median(self.count_table['treatment'])/self.count_table['treatment']

    # apply normalization
    self.normalized_counts = pd.DataFrame(
        {'control': self.count_table['control']*control_sf,
         'treatment': self.count_table['treatment']*treatment_sf}
        )


    # calculate log2 fold change
    self.fold_changes = pd.DataFrame({
        'log2fc': np.log2(self.normalized_counts['treatment']/self.normalized_counts['control']),
        'variance': 1/(self.normalized_counts['control'] + self.normalized_counts['treatment'])
        })
    return self.normalized_counts

  def group_sgrna_by_gene(self):
    """
    Group sgRNAs by targeting genes.

    Returns:
    - gene_guide_map (dict): Mapping of genes to sgRNA indices.
    """


    self.gene_guide_map = {}
    for gene in self.genes:
      mask = self.count_table['gene'] == gene
      self.gene_guide_map[gene] = {
          'guide_index': np.where(mask)[0],
          'n_guides':np.sum(mask)}
    return self.gene_guide_map

  def initialize_priors(self):
    """
    Initialize prior distributions for Bayesian analysis.
    """

    # 1. gene effect priors(normal distribution)
    self.gene_priors = {
        'mean': np.zeros(self.n_genes),
        'variance': np.ones(self.n_genes)
    }

    # 2.sgRNA efficiency priors
    if self.sgrna_efficiency is None:
      self.sgrna_priors = {
          'mean': np.array(list(self.sgrna_efficiency.values())),
          'variance': 0.1 * np.ones(self.n_sgrnas)
      }
    else:
      self.sgrna_priors = {
          'mean': np.ones(self.n_sgrnas),
          'variance':np.ones(self.n_sgrnas)
          }

    # 3. dispersion priors (Gamma distribution)
    self.dispersion_priors = {
        'shape': np.ones(self.n_genes),
        'scale': np.ones(self.n_genes)
        }

  def construct_likelihood(self):
    """
    Construct negative binomial likelihood function.

    Returns:
    - likelihood (function): Negative binomial log-likelihood function.
    """
    def negative_binomial_likelihood(count: np.ndarray,
                                     mean: np.ndarray,
                                     dispersion: np.ndarray):
      # log likelihood of negative binomial distribution
      r = 1/dispersion
      p = r/(r+mean)
      return scipy.stats.nbinom.logpmf(count, r, p)
    return negative_binomial_likelihood

  def update_gene_effect(self):
      """
      Update gene effects using Gibbs sampling.
      """

      # run normalize_counts if needed
      if self.fold_changes is None:
        self.normalize_counts()

      # run group_sgrna_by_gene if needed
      if self.gene_guide_map is None:
        self.group_sgrna_by_gene()

      # run initialize_priors if needed
      if any(attr is None for attr in [self.gene_priors, self.sgrna_priors, self.dispersion_priors]):
        self.initialize_priors()

      for gene_idx, gene in enumerate(self.genes):
        # get sgRNA for the gene
        gene_guides = self.gene_guide_map[self.genes[gene_idx]]['guide_index']
        # calculate statistics
        sgrna_data = self.fold_changes.loc[gene_guides, 'log2fc']
        print(f"sgRNA data: {sgrna_data}")
        sgrna_vars = 1/(self.sgrna_priors['variance'][gene_guides])
        # posterior parameters
        posterior_var = 1/(1/self.gene_priors['variance'][gene_idx] + np.sum(sgrna_vars))
        posterior_mean = posterior_var * (self.gene_priors['mean'][gene_idx]/self.gene_priors['variance'][gene_idx] + np.sum(sgrna_data*sgrna_vars))
        # sample new effect
        self.gene_effect[gene_idx] = np.random.normal(posterior_mean, np.sqrt(posterior_var))


  def update_sgrna_efficiency(self):
    """
    Update sgRNA efficiencies.
    """




    for sgrna_idx, sgrna in enumerate(self.n_sgrnas):
      # calculate statistics
      gene = self.count_table.loc[sgrna_idx, 'gene']
      gene_idx = np.where(self.genes == gene)[0][0]
      # calculate posterior
      data_contribution = self.fold_changes.loc[sgrna_idx, 'log2fc']
      prior_contribution = self.sgrna_prior['mean'][gene_idx]

      posterior_var = 1/(
          1/self.sgrna_priors['variance'][sgrna_idx] +
          1/self.gene_effects[gene_idx]**2
      )
      posterior_mean = posterior_var * (
          data_contribution/self.fold_changes.loc[sgrna_idx, 'variance'] +
          prior_contribution/self.gene_effects[gene_idx]**2
      )

      # sample the efficiency
      self.sgrna_efficiency[sgrna] = np.random.normal(posterior_mean, np.sqrt(posterior_var))

  def update_dispersion(self):
    """
    Update dispersion parameters.
    """
    for gene_idx, in enumerate(self.genes):
      gene_guides = self.gene_guide_map[self.genes[gene_idx]]['guide_index']
      # calculate statistics
      shape = self.dispersion_priors['shape'][gene_idx] + len(gene_guides)/2
      rate = self.dispersion_priors['rate'][gene_idx] + \
              np.sum((self.fold_changes.loc[gene_guides, 'log2fc'] -
                    self.gene_effects[gene_idx])**2)/2

      # Sample new dispersion
      self.dispersion[gene_idx] = np.random.gamma(shape, 1/rate)
  def estimate_parameters(self, n_iterations=1000, burn_in=100):
    """
    Estimate the parameters using EM algorithm.
    """
    self.mcmc_samples = {
        'gene_effect': np.zeros((n_iterations, self.n_genes)),
        'sgrna_efficiency': np.zeros((n_iterations, self.n_sgrnas)),
        'dispersion': np.zeros((n_iterations, self.n_genes))
    }
    # run gibbs sampling
    for iter in range(n_iterations):
      self.update_gene_effect()
      self.update_sgrna_efficiency()
      self.update_dispersion()
      # store sample (after burn-in)
      if iter >= burn_in:
        self.store_samples(iter)
      # monitor convergence
      if iter % 100 == 0:
        self.check_convergence(iter)

  def store_samples(self, iter):
    """
    Store MCMC samples.
    """
    # store gene effect estimates
    self.mcmc_samples['gene_effect'][iter, :] = self.gene_effects
    # store guide RNA efficiency estimates
    self.mcmc_samples['sgrna_efficiency'][iter, :] = self.sgrna_efficiency
    # store dispersion estimates
    self.mcmc_samples['dispersion'][iter, :] = self.dispersion




In [49]:
# a small example
count_data = pd.DataFrame(
    {'sgrna': ['sgrna1', 'sgrna2', 'sgrna3', 'sgrna4'],
     'gene': ['geneA', 'geneA', 'geneB', 'geneB'],
     'control': [100, 200, 200, 250],
     'treatment': [50, 150, 200, 400]
     })
sgrna_efficiency = {'sgrna1':0, 'sgrna2':0, 'sgrna3':0, 'sgrna4':0}
bayes_selector = BayesianSelection(count_data, sgrna_efficiency=sgrna_efficiency)

In [50]:
bayes_selector.update_gene_effect()
print(bayes_selector.gene_effect)

sgRNA data: 0   -0.192645
1   -0.192645
Name: log2fc, dtype: float64
sgRNA data: 2   -0.192645
3   -0.192645
Name: log2fc, dtype: float64
[ 0.27124378 -0.01325144]


In [19]:
for gene_idx, gene in enumerate(bayes_selector.genes):
  print(bayes_selector.gene_guide_map)

None
None


In [None]:
bayes_selector.fold_changes


In [None]:
bayes_selector.fold_changes


In [None]:
print(bayes_selector.gene_priors)
print(bayes_selector.sgrna_priors)
print(bayes_selector.dispersion_priors)
print(bayes_selector.gene_guide_map)

AttributeError: 'BayesianSelection' object has no attribute 'gene_priors'

In [None]:
def example_update():
    # Example data
    prior_mean = 0
    prior_var = 1
    guide_data = np.array([0.5, 0.7, 0.3])
    guide_vars = np.array([0.1, 0.1, 0.1])

    # Calculate posterior
    posterior_var = 1 / (1/prior_var + np.sum(1/guide_vars))
    posterior_mean = posterior_var * (
        np.sum(guide_data/guide_vars) +
        prior_mean/prior_var
    )

    # Sample
    new_effect = np.random.normal(posterior_mean, np.sqrt(posterior_var))

    return new_effect

In [None]:
example_update()

0.588096056993475

In [None]:
np.random.normal(1,4)

0.8626135969033462