In [19]:
import pyslim
import tskit
import numpy as np
import tqdm

In [20]:
def fullfill_pyslim(tree_seq, recombrate, pop_size_outcrossing, nsam):
    """recap and sample to 1 haplotype per individual
    
    Args:
        ts: pyslim tree sequence
        
    Return:
    tree sequence
    """
    # recaptitate; only for the use of the summarizing function; no mutating
    recap = tree_seq.recapitate(
        recombination_rate=recombrate, Ne=pop_size_outcrossing
    )
    
    ts_1hapInd = sample_1perInd(recap)  # sample a single haplotype per individual
    
    ts_1perInd_nsam = random_sample_from_treeseq(ts_1hapInd, sample_size=nsam)
    
    fin_ts = ts_1perInd_nsam.simplify(reduce_to_site_topology=True)
    
    return fin_ts

In [21]:
def sample_1perInd(my_ts, seed=None):
    """Simplify to one haplotype per individual

    Intended to use on pyslim treeseq, we sample only 1 single
    haplotype per individual.

    Args:
        my_ts: tree sequence from pyslim

    Returns:
        simplified tree sequence with a single haplotype per individual
    """
    np.random.seed(seed)

    list_of_1hap_samples = []
    for i in my_ts.individuals():
        list_of_1hap_samples.append(np.random.choice(i.nodes))
    list_of_1hap_samples = np.array(list_of_1hap_samples)

    my_simple_ts = my_ts.simplify(samples=list_of_1hap_samples)

    return my_simple_ts

In [22]:
def random_sample_from_treeseq(my_ts, sample_size):
    """Get a random subsample from the provided tree sequence

    Args:
        my_ts: a tree sequence
        sample_size: the number of samples to choose randomly from

    Returns:
        simplified tree sequence
    """
    sample_nodes = [i for i in my_ts.samples()]
    chosen_samples = np.random.choice(sample_nodes, sample_size, replace=False)
    simpel_ts = my_ts.simplify(samples=chosen_samples)
    return simpel_ts

In [23]:
# read tree seqs for 5 independent loci
tss = [fullfill_pyslim(pyslim.load(ts), snakemake.params.recombrate, snakemake.params.pop_size_outcrossing, snakemake.params.nsam) for ts in snakemake.input.tss]

In [26]:
with open(snakemake.output.mhs, "w") as outf, open(snakemake.log.std, "w") as logf:
    for chr_identifier, ts in enumerate(tss, start=1):
        previous_site_position = 1
        nmultiallelics = 0
        for variant in tqdm.tqdm(ts.variants(), total=ts.num_sites):
            # haplotypes
            haplotypes = "".join(["A" if i == 0 else "T" if i == 1 else "-"
                for i in variant.genotypes])
            
            if len(set(haplotypes)) == 1:
                continue
            
            # num_called_sites since the last heterozygous site
            num_called_sites = int(round(variant.position, 0)) - int(round(
                previous_site_position, 0))
            previous_site_position = int(round(variant.position, 0))

            # do not allow for multiallelic sites
            if not num_called_sites:
                nmultiallelics += 1
                continue

            print(chr_identifier, int(round(variant.position, 0)),
                num_called_sites, haplotypes, sep="\t", end="\n", file=outf)
            
        print(f"Chr {chr_identifier}: {nmultiallelics} multiallelic sites", file=logf)