In [None]:
import csv
import io
from itertools import product
from pathlib import Path

import palettable
import demes
import demesdraw
import matplotlib as mpl
import matplotlib.pyplot as plt
import moments
import msprime
import numpy as np
import polars as pl
import tskit
from spatial import ld_decay, ld_decay_two_way

In [None]:
def midpoint(bins):                                                                                                    
    return (bins[1:] + bins[:-1]) / 2

In [None]:
def simpson(edge, mid):
    return (edge[:-1] + 4 * mid + edge[1:]) / 6


def gather_moments_data_demog(rho, theta, bins, demog, sampling_time):
    edges_result = moments.Demes.LD(
        demog,
        sampled_demes=["A", "B"],
        sample_times=[sampling_time, sampling_time],
        rho=rho * bins,
        theta=theta,
    )
    mids_result = moments.Demes.LD(
        demog,
        sampled_demes=["A", "B"],
        sample_times=[sampling_time, sampling_time],
        rho=rho * midpoint(bins),
        theta=theta,
    )

    mids_ld_stats = np.vstack(mids_result[:-1])
    mids_D2_cross = mids_ld_stats[:, mids_result.names()[0].index("DD_0_1")]
    mids_pi2_1 = mids_ld_stats[:, mids_result.names()[0].index("pi2_0_0_0_0")]
    mids_pi2_2 = mids_ld_stats[:, mids_result.names()[0].index("pi2_1_1_1_1")]
    edges_ld_stats = np.vstack(edges_result[:-1])
    edges_D2_cross = edges_ld_stats[:, edges_result.names()[0].index("DD_0_1")]
    edges_pi2_1 = edges_ld_stats[:, edges_result.names()[0].index("pi2_0_0_0_0")]
    edges_pi2_2 = edges_ld_stats[:, edges_result.names()[0].index("pi2_1_1_1_1")]

    D2 = simpson(edges_D2_cross, mids_D2_cross)
    pi2 = simpson(
        np.sqrt(edges_pi2_1) * np.sqrt(edges_pi2_2),
        np.sqrt(mids_pi2_1) * np.sqrt(mids_pi2_2),
    )

    return D2, pi2

In [None]:
def plot_moments(
    sampling_times, bins, moments_results, title_add="", clip=np.inf, ax=None
):
    moments_results_sigma_d2 = [D2 / pi2 for D2, pi2 in moments_results]
    x = midpoint(bins)
    mask = x < clip
    if ax is None:
        for r, st in zip(moments_results_sigma_d2, sampling_times):
            plt.plot(x[mask] * rho, r[mask], label=st)
        plt.xscale("log")
        plt.yscale("log")
    else:
        for r, st in zip(moments_results_sigma_d2, sampling_times):
            ax.plot(x[mask] * rho, r[mask], label=st)
        if title_add:
            ax.set_title(title_add)
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.legend()


def plot_moments_groups(
    sampling_times, bins, data, labels, suptitle="", supx="", supy=""
):
    assert len(data) == len(labels)
    fig, axes = plt.subplots(
        1, len(data), figsize=(4.5 * len(data), 4.5), sharey="col", layout="constrained"
    )
    for d, l, ax_row in zip(data, labels, axes):
        plot_moments(sampling_times, bins, d, title_add=l, clip=1e4, ax=ax_row)
    if suptitle:
        fig.suptitle(suptitle)
    if supx:
        fig.supxlabel(supx)
    if supy:
        fig.supylabel(supy)

In [None]:
def midpoint(bins):
    return (bins[1:] + bins[:-1]) / 2


def plot_decays(times, decays, ax=None, clip=np.inf, title_add=""):
    if ax is None:
        fig, ax = plt.subplots(1, 2, figsize=(10, 6))
    for t, (b, d) in zip(times, zip(*avg_decay(decays))):
        x = midpoint(b)
        mask = x < clip
        ax.plot(x[mask] * rho, d[mask], label=f"time={t}")
    ax.legend()
    ax.set_xscale("log")
    ax.set_yscale("log")
    if title_add:
        ax.set_title(title_add)
    if ax is None:
        fig.supxlabel(r"$\rho$")
        fig.supylabel("$r_x r_y$")
    # plt.suptitle("LD Decay at sampling points post split" + title_add)
    # fig.tight_layout()

In [None]:
def plot_group(sampling_times, data, labels, suptitle="", supx="", supy=""):
    assert len(data) == len(labels)
    fig, axes = plt.subplots(
        1, len(data), figsize=(4.5 * len(data), 4.5), sharey="col", layout="constrained"
    )
    for d, l, ax_row in zip(data, labels, axes):
        plot_decays(sampling_times, d, title_add=l, clip=1e4, ax=ax_row)
    if suptitle:
        fig.suptitle(suptitle)
    if supx:
        fig.supxlabel(supx)
    if supy:
        fig.supylabel(supy)

In [None]:
def run_msprime(sampling_times, mut_rate, demog):
    tss = msprime.sim_ancestry(
        samples=[
            msprime.SampleSet(40, population=p, time=t)
            for p, t in product(["A", "B"], sampling_times)
        ],
        sequence_length=L,
        recombination_rate=r,
        demography=msprime.Demography.from_demes(demog),
        random_seed=SEED,
        num_replicates=10,
    )
    return [msprime.sim_mutations(ts, rate=mut_rate, random_seed=SEED) for ts in tss]


def compute_decay_two_way(tss, sampling_times, stat="r2"):
    for ts in tss:
        a_ss = [ts.samples(1, time=t) for t in sampling_times]
        b_ss = [ts.samples(2, time=t) for t in sampling_times]
        out = []
        for a, b in zip(a_ss, b_ss):
            out.append(
                ld_decay_two_way(
                    ts,
                    max_dist=100_000,
                    win_size=100,
                    chunk_size=100,
                    n_threads=18,
                    stat=stat,
                    sample_sets=[a, b],
                )
            )
        yield out


def compute_decay(tss, sampling_times, stat="r2"):
    for ts in tss:
        a_ss = [ts.samples(1, time=t) for t in sampling_times]
        # b_ss = [ts.samples(2, time=t) for t in sampling_times]
        out = []
        for a in a_ss:
            out.append(
                ld_decay(
                    ts,
                    max_dist=100_000,
                    win_size=100,
                    chunk_size=100,
                    n_threads=18,
                    stat=stat,
                    sample_sets=[a],
                )
            )
        yield out


def avg_decay(decays):
    outs = []
    mean = np.dstack([[d for _, _, d in r] for r in decays]).mean(2)
    bins = np.vstack([r[0][0] for r in decays])
    return bins, mean

In [None]:
def save_result(filename, sampling_times, decay):
    import polars as pl

    out = dict()
    for rep, data in enumerate(decay):
        for t_i, (b, c, d) in enumerate(data):
            dcy = np.insert(d, 0, np.nan)
            for k, v in [
                (f"bins_{rep}_{(t := sampling_times[t_i])}", b),
                # (f"count_{rep}_{t}", c),
                (f"decay_{rep}_{t}", dcy),
            ]:
                out[k] = v
    return pl.DataFrame(out)

In [None]:
def get_demes(Ne, mig):
    if mig == 0:
        migrations = ""
    elif mig > 0:
        migrations = f"""\
migrations:
  - demes: [A, B]
    rate: {mig}
            """
    else:
        raise Exception("migration rate cannot be negative")
    return demes.load(
        io.StringIO(
            f"""
time_units: generations
defaults:
  epoch:
    start_size: {Ne}
demes:
  - name: X
    epochs:
      - end_time: 5000
  - name: A
    ancestors: [X]
  - name: B
    ancestors: [X]
{migrations}
            """
        )
    )

In [None]:
SAMPLING_TIMES = np.arange(4999, 4999 - 200 * 20 + 1, -200)
SEED = 23
Ne = 4_000
r = 1e-8
L = 1e8
mu = 1e-8
rho = 4 * Ne * r

In [None]:
def gen_params(filename):
    i = 0
    samp_times = [
        ":".join(map(str, r))
        for r in np.array_split(
            np.clip(
                np.arange(4999, 4999 - 200 * 26 + 1, -200), a_min=0, a_max=np.inf
            ).astype(np.int64),
            7,
        )
    ]
    with open(filename, "w") as fp:
        writer = csv.writer(fp)
        for stat in ["D2_unbiased", "pi2_unbiased", "r2"]:
            reps = 15
            for mig in [0, 1e-1, 1e-2, 1e-3, 1e-4]:
                for s in samp_times:
                    writer.writerow([i, s, stat, reps, mig])
                    i += 1
# gen_params('params.csv')

In [None]:
df = pl.DataFrame()
for p in Path(
    "/home/lkirk/simulation-outputs/spatial-ld-final/msprime-ld-sims/output-1748688/result"
).glob("*.parquet"):
    df.hstack(
        pl.read_parquet(p).rename(
            lambda s: p.with_suffix("").name.split("-")[0] + "-" + s
        ),
        in_place=True,
    )

Column names: sim num, rep, t, num pops

In [None]:
(
    pl.Series(df.columns)
    .str.split("-")
    .list.to_struct(fields=["sim_num", "info", "num_pops"])
    .struct.unnest()
    .with_columns()
)

In [None]:
meta = (
    pl.DataFrame({"col": df.columns})
    .with_columns(
        split=pl.col.col.str.split_exact("-", 2).struct.rename_fields(
            ["sim_num", "info", "num_pops"]
        )
    )
    .unnest("split")
    .with_columns(
        pl.col.info.str.split_exact("_", 2).struct.rename_fields(
            ["dname", "rep", "time"]
        )
    )
    .unnest("info")
    .with_columns(
        pl.col.sim_num.cast(pl.Int64),
        pl.col.rep.cast(pl.Int32),
        pl.col.time.cast(pl.Int32),
        pl.col.num_pops.cast(pl.Int32),
    )
    .join(
        pl.read_csv(
            "params.csv",
            has_header=False,
            new_columns=["sim_num", "sampling_times", "stat", "reps", "mig"],
        ),
        on="sim_num",
    )
    .drop("sampling_times", "reps")
)

In [None]:
15 * 2 * 26 * 5, 15 * 2 * 26 * 5 * 3 * 2

In [None]:
df.shape

In [None]:
meta.filter(pl.col.dname == "decay").group_by("stat").agg(pl.col.sim_num.count())

In [None]:
meta.filter(pl.col.stat == "r2", pl.col.num_pops == 2)

In [None]:
df[:, ["16-bins_5_2799-2", "16-decay_5_2799-2"]]

In [None]:
meta.filter(pl.col.sim_num == 16)

In [None]:
# df[:, [c for c in df.columns if "bins" in c]].transpose().null_count()

In [None]:
bins = midpoint(df.filter(pl.col("16-decay_10_3199-2").is_not_null())[:, 0]).rename("bins")
# mask = bins < 0.7e5

In [None]:
1.5e4

In [None]:
1e4

In [None]:
decays = df.filter(pl.col("16-decay_10_3199-2").is_not_null())[
    :, [c for c in df.columns if "decay" in c]
][1:]

In [None]:
decay_meta = meta.filter(pl.col.dname == "decay", pl.col.time.is_in(decay_meta['time'].unique()[0:20])).drop("dname")

In [None]:
# decay_meta.sort("num_pops").partition_by("num_pops", as_dict=True, maintain_order=True)[(1,)].filter(pl.col.stat == "r2", pl.col.time == 3799).partition_by("mig")

In [None]:
demesdraw.tubes(get_demes(Ne, 'low'))

In [None]:
def plot_stat(ax, _meta, cmap, norm, stat):
    if stat == "r2":
        for (t,), m_df in (
            _meta
            .sort("time")
            .partition_by("time", maintain_order=True, as_dict=True)
            .items()
        ):
            s = decays[:, m_df.filter(pl.col.stat == "r2")["col"]].mean_horizontal()
            ax.plot(bins.filter(mask) * rho, s.filter(mask), color=cmap(norm(t)))
    if stat == "sigmad2":
        for (t,), m_df in (
            _meta
            .sort("time")
            .partition_by("time", maintain_order=True, as_dict=True)
            .items()
        ):
            s = (decays[:, m_df.filter(pl.col.stat == "D2")["col"]] / decays[:, m_df.filter(pl.col.stat == "pi2")["col"]]).mean_horizontal()
            # print(decays[:, m_df.filter(pl.col.stat == "D2")["col"]])
            # print(decays[:, m_df.filter(pl.col.stat == "pi2")["col"]])
            # print('-------------')
            ax.plot(bins.filter(mask) * rho, s.filter(mask), color=cmap(norm(t)))
    # print(decays[:, m_df['col']].mean_horizontal())
    # plt.show()

def plot_all():
    times = decay_meta['time'].unique()
    cmap = mpl.colors.ListedColormap(palettable.scientific.diverging.Roma_20.mpl_colors)
    cmap_r = mpl.colors.ListedColormap(
        palettable.scientific.diverging.Roma_20.mpl_colors
    )
    bounds = np.linspace(times.min(), times.max(), len(times) + 1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    plot_size_x, plot_size_y = (6,6)
    # fig, axes = plt.subplots(4, 5, figsize=(4 * 5, 4 * 4), layout='constrained', sharex=True, sharey=True)
    # fig, axes = plt.subplots(2, 5, figsize=(4 * 5, 4 * 2), layout='constrained', sharex=True, sharey=True)
    fig, axes = plt.subplots(2, 3, figsize=(plot_size_x * 3, plot_size_y * 2), layout='constrained', sharex=True, sharey=True)
    for i, (((num_pops,), _m), stat) in enumerate(product(decay_meta.sort("num_pops").partition_by("num_pops", as_dict=True, maintain_order=True).items(), ["r2"])):
    # for i, (((num_pops,), _m), stat) in enumerate(product(decay_meta.sort("num_pops").partition_by("num_pops", as_dict=True, maintain_order=True).items(), ["r2", "sigmad2"])):
        # for j, ((mig,), m) in enumerate(_m.sort("mig").partition_by("mig", as_dict=True, maintain_order=True).items()):
        for j, ((mig,), m) in enumerate(_m.sort("mig").filter(~pl.col.mig.is_in({0.0, 0.001})).partition_by("mig", as_dict=True, maintain_order=True).items()):
            ax = axes[i, j]
            # ax = axes[i, j - 1]
            plot_stat(ax, m, cmap, norm, stat)
            ax.set_title(f"mig: {mig} stat: {stat} n_pops: {num_pops}")
    cbar = plt.colorbar(
        plt.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=axes[:, -1],
        location="right",
    )
    cbar.ax.set_yticks(
        midpoint(np.linspace(times.min(), times.max(), len(times) + 1))
    )
    cbar.ax.set_yticklabels(times)
    plt.xscale("log")
    plt.yscale("log")

# mask = bins < 2e4
mask = bins < np.inf
plot_all()