# Simulation-Based Inference (SBI) in population genetics

Welcome to this workshop on applying neural posterior estimation (NPE) in population genetics! In this notebook, we will explore together how to use our Snakemake pipeline for simulation-based inference in population genetics. 

## 1. Introduction

We'll walk you through:
1. A brief overview of the SBI toolbox and NPE.
2. Setting up the required environments? (probably should do this in advance)
3. Reading in and exploring pre-simulated data.
4. Training a posterior using the SBI toolbox.
5. Evaluation of the posterior distribution (based on insufficient dataset).
6. Loading a pre-trained posterior with evaluation.
7. Visualisations

   ...more ideas?

### 1.1. A Brief Overview

Neural posterior estimation (NPE) is provided within [sbi toolbox](https://github.com/sbi-dev/sbi) where we can learn the posterior distribution of parameters given observations using flexible neural networks. 
- It allows us to infer complex, high-dimensional parameters without relying on approximate likelihoods.
- The approach is especially useful for scenarios where the likelihood function is expensive or intractable, but data simulation is feasible.
  
You can visit [sbi documentation](https://sbi-dev.github.io/sbi/latest/) for more information.

Based on sbi, our [Snakemake](https://snakemake.readthedocs.io/en/stable/) pipeline provides a framework for simulation-based inference in population genetics using [msprime](https://tskit.dev/msprime/docs/stable/quickstart.html). It automates data simulation (e.g., tree sequences), training of neural posterior estimators (NPEs), and plotting/visualization of inferred parameters. 

Three different workflows are provided: an amortized msprime workflow, an amortized dadi workflow, and a sequential msprime workflow. Configuration files control the number of simulations, model details, and training settings, making the workflow flexible for various population genetic scenarios.
For more information on this pipeline, please visit our [GitHub repository](https://github.com/your-org/your-sbi-snakemake-pipeline).

- [ ] How should we present the pipeline -- DAG?

### 1.2. Prerequisites

Before we begin, ensure the following:
1. **Operating System**: Linux/macOS/Windows (with WSL2 or an equivalent environment).
2. **Hardware**: (do we want to keep everything on CPU?)
    - CPU with at least X cores (recommended).
    - GPU (optional but recommended) for faster training with PyTorch.
3. **Software**:
    - Python 3.9+ [sbi0.22.0](https://github.com/sbi-dev/sbi/releases/tag/v0.22.0).
    - [conda](https://docs.conda.io/en/latest/) (or `venv`) for environment management.
    - Required Python libraries for this tutorial ([requirements](https://github.com/kr-colab/popgensbi_snakemake/blob/main/requirements.yaml)).

#### Environment Setup

To run this notebook, please follow these steps:
1. Install [conda](https://docs.conda.io/en/latest/miniconda.html) if you haven’t already.
2. Clone the repository: `git clone https://github.com/kr-colab/popgensbi_snakemake.git`
3. Create the environment: `conda env create -f requirements.yaml`
4. Activate the environment: `conda activate popgensbi_env`
5. Launch Jupyter notebook: `jupyter notebook`.
6. In the Notebook, select the "popgensbi" kernel if prompted.

### 1.3. Environment Test

- [ ] Here should be a short test block

In [9]:
# Are you ready to go?

import sys
import subprocess

# List of critical packages we expect
required_packages = ["snakemake", "msprime", "dadi", "sbi", "torch"]
missing_packages = []

for pkg in required_packages:
    try:
        __import__(pkg)
    except ImportError:
        missing_packages.append(pkg)

if missing_packages:
    print("WARNING: The following packages are missing:", missing_packages)
    print("Please install or switch to the conda environment that has them.")
else:
    print("All required packages found. Environment looks good!")


Please install or switch to the conda environment that has them.


---

## 2. Explore demographic inference: Simulated data and NPE

Here we’ll load some pre-generated data in TSV format and explore it briefly. In this section, let's infer historical effective population sizes using summary statistics. Train your posterior quickly and play with it!

In [11]:
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt

# path_population_sizes = '... .tsv'
# path_summary_statistics = '.tsv'

# # Read in the data
# df_pop = pd.read_csv(path_population_sizes, sep='\t')
# df_sum = pd.read_csv(path_summary_statistics, sep='\t')

# print(f"Shape of Ne: {df_pop.shape}")
# print(f"Shape of summary statistics: {df_sum.shape}")

In [None]:
# df_pop.head()

In [None]:
# df_sum.head()

In [None]:
# from sbi.inference import SNPE, prepare_for_sbi
# import torch

# # Convert to torch tensors
# theta = torch.tensor(df[...].values, dtype=torch.float32)
# x = torch.tensor(df[...].values, dtype=torch.float32)

# inference = SNPE(prior=None)  # Usually, you'd define a prior or pass a prior object.

# # Train the posterior (this can take a while, especially on CPU)
# density_estimator = inference.append_simulations(theta, x).train()
# posterior = inference.build_posterior(density_estimator)

In [None]:
# def compute_mse(true_params, inferred_params):
#     return ((true_params - inferred_params)**2).mean().item()

# observed_x = ... #testing set

# # Sample from the posterior
# with torch.no_grad():
#     inferred_samples = posterior.sample((1000,), x=observed_x)  # get 1000 samples
# inferred_mean = inferred_samples.mean(dim=0)

# mse_value = compute_mse(true_params, inferred_mean)
# print(f"MSE for test index {test_index} = {mse_value}")


In [None]:
# import pickle

# posterior_path = 'pretrained_posterior.pkl'

# with open(posterior_path, 'rb') as f:
#     pretrained_posterior = pickle.load(f)

# # Now we can do inference with the loaded posterior
# test_index = 1
# observed_x = x[test_index].unsqueeze(0)
# true_params = theta[test_index]

# with torch.no_grad():
#     inferred_samples = pretrained_posterior.sample((1000,), x=observed_x)
# inferred_mean = inferred_samples.mean(dim=0)

# mse_value = compute_mse(true_params, inferred_mean)
# print(f"Using the pre-trained posterior, MSE = {mse_value}")


---

## 3. Extended scenarios

- [ ] We can let people run the snakemake pipeline from this step, then use this notebook to work with the posterior.
- [ ] Or, we can simply try out all the steps here in the notebook, just to get people familiar with the general workflow.
- [ ] There will be a pre-trained posterior with testing data provided anyway, in case some accidence happen

### 3.1 Customize the prior for your interest.

- [ ] make the prior more flexible
- [ ] define a simulator

### 3.2 Compute summary statistics
### 3.3 Customize neural network architecture/ work with embedding NN
### 3.4 NPE training

---

So far we have walked through the complete workflow including data simulation and NPE training. 
There are different ways to visualize the posterior distribution using sbi integrated, or self-defined functions.
...

## 4. Evaluation and visualisation

Thank you for following along! We hope this tutorial helps you get started with the SBI Snakemake pipeline for population genetics.