# Simulation-Based Inference (SBI) in population genetics

Welcome to this workshop on applying neural posterior estimation (NPE) in population genetics! In this notebook, we will explore together how to use our Snakemake pipeline for simulation-based inference in population genetics. 

## 1. Introduction

We'll walk you through:
1. A brief overview of the SBI toolbox and NPE.
2. Setting up the required environments? (probably should do this in advance)
3. Reading in and exploring pre-simulated data.
4. Training a posterior using the SBI toolbox.
5. Evaluation of the posterior distribution (based on insufficient dataset).
6. Loading a pre-trained posterior with evaluation.
7. Visualisations

   ...more ideas?

### 1.1. A Brief Overview

Neural posterior estimation (NPE) is provided within [sbi toolbox](https://github.com/sbi-dev/sbi) where we can learn the posterior distribution of parameters given observations using flexible neural networks. 
- It allows us to infer complex, high-dimensional parameters without relying on approximate likelihoods.
- The approach is especially useful for scenarios where the likelihood function is expensive or intractable, but data simulation is feasible.
  
You can visit [sbi documentation](https://sbi-dev.github.io/sbi/latest/) for more information.

Based on sbi, our [Snakemake](https://snakemake.readthedocs.io/en/stable/) pipeline provides a framework for simulation-based inference in population genetics using [msprime](https://tskit.dev/msprime/docs/stable/quickstart.html). It automates data simulation (e.g., tree sequences), training of neural posterior estimators (NPEs), and plotting/visualization of inferred parameters. 

Three different workflows are provided: an amortized msprime workflow, an amortized dadi workflow, and a sequential msprime workflow. Configuration files control the number of simulations, model details, and training settings, making the workflow flexible for various population genetic scenarios.
For more information on this pipeline, please visit our [GitHub repository](https://github.com/your-org/your-sbi-snakemake-pipeline).

- [ ] Make a few slides introducing the genral idea of NPE (to biologists)

### 1.2. Prerequisites

Before we begin, ensure the following:
1. **Operating System**: Linux/macOS/Windows (with WSL2 or an equivalent environment).
2. **Hardware**:
    - Only CPU is needed for this workshop.
    - [ ] GPU usage will be provided in another Notebook.
3. **Software**:
    - Python 3.9+ [sbi0.22.0](https://github.com/sbi-dev/sbi/releases/tag/v0.22.0).
    - [conda](https://docs.conda.io/en/latest/) (or `venv`) for environment management.
    - Required Python libraries for this tutorial ([requirements](https://github.com/kr-colab/popgensbi_snakemake/blob/main/requirements.yaml)).

#### Environment Setup

To run this notebook, please follow these steps:
1. Install [conda](https://docs.conda.io/en/latest/miniconda.html) if you haven’t already.
2. Clone the repository: `git clone https://github.com/kr-colab/popgensbi_snakemake.git`
3. Download the .zip folder `folder_name` for testing data and pre-trained neural networks.
4. Create the environment: `conda env create -f requirements.yaml`
5. Activate the environment: `conda activate popgensbi_env`
6. Launch Jupyter notebook: `jupyter notebook`.
7. In the Notebook, select the "popgensbi" kernel if prompted.

### 1.3. Environment Test

- [ ] Here should be a short test block

In [9]:
# Are you ready to go?

import sys
import subprocess

# List of critical packages we expect
required_packages = ["snakemake", "msprime", "dadi", "sbi", "torch"]
missing_packages = []

for pkg in required_packages:
    try:
        __import__(pkg)
    except ImportError:
        missing_packages.append(pkg)

if missing_packages:
    print("WARNING: The following packages are missing:", missing_packages)
    print("Please install or switch to the conda environment that has them.")
else:
    print("All required packages found. Environment looks good!")

Please install or switch to the conda environment that has them.


In [None]:
# Test if NPE is running without problem.
import torch
from sbi.inference import NPE

# define shifted Gaussian simulator.
def simulator(θ): return θ + torch.randn_like(θ)
# draw parameters from Gaussian prior.
θ = torch.randn(1000, 2)
# simulate data
x = simulator(θ)

# choose sbi method and train
inference = NPE()
inference.append_simulations(θ, x).train()

# do inference given observed data
x_o = torch.ones(2)
posterior = inference.build_posterior()
samples = posterior.sample((1000,), x=x_o)

---

## 2. Explore demographic inference

In this section, we first have a look at the data set. 
All data provided here are from testing sets so that you can test them on a pre-trained neural network to make your own posterior samples.

### 2.1 Simulated data

Through a prior, we randomly generated some demographic scenarios. Each scenario has 21 effective population sizes from presence to the past along exponentially growing gaps, and a recombination rate.

In [11]:
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt

In [None]:
# path_population_sizes = '... .tsv'
# df_pop_size = pd.read_csv(path_population_sizes, sep='\t')
# df_pop_size

### 2.2 SNP matrix

[Msprime](https://tskit.dev/msprime/docs/stable/quickstart.html) is a powerful coalescent simulator that models the ancestry of a sample of genomes under specified demographic parameters (e.g., population sizes, mutation rates, recombination rates, and population splits). It generates a [tree sequence](https://tskit.dev/tutorials/intro.html), which is essentially a record of how all sampled individuals coalesce back to their common ancestors. By placing mutations along the branches of these ancestral trees (according to the specified mutation rate), msprime outputs simulated genetic variation—ultimately yielding SNP data or variant matrices that can be used for downstream analyses, such as training neural posterior estimators in our pipeline.

In [None]:
# load SNP matrices


### 2.3 Summary statistics

We computed site frequency spectrum and linkage disequilibrium from the SNP matrices, which ends up into 68 summary statistics.

In [None]:
# path_summary_statistics = '.tsv'
# df_sum = pd.read_csv(path_summary_statistics, sep='\t')
# df_sum

### 2.4 Pre-trained network

Give it a shot!

In [None]:
# import pickle

## This network is trained on population sizes and summary statistics. 
# posterior_path = 'pretrained_posterior.pkl'
# with open(posterior_path, 'rb') as f:
#     pretrained_posterior = pickle.load(f)

# # Now we can do inference with the loaded posterior
# test_index = 1
# observed_x = x[test_index].unsqueeze(0)
# true_params = theta[test_index]

# with torch.no_grad():
#     inferred_samples = pretrained_posterior.sample((1000,), x=observed_x)
# inferred_mean = inferred_samples.mean(dim=0)

In [None]:
## This is how the network is trained. Uncomment to try it out. This could take some time.

# from sbi.inference import SNPE, prepare_for_sbi
# import torch

# # Convert to torch tensors
# theta = torch.tensor(df[...].values, dtype=torch.float32)
# x = torch.tensor(df[...].values, dtype=torch.float32)

# inference = SNPE(prior=None)  # Usually, you'd define a prior or pass a prior object.

# # Train the posterior (this can take a while, especially on CPU)
# density_estimator = inference.append_simulations(theta, x).train()
# posterior = inference.build_posterior(density_estimator)

---

## 3. Various scenarios

A population could go through different types of history. We can specifically simulate them here, and see how the neural network performs on each of them. In this section, you will be simulating the testing data and computing the summary statistics on your own, instead of loading them directly! 

### 3.1 Simulate 6 representative scenarios

- **Medium**: constant population sizes of 5,000.
- **Large**: constant population sizes of 50,000. 
- **Decline**: decreasing population sizes. 
- **Expansion**: increasing population sizes.
- **Bottleneck**: enlarging population sizes followed by a 
- **Zigzag**: two bottlenecks

In [None]:
def simulate_scenario(population_size, population_time, seed, num_replicates, mutation_rate, recombination_rate, 
                      segment_length, num_sample):
    
    demography = msprime.Demography()
    demography.add_population(initial_size=population_size[0])

    for i in range(1, len(population_size)):
        demography.add_population_parameters_change(time=population_time[i], initial_size=population_size[i], growth_rate=0)

    ts = msprime.sim_ancestry(
        num_sample,
        random_seed=seed,
        sequence_length=segment_length,
        ploidy=1,
        num_replicates=num_replicates,
        demography=demography,
        recombination_rate=recombination_rate)
    pos = []
    snp = []
    
    for rep, tree in enumerate(ts):
        mts = msprime.sim_mutations(tree, rate=mutation_rate, random_seed=seed)
        positions = [variant.site.position for variant in mts.variants()]
        positions = np.array(positions) - np.array([0] + positions[:-1])
        positions = positions.astype(int)
        pos.append(positions)
        SNPs = mts.genotype_matrix().T.astype(np.uint8)
        snp.append(SNPs)
    
    data = [[snp[i], pos[i]] for i in range(len(snp))]
    data = [np.vstack([d[1], d[0]]) for d in data]
    return data

# Population sizes are defined on a log10 scale
scenarios = {'Medium': [3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7],   
             'Large': 4.7 * np.ones(shape=21, dtype='float'), 
             'Decline': [2.5, 2.5, 3, 3, 3, 3, 3.2, 3.4, 3.6, 3.8, 4, 4.2, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6], 
             'Expansion': [4.7, 4.7, 4.7, 4.6, 4.6, 4.5, 4.4, 4.3, 4, 3.7, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4], 
             'Bottleneck': [4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.5, 4.15, 3.8, 4.3, 4.8, 4.55, 4.3, 4.05, 3.8, 3.8, 3.8, 3.8], 
             'Zigzag': [4.8, 4.8, 4.8, 4.5, 4.15, 3.8, 4.15, 4.5, 4.8, 4.5, 4.15, 3.8, 4.3, 4.8, 4.55, 4.3, 4.05, 3.8, 3.8, 3.8, 3.8]}
scenarios = {k:10**np.array(scenarios[k]) for k in scenarios.keys()}

seed = 2
num_replicates = 100
mutation_rate = 1e-8
segment_length = 2e6
time_rate = 0.06
tmax = 130000
num_time_windows = 21
num_sample = 50
population_time = [(np.exp(np.log(1 + time_rate * tmax) * i /
                  (num_time_windows - 1)) - 1) / time_rate for i in
                  range(num_time_windows)]

snp_data = {}
for k in scenarios.keys():
    print(f'Simulating scenario \"{k}\"')
    population_size = scenarios[k]
    recombination_rate = np.random.uniform(low=1e-9, high=1e-8)
    snp_data[k] = simulate_scenario(population_size, population_time, seed, num_replicates, mutation_rate, recombination_rate, segment_length, num_sample)

In [None]:
scenarios

In [None]:
snp_data

### 3.2 Compute summary statistics

In [None]:
def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
    if distance_bins is None or isinstance(distance_bins, int):
        if isinstance(distance_bins, int):
            n_bins = distance_bins - 1
        else:
            n_bins = 19
        if circular:
            distance_bins = np.logspace(2, np.log10(size_chr // 2), n_bins)
            distance_bins = np.insert(distance_bins, 0, [0])
        else:
            distance_bins = np.logspace(2, np.log10(size_chr), n_bins)
            distance_bins = np.insert(distance_bins, 0, [0])

    # Iterate through gap sizes
    n_SNP, n_samples = haplotype.shape
    gaps = (2 ** np.arange(0, np.log2(n_SNP), 1)).astype(int)

    # Initialize lists to store selected SNP pairs and LD values
    selected_snps = []
    for gap in gaps:
        snps = np.arange(0, n_SNP, gap) + np.random.randint(0, (n_SNP - 1) % gap + 1)
        # adding a random start (+1, bc 2nd bound in randint is exlusive)

        # non overlapping contiguous pairs
        # snps=[ 196, 1220, 2244] becomes
        # snp_pairs=[(196, 1220), (1221, 2245)]
        snp_pairs = np.unique([((snps[i] + i) % n_SNP, (snps[i + 1] + i) % n_SNP) for i in range(len(snps) - 1)],
                              axis=0)

        # If we don't have enough pairs (typically when gap is large), we add a random rotation until we have at
        # least 300) count = 0

        if not circular:
            snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
        last_pair = snp_pairs[-1]

        if circular:
            max_value = n_SNP - 1
        else:
            max_value = n_SNP - gap - 1

        while len(snp_pairs) <= min(300, max_value):
            # count += 1 if count % 10 == 0: print(">>  " + str(gap) + " - " + str(len(np.unique(snp_pairs,
            # axis=0))) + " -- "+ str(len(snps) - 1) + "#" + str(count)) remainder = (n_SNP - 1) % gap if (n_SNP - 1)
            # % gap != 0 else (n_SNP - 1) // gap
            random_shift = np.random.randint(1, n_SNP) % n_SNP
            new_pair = (last_pair + random_shift) % n_SNP
            snp_pairs = np.unique(np.concatenate([snp_pairs, new_pair.reshape(1, 2)]), axis=0)
            last_pair = new_pair

            if not circular:
                snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]

        selected_snps.append(snp_pairs)

    # Functions to aggregate the values within each distance bin
    agg_bins = {"snp_dist": ["mean"], "r2": ["mean", "count", "sem"]}

    ld = pd.DataFrame()
    for i, snps_pos in enumerate(selected_snps):

        if circular:
            sd = pd.DataFrame((np.diff(pos_vec[snps_pos]) % size_chr) % (size_chr // 2),
                              columns=["snp_dist"])  # %size_chr/2 because max distance btw 2 SNP is size_chr/2
        else:
            sd = pd.DataFrame((np.diff(pos_vec[snps_pos])), columns=["snp_dist"])

        sd["dist_group"] = pd.cut(sd.snp_dist, bins=distance_bins)
        sr = [allel.rogers_huff_r(snps) ** 2 for snps in haplotype[snps_pos]]
        sd["r2"] = sr
        sd["gap_id"] = i
        ld = pd.concat([ld, sd])

    ld2 = ld.dropna().groupby("dist_group").agg(agg_bins)

    # Flatten the MultiIndex columns and rename explicitly
    ld2.columns = ['_'.join(col).strip() for col in ld2.columns.values]
    ld2 = ld2.rename(columns={
        'snp_dist_mean': 'mean_dist',
        'r2_mean': 'mean_r2',
        'r2_count': 'Count',
        'r2_sem': 'sem_r2'
    })
    # ld2 = ld2.fillna(-1)
    return ld2[['mean_dist', 'mean_r2', 'Count', 'sem_r2']]


def sfs(haplotype, ac):
    """
    Calculate the site frequency spectrum (SFS) from haplotype data and allele counts.

    Parameters
    ----------
    haplotype (numpy.ndarray): The haplotype matrix where rows represent variants and columns represent individuals.
    ac (numpy.ndarray): Allele count array where each entry represents the count of the derived allele at a site.

    Returns
    -------
    pandas.DataFrame: DataFrame containing the SFS. Each row corresponds to a frequency (number of individuals),
    with the corresponding count of SNPs that have that frequency.

    """
    nindiv = haplotype.shape[1]
    tmp_df = pd.DataFrame({"N_indiv": range(1, nindiv)})

    # getting unfolded sfs
    df_sfs = pd.DataFrame(allel.sfs(ac.T[1]), columns=["count_SNP"])
    df_sfs.index.name = "N_indiv"
    df_sfs.reset_index(inplace=True)
    df_sfs = df_sfs.merge(tmp_df, on="N_indiv", how="right").fillna(0).astype(int)

    return df_sfs

def compute_sumstat(snp_data):
    output=[]
    scenario_data = {}
    for scenario in snp_data:
        for rep in snp_data[scenario]:
            snp = rep[1:][:50,:400]
            pos = rep[0][:400]

            if any(np.diff(pos) < 0):
                pos = np.cumsum(pos)
            if pos.max() <= 1:
                pos = (pos * 2e6).round().astype(int)

            haplotype = allel.HaplotypeArray(snp.T)
            allel_count = haplotype.count_alleles()
        
            afs = sfs(haplotype, allel_count)
            afs = afs.set_index('N_indiv')
            afs['scenario'] = scenario

            ld = LD(haplotype, pos, circular=False, size_chr=2e6)
            ld["scenario"] = scenario
            ld = ld.drop(columns=['sem_r2'])

            if scenario not in scenario_data:
                scenario_data[scenario] = {"afs": [], "ld": []}
            scenario_data[scenario]["afs"].append(afs)
            scenario_data[scenario]["ld"].append(ld)

    for scenario in snp_data:
        mean_afs = pd.concat(scenario_data[scenario]["afs"]).groupby("N_indiv").mean()
        mean_afs['scenario'] = scenario
        mean_afs.reset_index(inplace=True)
        mean_ld = pd.concat(scenario_data[scenario]["ld"]).groupby("dist_group").mean()
        mean_ld['scenario']=scenario
        mean_ld.reset_index(inplace=True)
        
        df_sfs = mean_afs.set_index('N_indiv')
        df_sfs_out = df_sfs.loc[df_sfs['scenario'] == scenario]
        df_sfs_out = df_sfs_out.drop(columns=['scenario'])
        df_sfs_out = df_sfs_out.stack(dropna=False)
        df_sfs_out.index = df_sfs_out.index.map('{0[1]}_{0[0]}'.format)
        df_sfs_out = df_sfs_out.to_frame().T
        df_sfs_out = df_sfs_out.set_index([[scenario]])

        df_ld_out = mean_ld.loc[np.array(mean_ld['scenario'] == scenario)]
        df_ld_out = df_ld_out.drop(columns=['scenario'])
        df_ld_out = df_ld_out.stack(dropna=False)
        df_ld_out.index = df_ld_out.index.map('{0[1]}_{0[0]}'.format)
        df_ld_out = df_ld_out.to_frame().T
        df_ld_out = df_ld_out.set_index([[scenario]])
        df = pd.merge(df_sfs_out, df_ld_out, left_index=True, right_index=True)
        output.append(df)
    return output

sumstats=compute_sumstat(snp_data)
tt = pd.DataFrame()
for v in sumstats:
    tt = pd.concat([tt,v])
sumstats = tt.drop(columns=[c for c in tt.columns if c.startswith('mean_dist') or c.startswith('dist') or c.startswith('Count_')])
sumstats.index.name='scenario'

In [None]:
sumstats

---

So far we have walked through the complete workflow including data simulation and NPE training. 
There are different ways to visualize the posterior distribution using sbi integrated, or self-defined functions.
...

## 4. Evaluation and visualisation

This is a free-styling section! Apart from some functions provided here, please try visualize the results by yourself.

In [None]:
# visualise the summary stats for testing 6 scenarios

Thank you for following along! We hope this tutorial helps you get started with the SBI Snakemake pipeline for population genetics.

---

## 5. More fancy applications ...

- [ ] Andy's pipeline, a brief introduction of what can be done.