# parallel sweeps
#### 9 Dec 2025
creates dfs inf_all and null_all

###  notes

In [1]:
import os
os.listdir()

['.ipynb_checkpoints', 'with_plots.ipynb', 'vcf_stage']

### get started

#### import modules

In [2]:
import os

In [3]:
from tqdm.notebook import tqdm

In [4]:
import time
for i in tqdm(range(100)):
    time.sleep(0.05) # Simulate some work

  0%|          | 0/100 [00:00<?, ?it/s]

In [5]:
import tsinfer
import tskit
import msprime
import tsdate

import numpy as np
import pandas as pd

import datetime as dt
import time

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline 
from matplotlib.colors import TwoSlopeNorm
import matplotlib.ticker as ticker

from scipy.stats import gaussian_kde
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score 

import itertools
from itertools import combinations

from Bio import SeqIO, AlignIO

import gzip
import csv

import subprocess, gzip, tempfile
import shutil

from concurrent.futures import ProcessPoolExecutor, as_completed
import warnings

#### define insert_proxy_samples

In [6]:
import logging
logger = logging.getLogger("tsinfer")

def insert_proxy_samples(
        self,
        variant_data,
        *,
        sample_ids=None,
        epsilon=None,
        keep_ancestor_times=None,
        allow_mutation=None,  # deprecated alias
        **kwargs,
):
        """
        Take a set of samples from a :class:`.VariantData` instance and create additional
        "proxy sample ancestors" from them, returning a new :class:`.AncestorData`
        instance including both the current ancestors and the additional ancestors
        at the appropriate time points.

        A *proxy sample ancestor* is an ancestor based upon a known sample. At
        sites used in the full inference process, the haplotype of this ancestor
        is identical to that of the sample on which it is based. The time of the
        ancestor is taken to be a fraction ``epsilon`` older than the sample on
        which it is based.

        A common use of this function is to provide ancestral nodes for anchoring
        historical samples at the correct time when matching them into a tree
        sequence during the :func:`tsinfer.match_samples` stage of inference.
        For this reason, by default, the samples chosen from ``sample_data``
        are those associated with historical (i.e. non-contemporary)
        :ref:`individuals <sec_inference_data_model_individual>`. This can be
        altered by using the ``sample_ids`` parameter.

        .. note::

            The proxy sample ancestors inserted here will correspond to extra nodes
            in the inferred tree sequence. At sites which are not used in the full
            inference process (e.g. sites unique to a single historical sample),
            these proxy sample ancestor nodes may have a different genotype from
            their corresponding sample.

        :param VariantData variant_data: The `VariantData` instance
            from which to select the samples used to create extra ancestors.
        :param list(int) sample_ids: A list of sample ids in the ``variant_data``
            instance that will be selected to create the extra ancestors. If
            ``None`` (default) select all the historical samples, i.e. those
            associated with an :ref:`sec_inference_data_model_individual` whose
            time is greater than zero. The order of ids is ignored, as are
            duplicate ids.
        :param list(float) epsilon: An list of small time increments
            determining how much older each proxy sample ancestor is than the
            corresponding sample listed in ``sample_ids``. A single value is also
            allowed, in which case it is used as the time increment for all selected
            proxy sample ancestors. If None (default) find :math:`{\\delta}t`, the
            smallest time difference between the sample times and the next
            oldest ancestor in the current :class:`.AncestorData` instance, setting
            ``epsilon`` = :math:`{\\delta}t / 100` (or, if all selected samples
            are at least as old as the oldest ancestor, take :math:`{\\delta}t`
            to be the smallest non-zero time difference between existing ancestors).
        :param bool keep_ancestor_times: If ``False`` (the default), the existing
            times of the ancestors in the current :class:`.AncestorData` instance
            may be increased so that derived states in the inserted proxy samples.
            can have an ancestor with a mutation to that site above them (i.e. the
            infinite sites assumption is maintained). This is useful when sites
            times have been approximated by using their frequency. Alternatively,
            if ``keep_ancestor_times`` is ``True``, existing ancestor times are
            preserved, and inserted proxy sample ancestors are allowed to
            possess derived alleles at sites where there are no pre-existing
            mutations in older ancestors. This can lead to a de-novo mutation at a
            site that also has a mutation elsewhere (i.e. breaking the infinite sites
            assumption).
        :param bool allow_mutation: Deprecated alias for `keep_ancestor_times`.
        :param \\**kwargs: Further arguments passed to the constructor when creating
            the new :class:`AncestorData` instance which will be returned.

        :return: A new :class:`.AncestorData` object.
        :rtype: AncestorData
        """
        if allow_mutation is not None:
            if keep_ancestor_times is not None:
                raise ValueError(
                    "Cannot specify both `allow_mutation` and `keep_ancestor_times`"
                )
            keep_ancestor_times = allow_mutation
        self._check_finalised()
        variant_data._check_finalised()
        if self.sequence_length != variant_data.sequence_length:
            raise ValueError("variant_data does not have the correct sequence length")
        used_sites = np.isin(variant_data.sites_position[:], self.sites_position[:])
        if np.sum(used_sites) != self.num_sites:
            raise ValueError("Genome positions in ancestors missing from variant_data")

        if sample_ids is None:
            sample_ids = []
            for i in variant_data.individuals():
                if i.time > 0:
                    sample_ids += i.samples
        # sort by ID and make unique for quick haplotype access
        sample_ids, unique_indices = np.unique(np.array(sample_ids), return_index=True)

        sample_times = np.zeros(len(sample_ids), dtype=self.ancestors_time.dtype)
        for i, s in enumerate(sample_ids):
            sample = variant_data.sample(s)
            if sample.individual != tskit.NULL:
                sample_times[i] = variant_data.individual(sample.individual).time

        if epsilon is not None:
            epsilons = np.atleast_1d(epsilon)
            if len(epsilons) == 1:
                # all get the same epsilon
                epsilons = np.repeat(epsilons, len(sample_ids))
            else:
                if len(epsilons) != len(unique_indices):
                    raise ValueError(
                        "The number of epsilon values must equal the number of "
                        f"sample_ids ({len(sample_ids)})"
                    )
                epsilons = epsilons[unique_indices]

        else:
            anc_times = self.ancestors_time[:][::-1]  # find ascending time order
            older_index = np.searchsorted(anc_times, sample_times, side="right")
            # Don't include times older than the oldest ancestor
            allowed = older_index < self.num_ancestors
            if np.sum(allowed) > 0:
                delta_t = anc_times[older_index[allowed]] - sample_times[allowed]
            else:
                # All samples have times equal to or older than the oldest curr ancestor
                time_diffs = np.diff(anc_times)
                delta_t = np.min(time_diffs[time_diffs > 0])
            epsilons = np.repeat(np.min(delta_t) / 100.0, len(sample_ids))

        proxy_times = sample_times + epsilons
        time_sorted_indexes = np.argsort(proxy_times)
        reverse_time_sorted_indexes = time_sorted_indexes[::-1]
        # In cases where we have more than a handful of samples to use as proxies, it is
        # inefficient to access the haplotypes out of order, so we iterate and cache
        # (caution: the haplotypes list may be quite large in this case)
        haplotypes = [
            h[1] for h in variant_data.haplotypes(
                samples=sample_ids, sites=used_sites, recode_ancestral=True
            )
        ]

        new_anc_times = self.ancestors_time[:]  # this is a copy
        if not keep_ancestor_times:
            assert np.all(np.diff(self.ancestors_time) <= 0)
            # Find the youngest (max) ancestor ID constrained by each sample haplotype
            site_ancestor = -np.ones(self.num_sites, dtype=int)
            anc_min_time = np.zeros(self.num_ancestors, dtype=self.ancestors_time.dtype)
            # If (unusually) there are multiple ancestors for the same focal site, we
            # can take the youngest
            for ancestor_id, focal_sites in enumerate(self.ancestors_focal_sites):
                site_ancestor[focal_sites] = ancestor_id
            for hap_id in time_sorted_indexes:
                derived_sites = haplotypes[hap_id] > 0
                if np.sum(derived_sites) == 0:
                    root = 0  # no derived sites, so only needs to be below the root
                    for i, focal_sites in enumerate(self.ancestors_focal_sites):
                        if len(focal_sites) > 0:
                            if i > 0:
                                root = i - 1
                            anc_min_time[root] = proxy_times[hap_id] + epsilons[hap_id]
                            break
                else:
                    max_anc_id = np.max(site_ancestor[derived_sites])  # youngest ancstr
                    if max_anc_id >= 0:
                        anc_min_time[max_anc_id] = proxy_times[hap_id] + epsilons[hap_id]
            # Go from youngest to oldest, pushing up the times of the ancestors to
            # achieve compatibility with infinite sites.
            # TODO - replace with something mre efficient that uses time_diffs
            for anc_id in range(self.num_ancestors - 1, -1, -1):
                current_time = new_anc_times[anc_id]
                if anc_min_time[anc_id] > current_time:
                    new_anc_times[:(anc_id + 1)] += anc_min_time[anc_id] - current_time
            assert new_anc_times[1] > np.max(sample_times)  # root ancestor

        with self.__class__(  # Create new AncestorData instance to return
            variant_data.sites_position[:][used_sites],
            variant_data.sequence_length,
            **kwargs,
        ) as other:
            mutated_sites = set()  # To check if mutations have occurred yet
            ancestors_iter = self.ancestors()
            anc = next(ancestors_iter, None)
            for i in reverse_time_sorted_indexes:
                proxy_time = proxy_times[i]
                sample_id = sample_ids[i]
                haplotype = haplotypes[i]
                derived_sites = set(np.where(haplotype > 0)[0])
                while anc is not None and new_anc_times[anc.id] > proxy_time:
                    anc_time = new_anc_times[anc.id]
                    other.add_ancestor(
                        anc.start, anc.end, anc_time, anc.focal_sites, anc.haplotype)
                    mutated_sites.update(anc.focal_sites)
                    anc = next(ancestors_iter, None)
                if not derived_sites.issubset(mutated_sites):
                    assert not keep_ancestor_times
                    logging.info(
                        f"Infinite sites assumption deliberately broken: {sample_id}"
                        "contains an allele which requires a novel mutation."
                    )
                logger.debug(
                    f"Inserting proxy ancestor: sample {sample_id} at time {proxy_time}"
                )
                other.add_ancestor(
                    start=0,
                    end=self.num_sites,
                    time=proxy_time,
                    focal_sites=[],
                    haplotype=haplotype,
                )
            # Add any ancestors remaining in the current instance
            while anc is not None:
                anc_time = new_anc_times[anc.id]
                other.add_ancestor(
                    anc.start, anc.end, anc_time, anc.focal_sites, anc.haplotype,  
                )
                anc = next(ancestors_iter, None)

            other.clear_provenances()
            for timestamp, record in self.provenances():
                other.add_provenance(timestamp, record)
            other.record_provenance(command="insert_proxy_samples", **kwargs)

        assert other.num_ancestors == self.num_ancestors + len(sample_ids)
        return other

### simulations

simulations of a simple evolutionary dynamics in which genomes of a well-mixed population of fixed size {Ne} evolve under reproduction, (neutral) mutation at a fixed rate {mu} per base per generation, and homologous recombination at a rate {rr} per base per gen- eration

Base parameters

### Determining biologically plausible parameters
In the literature, recombination is usually expressed in fractions p/m (recombination rate/mutation rate).  
I will vary mutation rates across some range, and choose recombination rates such that i have a grid of p/m 0, 0.1, 0.01, 0.001, 0.3, 1, and 10. 

Sweeps parameters

In [7]:
#later: tract lengths
#gcrs = {}
#gcrls = {} 
#times = {} 

## simulation

In [8]:
gen_time_days = 1.0 
gen_per_year = 365.0 

def years_to_gen(years): 
    return years*gen_per_year

In [9]:
 def sim(ne, L, rr, mu, seed): 

    ts = msprime.sim_ancestry(
        samples = [
            msprime.SampleSet(20, time=0, ploidy = 1), 
            msprime.SampleSet(15, time=years_to_gen(5), ploidy = 1), 
            msprime.SampleSet(15, time=years_to_gen(10), ploidy = 1), 
            msprime.SampleSet(10, time=years_to_gen(25), ploidy = 1), 
            msprime.SampleSet(5, time=years_to_gen(50), ploidy = 1)
        ],
        sequence_length=L,
        recombination_rate=rr,    
        population_size=ne,
        random_seed=seed,
    )
    
    ts = msprime.sim_mutations(ts, rate=mu, random_seed = seed)

    dated_ts = tsdate.date(ts, 
                       mutation_rate=mu, # same mutation rate used for simulation 
                       time_units="generations", # dont want to switch this or nodes and samples will have different units
                       match_segregating_sites = True,
                       rescaling_intervals = 1
                       )
    
    samples = list(dated_ts.samples()) ######## return 

    # get ancestral states
    ancestral_states = []
    
    for site in ts.sites():
        if site.ancestral_state is None:
            ancestral_states.append("N")
            #print("N")
        else:
            ancestral_states.append(str(site.ancestral_state))
            #print(site.ancestral_state)
    
    ancestral_states = np.array(ancestral_states) ######### return

    return dated_ts, samples, ancestral_states  

In [10]:
def sim_sweep(mu_vals, pm_grid): 
    # create empty list.
    sims = [] #save each ts here 
    samples_list = [] 
    ancestral_states_list = []
    metadata = [] 
    count = 0
    counts = []
    times = [] 

    for mu in mu_vals:
        for pm in pm_grid:
            count += 1
            rr = mu * pm ####### change this so it's not 1.0000000001e-X 
            dated_ts, samples, ancestral_states = sim(ne, L, rr, mu)
            
            sims.append(dated_ts) 
            samples_list.append(samples) 
            ancestral_states_list.append(ancestral_states)
            times.append(dated_ts.nodes_time[0:65])

            metadata.append({"index": count-1, "rate": rr, "mu": mu})

            counts.append({"rr": rr, "mu": mu, "num_trees": dated_ts.num_trees, "diversity": dated_ts.diversity()}) 

            print(f"Finished inference {count}/{(len(mus)*len(pm_grid))}. RR: {rr}, MU: {mu}, num trees: {dated_ts.num_trees}")

    return pd.DataFrame(counts), pd.DataFrame(metadata), sims, samples_list, ancestral_states_list, times
      

## vcf export/import

In [11]:
def fmt_sci(x):
    s = f"{x:.3e}"
    s = s.replace("+0", "+").replace("-0", "-") # reformatting numbers
    return s

def vcz_name(prefix, mu, rr, seed):
    return f"{prefix}_mu{fmt_sci(mu)}_rr{fmt_sci(rr)}_seed{seed}.vcf.gz.icf.vcz"

# export
def check_zarr_store(p):
    # version differences? 
    return (
        os.path.isdir(p)
        and (
            os.path.exists(os.path.join(p, ".zgroup"))
            or os.path.exists(os.path.join(p, "zarr.json"))
        )
    )

def export_sim(prefix, ts, mu, rr, seed, workdir=".", force=False):
    os.makedirs(workdir, exist_ok=True)

    base = os.path.join(workdir, vcz_name(prefix, mu, rr, seed))             # ...vcf.gz.icf.vcz
    vcf_path   = base.replace(".vcf.gz.icf.vcz", ".vcf")
    vcfgz_path = base.replace(".icf.vcz", "")                                  # ...vcf.gz
    icf_path   = base.replace(".vcz", "")   # ...vcf.gz.icf

    if os.path.exists(base):
        return os.path.abspath(base)
    
    
    # if .vcz exists, skip all
    # if os.path.exists(base):
    #     if check_zarr_store(base):
    #         return os.path.abspath(base)

    if os.path.exists(base):
        if check_zarr_store(base):
            return os.path.abspath(base)
        if os.path.isdir(base):
            shutil.rmtree(base, ignore_errors=True)
        else:
            os.remove(base)

    # if .icf exists, encode .vcz
    if os.path.exists(icf_path):
        cmd = ["vcf2zarr", "encode", icf_path, base]
        if force: cmd.append("--force")
        subprocess.run(cmd, check=True)
        return os.path.abspath(base)

    # if .vcf.gz, explode -> .icf and encode -> .vcz
    if os.path.exists(vcfgz_path):
        cmd = ["vcf2zarr", "explode", vcfgz_path, icf_path]
        if force: cmd.append("--force")
        subprocess.run(cmd, check=True)
        cmd = ["vcf2zarr", "encode", icf_path, base]
        if force: cmd.append("--force")
        subprocess.run(cmd, check=True)
        return os.path.abspath(base)

    # create .vcf, compress 
    if not os.path.exists(vcfgz_path):
        with tempfile.NamedTemporaryFile("w", delete=False, dir=workdir) as tmp:
            tmp_vcf = tmp.name
            ts.write_vcf(tmp, position_transform=lambda x: np.fmax(1, x))
        # compress
        if shutil.which("bgzip"):
            subprocess.run(["bgzip", "-f", tmp_vcf], check=True)
            # bgzip makes tmp_vcf + .gz
            os.replace(tmp_vcf + ".gz", vcfgz_path)
        else:
            with open(tmp_vcf, "rb") as fin, gzip.open(vcfgz_path, "wb") as fout:
                fout.write(fin.read())
            os.remove(tmp_vcf)

    # explode vcf.gz to icf
    if not os.path.exists(icf_path):
        cmd = ["vcf2zarr", "explode", vcfgz_path, icf_path]
        if force: cmd.append("--force")
        subprocess.run(cmd, check=True)

    # encode icf to vzc
    if not os.path.exists(base):
        cmd = ["vcf2zarr", "encode", icf_path, base]
        if force: cmd.append("--force")
        subprocess.run(cmd, check=True)

    return os.path.abspath(base)


# import vcz -> variant data object
def import_sim(vcz_path, ancestral_states, individuals_time):
    import tsinfer
    vdata = tsinfer.VariantData(
        vcz_path,
        ancestral_state=np.asarray(ancestral_states),
        individuals_time=np.asarray(individuals_time),
    )
    return vdata

## inference

In [12]:
tsinfer.AncestorData.insert_proxy_samples = insert_proxy_samples

In [13]:
def run_proxy(vdata, mu, rr_value=None, mm_value=None):
    if rr_value is not None:
        rr = rr_value 
        mm = mm_value
    else: 
        rr = None
        mm = None

    anc = tsinfer.generate_ancestors(vdata)
    anc_proxy = anc.insert_proxy_samples(vdata)
    anc_proxy_ts = tsinfer.match_ancestors(vdata, anc_proxy, recombination_rate=rr, mismatch_ratio=mm)
    ts_proxy = tsinfer.match_samples(vdata, anc_proxy_ts, force_sample_times=True, recombination_rate=rr)
    simplified_proxy = tsdate.preprocess_ts(ts_proxy, erase_flanks=False)

    dated_proxy = tsdate.date(
        simplified_proxy,
        mutation_rate=mu,
        time_units="generations",
        match_segregating_sites=True,
        rescaling_intervals=1,
    )
    return rr, mm, dated_proxy
     

### get tree width

In [14]:
def get_intervals(dated_ts): 

    data = []
    
    for tree in dated_ts.trees():
        left, right = tree.interval
        data.append({
            "tree_index": tree.index,
            "left": left,
            "right": right,
        })
    
    intervals = pd.DataFrame(data)

    return intervals


### get mrcas

Returns a dataframe of pairwise MRCAs for every combination of samples, for each tree within the simulated tree sequence.

In [15]:
def get_sims_times(dated_ts, samples):
    res = []

    for a, b in combinations(samples, 2):
        for i in range(0, dated_ts.get_num_trees()):
            t = dated_ts.at_index(i).tmrca(a, b)
            w = dated_ts.at_index(i).interval.right - dated_ts.at_index(i).interval.left
            l = dated_ts.at_index(i).interval.left
            r = dated_ts.at_index(i).interval.right
            num_trees = dated_ts.num_trees
            res.append({"index": i, "sample_a": a, "sample_b": b, "mrca": t, "width": w, "left": l, "right": r, "num_trees_sim": num_trees})
            
    mrcas = pd.DataFrame(res) 
    return mrcas

In [16]:
def get_res_times(dated_ts, samples):
    res = []

    for a, b in combinations(samples, 2):
        for i in range(1, dated_ts.get_num_trees()-1): #flanking trees are not informative and must be trimmed
            t = dated_ts.at_index(i).tmrca(a, b)
            w = dated_ts.at_index(i).interval.right - dated_ts.at_index(i).interval.left
            l = dated_ts.at_index(i).interval.left
            r = dated_ts.at_index(i).interval.right
            num_trees = dated_ts.num_trees
            res.append({"index": i, "sample_a": a, "sample_b": b, "mrca": t, "width": w, "left": l, "right": r, "num_trees": num_trees})
            
    mrcas = pd.DataFrame(res) 
    return mrcas

### varying recombination rate and mismatch_ratio during inference

Produces an array of recombination rates and mismatch ratios , then calls run_inference to infer tree sequences under varying parameters. Returns counts (for indexing), metadata (tree index, rr and mm values), and seqs(a list of tree sequences). 

In [17]:
def rr_mm(vdata, mu): 

    # pm_grid = np.array([1e-3, 1e-2, 0.1, 0.3, 1.0, 10.0]) # bioligically plausible r/m sweep
    pm_grid = np.array([1e-3, 1e-2, 0.1, 0.3, 1.0])
    rates = list(mu * pm_grid)
    
    mms = [10**x for x in range(-3, 1, 1)]

    n_r, n_m = len(rates), len(mms)
    grid_num_trees = np.full((n_r, n_m), np.nan)
    
    seqs = [] #save each ts here 
    
    count = 0
    nones = 0

    metadata = [] 
    counts = []

    print(f"Begin inference for simulated mu = {mu}.") 
    for rr_idx, rr_value in enumerate(rates):
        for mm_idx, mm_value in enumerate(mms): 

            if nones > 0: 
                break

            if rr_value == None:
                mm_value = 1
                nones += 1 
            
            count+=1
            
            rr, mm, ip = run_proxy(vdata, rr_value, mm_value)
            
            grid_num_trees[rr_idx][mm_idx] = ip.num_trees
            
            seqs.append(ip) 

            metadata.append({"index": count-1, "rate": rr, "mm": mm})

            counts.append({"rate": rr, "mm": mm, "num_trees": ip.num_trees}) 

            print(f"Finished inference {count}/{(len(rates)*len(mms)+1)}. RR: {rr_value}, MM: {mm_value}, num trees: {ip.num_trees}")

    return pd.DataFrame(counts), pd.DataFrame(metadata), seqs 


# core functions

In [None]:
# get rr from pm grid
def rr_for(mu, pm): 
    return pm * mu

In [None]:
def make_edges(L, bin_size):
    return np.arange(0, int(L) + bin_size, bin_size, dtype=np.int64)

def add_bins(df, positions):
    out = []
    
    for i, pos in enumerate(positions):
        mask = (df["left"] <= pos) & (pos < df["right"])   # half-open [left, right) like tskit intervals
        if mask.any():
            tmp = df.loc[mask].copy()
            tmp["bin"] = i
            tmp["position"] = int(pos)
            out.append(tmp)
            
    return pd.concat(out, ignore_index=True)

def align_mrcas(sim_df, inf_df, positions):
    sim_b = add_bins(sim_df, positions)
    inf_b = add_bins(inf_df, positions)
    
    merged = pd.merge(
        inf_b, sim_b,
        on=["sample_a", "sample_b", "bin"],
        suffixes=("_inf", "_sim")
    )
    
    return merged

def r2_log1p(x, y):
    x = np.log1p(x)
    y = np.log1p(y)
    r = np.corrcoef(x, y)[0, 1]

    return r**2

def r2_by_bin(merged):
    return (
        merged
        .groupby("bin", sort=False)
        .apply(lambda group: r2_log1p(group["mrca_inf"], group["mrca_sim"]))
        .dropna()
    )

evaluating tree sequences 

In [18]:
# random permuations of bin order 

def random_permutations(n_bins, n_reps, seed):
    rng = np.random.default_rng(seed)
    I = np.arange(n_bins, dtype=np.int32)
    out = []
    
    while len(out) < n_reps:
        p = rng.permutation(n_bins).astype(np.int32)
        if np.any(p != I):        # allow some fixed points; switch to np.all for strict
            out.append(p)
            
    return out

def apply_bin_perm(merged_base, perm):
    df = merged_base.copy()
    key = df[["sample_a","sample_b","bin","mrca_inf"]].rename(
        columns={"bin":"src_bin","mrca_inf":"mrca_inf_src"}
    )
    
    mapper = {i:int(j) for i,j in enumerate(perm)}
    df = df.assign(src_bin=df["bin"].map(mapper))
    df = df.merge(key, on=["sample_a","sample_b","src_bin"], how="left", validate="many_to_one")
    df["mrca_inf"] = df["mrca_inf_src"].to_numpy()
    
    return df.drop(columns=["mrca_inf_src","src_bin"])

simulate + infer one-by-one

In [18]:
# get dated ts, export VCZ
def simulate_one(mu, rr, seed, prefix="sim"):
    # your sim() returns a dated tree sequence already
    dated_ts, samples, ancestral_states = sim(ne, L, rr, mu, seed = seed)  # ground truth ts
    vcz_path = export_sim(prefix, dated_ts, mu, rr, seed, workdir="vcf_stage")
    
    return dated_ts, samples, ancestral_states, vcz_path

# use simulated ts to make mrca table
def mrca_table_sim(ts_sim):
    # your MRCA extraction; rename 'mrca' to 'mrca_sim'
    sim_df = get_sims_times(ts_sim, list(ts_sim.samples()))
    
    return sim_df.rename(columns={"mrca": "mrca_sim"})

# re-import genotypes, run inference, compute inferred MRCA
def infer_one(vcz_path, ancestral_states, individuals_time, mu, rr, seed, mm=1):
    vdata = import_sim(vcz_path, ancestral_states, individuals_time)
    rr_used, mm_used, dated_proxy = run_proxy(vdata, mu, rr_value=rr, mm_value=mm)
    inf_df = get_res_times(dated_proxy, list(dated_proxy.samples()))
    
    return inf_df.rename(columns={"mrca": "mrca_inf"})

for parallelizing

In [None]:
def run_cell(mu, pm, rep, seed):
    rr = pm * mu

    # simulate + export VCF/VCZ
    ts, samples, ancestral_states, vcz_path = simulate_one(mu, rr, seed=seed)

    # build truth table from the simulated ts
    edges = make_edges(int(ts.sequence_length), BIN_SIZE)
    sim_df = mrca_table_sim(ts)

    # per-individual times for import
    sample_nodes = np.array(list(ts.samples()), dtype=int)
    individuals_time = ts.tables.nodes.time[sample_nodes]

    # inference w/ re-imported genomic data
    inf_df = infer_one(vcz_path, ancestral_states, individuals_time, mu=mu, rr=rr, seed=seed, mm=1)

    ##sim_df and inf_df have num_trees .... extract
    n_trees_sim = int(sim_df["num_trees_sim"].iloc[0])
    n_trees_inf = int(inf_df["num_trees"].iloc[0])


    # align + R2
    merged = align_mrcas(sim_df, inf_df, edges)
    r2_inf = r2_by_bin(merged)
    r2_vals = np.asarray(r2_inf.to_numpy()).ravel()
    bins = np.asarray(r2_inf.index.to_numpy(), dtype=np.int32).ravel()
    n = len(r2_vals)

    #print(r2_vals)
    
    inf_rows = pd.DataFrame({
        "mu":   np.full(n, mu, dtype=float),
        "pm":   np.full(n, pm, dtype=float),
        "rr":   np.full(n, rr, dtype=float),
        "rep":  np.full(n, rep, dtype=int),
        "bin":  bins,
        "r2":   r2_vals,
        "kind": np.full(n, "inferred", dtype=object),
        "n_trees_sim": np.full(n, n_trees_sim, dtype=int),
        "n_trees_inf": np.full(n, n_trees_inf, dtype=int) 
    })


    # null: reuse merged; permute bins (no re-merge)
    null_rows = []
    if N_NULL > 0:
        perms = random_permutations(n_bins=len(edges)-1, n_reps=N_NULL, seed=seed)
        for p in perms:
            m_perm = apply_bin_perm(merged, p)
            r2_null = r2_by_bin(m_perm)
            if len(r2_null):
                null_rows.append(pd.DataFrame({
                    "mu": mu, 
                    "pm": pm, 
                    "rr": rr, 
                    "rep": rep,
                    "bin": r2_null.index.astype(np.int32),
                    "r2": r2_null.values,
                    "kind": "null"
                }))
    null_df = pd.concat(null_rows, ignore_index=True) if null_rows else pd.DataFrame(
        columns=["mu","pm","rr","rep","bin","r2","kind"]
    )
    return inf_rows, null_df

In [19]:
# config
MUS     = [1e-10, 1e-9, 1e-8, 1e-7]   # 1e-6]            # per-site, per-gen
PM_GRID = [1e-3, 1e-2, 0.1, 0.3, 1.0, 3.0]      #1.0] #, 10.0]         # p/m ratios
BASE_SEED  = 50
BIN_SIZE   = 100_000
N_REPS     = 5 # per (mu, pm)
N_NULL     = 10 # null perms per replicate
#seed = 50
MAX_WORKERS = max(1, (os.cpu_count() or 2) - 1)

ne = 5000
L  = int(3e6)

In [20]:
def sim_sweep_parallel(mus, pm_grid, base_seed=BASE_SEED, max_workers=MAX_WORKERS):
    futures = []
    total = len(mus) * len(pm_grid) * N_REPS
    desc = f"Running {total} parameter combos"
    
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        for i, mu in enumerate(mus):
            for j, pm in enumerate(pm_grid):
                for rep in range(N_REPS):
                    seed = base_seed + i*10000 + j*100 + rep*3
                    futures.append(ex.submit(run_cell, mu, pm, rep, seed))

        inf_parts, null_parts = [], []
        with tqdm(total=total, desc=desc, ncols=100) as pbar:
            for fut in as_completed(futures):
                try:
                    a, b = fut.result()
                    inf_parts.append(a); null_parts.append(b)
                except Exception as e:
                    warnings.warn(f"worker failed: {e!r}")
                finally:
                        pbar.update(1)

    inf_all  = pd.concat(inf_parts,  ignore_index=True) if inf_parts  else pd.DataFrame()
    null_all = pd.concat(null_parts, ignore_index=True) if null_parts else pd.DataFrame()
    return inf_all, null_all

## run all in parallel

In [21]:
# run 
inf_all, null_all = sim_sweep_parallel(MUS, PM_GRID)

Running 2 parameter combos:   0%|                                             | 0/2 [00:00<?, ?it/s]

    Scan: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.44s/files]
    Scan: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.45s/files]
 Explode: 65.0vars [00:00, 218vars/s]]
 Explode: 607vars [00:00, 1.58kvars/s]
  Encode: 100%|██████████| 15.0k/15.0k [00:00<00:00, 22.1kB/s]
  Encode: 100%|██████████| 140k/140k [00:00<00:00, 185kB/s] 
Finalise: 100%|██████████| 11.0/11.0 [00:00<00:00, 67.7array/s]
Finalise: 100%|██████████| 11.0/11.0 [00:00<00:00, 67.6array/s]


In [22]:
inf_all.to_csv('inf_all_reps.csv', index=False)

In [23]:
null_all.to_csv('null_all_reps.csv', index=False)