# Setup

To access material for this workbook please execute the two notebook cells immediately below (e.g. use the shortcut <shift>+<return>). The first cell can be skipped if you are running this notebook locally and have already installed all the necessary packages. The second cell should print out "Your notebook is ready to go!"

In [None]:
JUPYTERLITE = False
if 'pyodide_kernel' in str(get_ipython()):  # specify packages to install under JupyterLite
    JUPYTERLITE = True
    %pip install -q -r jlite-requirements.txt
elif 'google.colab' in str(get_ipython()):  # specify package location for loading in Colab
    from google.colab import drive
    drive.mount('/content/drive')
    %run /content/drive/MyDrive/GARG_workshop/Notebooks/add_module_path.py
else:  # install packages on your local machine (-q = "quiet": don't print out installation steps)
    # (NB: you can probably ignore any message about restarting the kernel)
    !pip install -q -r https://github.com/ebp-nor/GARG/raw/main/jlite/requirements.txt

In [None]:
# Load questions etc for this workbook
import ARG_workshop
workbook = ARG_workshop.Workbook1F()
display(workbook.setup)
import warnings
warnings.filterwarnings('ignore')

# Workbook 1-F: Simulations with stdpopsim II

In the previous workbook we were introduced to stdpopsim, a package for making population genetic simulations. Here, we will build upon what we learned and make some more advanced models. We will look at a selective sweep model, and also simulate data from a nonmodel organism. We will make use of a sweep model that actually can be simulated with msprime (namely [msprime.SweepGenicSelection](https://tskit.dev/msprime/docs/stable/api.html#msprime.SweepGenicSelection)), thus avoiding the need to rely on SLiM in this particular instance.

## Setting up a *Pan troglodytes* model


Import the relevant libraries

In [None]:
import numpy as np
import pandas as pd
import msprime
import stdpopsim
import tskit

and get a chimpanzee demographic model which includes a ghost population. We plot the demographic history with demesdraw as before to get an overview of the model.

In [None]:
species = stdpopsim.get_species("PanTro")
model = species.get_demographic_model("BonoboGhost_4K19")
contig = species.get_contig("chr3", mutation_rate=model.mutation_rate)
sample_sizes = {"bonobo": 10, "central": 10, "western": 10}

In [None]:
import demesdraw
demesgraph = model.model.to_demes()
demesdraw.tubes(demesgraph);

We specify a chromosome length and a time $G$ ago in generations, from which point we assume populations are isolated and of constant size. Furthermore, we define two loci that are subject to selective sweeps.

In [None]:
L = 10e6  # Simulate 10 Mb
G = 4000  # A time ago in generations: we assume populations from time 0..G are isolated and of constant size

sweep_params = {
    "bonobo": {"position": L//3, "s": 0.05},
    "western": {"position": (2 * L)//3, "s": 0.1},
}

Here we define a chromosome of length 10Mb with different sweep locations and selection parameters for populations `bonobo` and `western`.

## Make independent populations, some with selective sweeps

Now that we have defined sweep parameters for `bonobo` and `western` populations we simulate sampled individuals from the populations with defined sample sizes in `sample_sizes`. Briefly, we loop through the populations defined by the model and initialize a population if it is defined in `sample_sizes`. Moreover, we specify different simulation functions (`SweepGenicSelection` vs `StandardCoalescent`) depending on settings in `sweep_params`. Finally, we add independent simulations to `independent_pop_ts` and combine the results tree sequence objects to `combined_ts`. Note that we here only simulate the ancestries (trees); mutations will be added on top later on.

In [None]:
# Make independent populations, some with selective sweeps
independent_pop_ts = []
for name, pop in model.model.items():
    if name in sample_sizes:
        Ne = pop.initial_size
        demog = msprime.Demography()
        demog.add_population(name=name, initial_size=Ne)
        if name in sweep_params:
            sweep_model = msprime.SweepGenicSelection(
                start_frequency=1.0 / (2 * Ne),
                end_frequency=1.0 - (1.0 / (2 * Ne)),
                dt=1/(40 * Ne),
                **sweep_params[name],
            )
            models = (sweep_model, msprime.StandardCoalescent())
            print(f"Adding {name} population to demographic model, sweep at {int(sweep_params[name]['position'])}bp, selection coefficient s={sweep_params[name]['s']}")
        else:
            models = msprime.StandardCoalescent()
            print(f"Adding {name} population to demographic model, neutral")
        independent_pop_ts.append(
            msprime.sim_ancestry(
                sample_sizes[name],
                model=models,
                demography=demog,
                recombination_rate=contig.recombination_map.slice(right=L, trim=True),
                sequence_length=L,
                end_time=G,  # Stop simulation early
                random_seed=123,
            )
        )
combined_ts = independent_pop_ts[0]
for ts in independent_pop_ts[1:]:
    combined_ts = combined_ts.union(ts, node_mapping=np.full(ts.num_nodes, tskit.NULL))

As noted, the assumption above was that populations were isolated. We used the parameter $G$ to [stop the simulation early](https://tskit.dev/msprime/docs/stable/ancestry.html#stopping-simulations-early); this means we simulate backwards $G$ generations, but the populations do not coalesce. We end by adding history backwards in time to the three populations until coalescence with the function `msprime.sim_ancestry` using the stdpopsim model specifications in `model` (which, as you will see if you print the `model.model` object, has mass migration events defined 4003 and 6202 generations ago). Finally, we add mutations with `msprime.sim_mutations`.

In [None]:
# Now recapitate: initial_state uses the population names in the combined_ts to figure out which are which
final_ts = msprime.sim_mutations(
    msprime.sim_ancestry(initial_state=combined_ts, demography=model.model, random_seed=123).simplify(),
    rate=model.mutation_rate,
    random_seed=123,
)
print(f"Simulated {final_ts.num_sites} sites for {final_ts.num_samples} genomes")

## Making sense of the simulations

Before proceeding, let's make some sanity checks to confirm that the results make sense. We expect the central population to have undergone neutral evolution, whereas bonobo and western have selective sweep signatures at different loci (at 3.33Mb and 6.67Mb, respectively). We can investigate the effect on nucleotide diversity by making windowed plots along the chromosome.

Import matplotlib and setup color schemes and plot diversity in 100kb windows.

In [None]:
import matplotlib.pyplot as plt 
import matplotlib.colors as mcolors
mpl_colors = mcolors.TABLEAU_COLORS 
pop_color_map = {  # Use same colors as demesdraw
    'All samples': mpl_colors["tab:olive"],
    'bonobo': mpl_colors["tab:orange"],
    'central': mpl_colors["tab:green"],
    'western': mpl_colors["tab:red"],
}

In [None]:
num_windows = 200
ts = final_ts
tspop = {p.metadata["name"]:p.id for p in ts.populations()}
window_size = ts.sequence_length / num_windows
sample_sets = [ts.samples(), ts.samples(population=tspop["bonobo"]), ts.samples(population=tspop["central"]), ts.samples(population=tspop["western"])]
pi_win = ts.diversity(
    sample_sets=sample_sets,
    windows=np.linspace(0, ts.sequence_length, num_windows + 1)
)

In [None]:
for i, key in enumerate(pop_color_map):
    plt.plot(pi_win[:, i], label=key, color=pop_color_map[key])
plt.axvline(len(pi_win[:, 0]) // 3, color=pop_color_map["bonobo"], linestyle="dashed")
plt.axvline(2 * len(pi_win[:, 0]) // 3, color=pop_color_map["western"], linestyle="dashed")
plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
plt.xlabel("Genomic window")
plt.ylabel("Diversity")
plt.show()

For clarity, we added the locations of sweeps (dashed lines) for both populations. There is a clear drop in diversity at both sites (and for the correct populations too!), so the simulations make intuitive sense. We could also zoom in on one of the sweep locations for increased resolution:

In [None]:
num_windows = 2_000
pi_win_zoom = ts.diversity(sample_sets=sample_sets, windows=np.linspace(0, ts.sequence_length, num_windows + 1))

In [None]:
pad = 0.03
for i, key in enumerate(pop_color_map):
    plt.plot(pi_win_zoom[:, i], label=key, color=pop_color_map[key])
plt.axvline(len(pi_win_zoom[:, 0]) // 3, color=pop_color_map["bonobo"], linestyle="dashed")
plt.axvline(2 * len(pi_win_zoom[:, 0]) // 3, color=pop_color_map["western"], linestyle="dashed")
plt.xlim(2 * num_windows / 3 - num_windows * pad, 2 * num_windows / 3 + num_windows * pad)
plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
plt.xlabel("Genomic window")
plt.ylabel("Diversity")
plt.show()

Here the pattern of reduced diversity is more prominent for the western population - in particular there is a stretch of windows where it is consistently low. For comparison's sake, we reproduce the plot but using branch length statistics.

In [None]:
pi_win_branch = ts.diversity(sample_sets=sample_sets, windows=np.linspace(0, ts.sequence_length, num_windows + 1), mode="branch")

In [None]:
pad = 0.05
for i, key in enumerate(pop_color_map):
    plt.plot(pi_win_branch[:, i], label=key, color=pop_color_map[key])
plt.axvline(len(pi_win_branch[:, 0]) // 3, color=pop_color_map["bonobo"], linestyle="dashed")
plt.axvline(2 * len(pi_win_branch[:, 0]) // 3, color=pop_color_map["western"], linestyle="dashed")
plt.xlim(2 * num_windows / 3 - num_windows * pad, 2 * num_windows / 3 + num_windows * pad)
plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
plt.xlabel("Genomic window")
plt.ylabel("Diversity")
plt.show()

As we alluded to earlier, the plots are now smoother since the stochastic noise added by mutations has been eliminated.

## Trees at sweep locations


Since sweeps are mutations that spread quickly in a population, we would expect the samples in a sweep population to coalesce in a short period of time. Let's investigate whether this is true by plotting trees around the sweep location. We first load the necessary python library and define color styles that apply relevant colors to the populations.


In [None]:
from IPython.display import SVG

In [None]:
styles = []
for pname in list(pop_color_map)[1:]:
    popid = tspop[pname]
    p = ts.population(popid)
    color = pop_color_map[pname]
    s = f".node.p{p.id} > .sym " + "{" + f"fill: {color}" + "}"
    styles.append(s)
    print(f'"{s}" applies to nodes from population {p.metadata["name"]} (id {p.id})')
css_string = " ".join(styles)
print(f'CSS string applied:\n    "{css_string}"')

First we plot a tree at a locus that is not under selection, just to get an idea of what a general tree would look like. You can plot a tree at a specific genomic coordinate using the [ts.at](https://tskit.dev/tskit/docs/stable/python-api.html#tskit.TreeSequence.at) syntax.

In [None]:
pos = L / 10 
SVG(ts.at(pos).draw_svg(size=(1500, 500), y_axis=True, style=css_string, node_labels={}, x_axis=True, y_ticks=[0, 50_000, 100_000]))

The genealogical tree of this chromosome interval (`999629-1000346`) reflects the demographic history of the model. The bonobo population (orange) forms its own distinct clade, whereas western (red) and central (green) share a common ancestor and show little signs of admixture. There are three distinct groups from around 40000 generations ago and looking forward in time, all of which show similar distributions of branch lengths and coalescence patterns.

We now draw the tree at the sweep location:

In [None]:
sweep_pos = L / 3
SVG(ts.at(sweep_pos).draw_svg(size=(1500, 500), y_axis=True, style=css_string, node_labels={}, x_axis=True, y_ticks=[0, 4_000, 50_000, 100_000]))

The plot clearly shows how samples belonging to the `bonobo` population (orange) coalesce much faster than other samples. In fact, all samples have coalesced within the last 4000 generations. If we calculate the sum of the branch lengths for the different populations, we see that they are much smaller for bonobo:

In [None]:
for pop in list(ts.populations())[:-1]:  # Omit ghost population
    branch_length = ts.simplify(ts.samples(population=pop.id)).at(sweep_pos).total_branch_length  # For a subtree consisting of samples in population pop.id at sweep position, get total branch length
    print(f"{pop.metadata['name']} total branch length: {branch_length}")                            

With [ts.simplify](https://tskit.dev/tskit/docs/stable/python-api.html#tskit.TreeSequence.simplify) we can subset a tree sequence to a list of samples (here corresponding to the samples of each population), and then we, once again, look at the tree at the sweep position and retrieve the total branch length of that tree.

For illustration purposes, we also make a boxplot of branch lengths.

In [None]:
x = []
for pop in list(ts.populations())[:-1]:
    tstmp = ts.simplify(ts.samples(population=pop.id)).at(sweep_pos)  # Get a subtree at sweep position for samples in pop.id only
    x.append([tstmp.branch_length(u) for u in tstmp.nodes()])  # Add all branch lengths for the nodes of the subtree

In [None]:
labels = [pop.metadata["name"] for pop in list(ts.populations())[:-1]]
colors = [pop_color_map[pop.metadata["name"]] for pop in list(ts.populations())[:-1]]
fig, ax = plt.subplots()
ax.set_ylabel('branch length')
bplot = ax.boxplot(x,
                   patch_artist=True,  # fill with color
                   tick_labels=labels)  # will be used to label x-ticks
for patch, col in zip(bplot['boxes'], colors):
    patch.set_facecolor(col)
plt.show()

This hints at a way to design a test for sweeps based on branch lengths being much shorter than the background around the sweep locus.

We can also illustrate what happens to trees as we move away from the sweep locus. The intuition is that as recombination breaks down linkage to the sweep locus, genomic segments further away will evolve neutrally and the coalescence times of the samples will increase, leading to trees with longer terminal branches and larger total branch length.

In [None]:
tslist = []
ts = final_ts
delta = 50_000  # Look this distance left and right
poslist = np.linspace(sweep_pos - delta, sweep_pos + delta, 7)
tslist = [ts.simplify(ts.samples(population=2), filter_nodes=False).at(pos) for pos in poslist]
max_time = 25_000  # Find suitable max time to rescale plots
ticks = np.linspace(0, max_time, num=6)
ticks = {t: f"{t/1_000:.0f}k" for t in ticks}
from IPython.display import HTML 
HTML("".join(
    tree.draw_svg(
        max_time=max_time,
        size=(200, 350),
        node_labels={},
        symbol_size=1,
        y_axis=True,
        y_ticks=ticks,
        x_axis=True,
        style=css_string,
        mutation_labels={},
        omit_sites=True,
        root_svg_attributes={"style":"display: inline-block"},
    )
    for tree in tslist
))

There are still a lot of short terminal branches, but clearly the height of the tree increases as we move away from the sweep locus.

## Haplotype GNNs

We conclude this demonstration 

In [None]:
ts.sequence_length
n = 500
windows = np.linspace(start=0.0, stop=ts.sequence_length, num=n)
window_size = windows[1] - windows[0]

In [None]:
group_sample_sets = []
for pop in list(ts.populations())[:-1]:
    group_sample_sets.append(ts.samples(pop.id))

In [None]:
focal_node = 10
haplotype = 1
df = ARG_workshop.haplotype_gnn(ts, focal_node, windows, group_sample_sets)
df.columns = list(tspop.keys())[:-1]
df = df[df.index.get_level_values("haplotype") == (haplotype - 1)]

In [None]:
ind_id = ts.individual(ts.node(focal_node).individual).id
focal_pop = ts.population(ts.individual_populations[ind_id]).metadata["name"]

In [None]:
fig, ax = plt.subplots()
fig.set_figwidth(15)
fig.set_figheight(1)
bottom = np.zeros(df.shape[0])

for pop in df.columns:
    p = ax.bar(df.index.get_level_values("start"), df[pop], width=window_size, label=pop, bottom=bottom, color=pop_color_map[pop])
    bottom += df[pop]

ax.set_title(f"GNN plot for haplotype {haplotype}, focal node {focal_node}, individual {ind_id}, focal population '{focal_pop}'")
ax.legend(bbox_to_anchor=(1, 1), loc="upper left")

plt.show()