In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import sc2ts


In [None]:
data_dir = Path("../arg_postprocessing")

In [None]:
# NB: replace this with the final path to the recombinants csv file

recomb_file = data_dir / "sc2ts_v1_2023-02-21_pp_dated_remapped_bps_pango_recombinants_rebar_matches_pangonet_nsl.csv"
recomb_df = pd.read_csv(recomb_file).set_index("sample_id")


In [None]:
df_hq = recomb_df[recomb_df.net_min_supporting_loci_lft_rgt_ge_4].copy()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import adjustText


def draw_averted_scatterplot(df, ax, jitter_seed=2, jitter_width=0.25):
    np.random.seed(jitter_seed)
    j1 = np.random.uniform(-jitter_width, jitter_width, len(df))
    j2 = np.random.uniform(-jitter_width, jitter_width, len(df))

    colorpal = ["gray", "darkviolet"]
    patches = [
        mpatches.Patch(color=colorpal[0], label="Non-recombinant"),
        mpatches.Patch(color=colorpal[1], label="Recombinant"),
    ]

    x = df.num_mutations_k1000
    y = df.num_mutations_k4

    ax.scatter(
        (1 + x) + j1,
        (1 + y) + j2,
        alpha=0.5,
        s=10 * (np.log(df.num_descendant_samples + 1)),
        c=[colorpal[int(x)] for x in df.is_rebar_recombinant],
        edgecolors=None,
    )
    ax.set_xlabel("Mutations (no recombination)", fontsize=16)
    ax.set_ylabel("Mutations (recombination, $k=4$)", fontsize=16)

    ax.legend(
        handles=patches,
        title="Rebar classification",
        title_fontsize=16,
        fontsize=14,
        loc="upper right",
        frameon=False,
    )

    texts = []
    for i, (xx, yy, row) in enumerate(zip(x, y, df.itertuples())):
        label=str(row.sample_pango)
        if label == "B.1.617.2":
            label = "Delta"
        if label.startswith("X"):
            texts.append(ax.text((1 + xx) + j1[i], (1 + yy) + j2[i], label, fontsize=14))
        elif yy >= 4 or xx > 18:
            # Pango label will be misleading if only 1 non-identified descendant (e.g. BA.1)
            if row.num_descendant_samples > 1:
                texts.append(ax.text((1 + xx) + j1[i], (1 + yy) + j2[i], label, fontsize=14))

    ax.plot([0.1, 10, 100], [0.1, 10, 100], c="gray")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlim(3.9, 42)
    ax.set_ylim(0.6, 42)
    #ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_label_position('top')
    ax.yaxis.set_label_position('right')
    ax.invert_yaxis()
    ax.minorticks_off()
    xticks = [5, 6, 7, 8, 9, 10, 11, 21, 31, 41]
    yticks = [1, 2, 3, 4, 5, 6, 7, 9, 11, 21, 31, 41]
    ax.set_xticks(ticks = xticks, labels = [t - 1 for t in xticks], fontsize=14)
    ax.set_yticks(ticks = yticks, labels = [t - 1 for t in yticks], fontsize=14)
    ax.tick_params(
        axis='x',   # Apply changes to the x-axis
        top=True, # Show ticks on the top side
        labeltop=True,    # Show tick labels on the top side
        bottom=False, # Hide ticks on the bottom side
        labelbottom=False,    # Hide tick labels on the bottom side
    )
    ax.tick_params(
        axis='y',   # Apply changes to the y-axis
        right=True, # Show ticks on the right side
        labelright=True,    # Show tick labels on the right side
        left=False, # Hide ticks on the left side
        labelleft=False,    # Hide tick labels on the left side
    )

    adjustText.adjust_text(
        texts,
        arrowprops=dict(
            relpos=(0.5, 0.0),
            arrowstyle="-",
            color="gray",
            shrinkA=10,
            lw=1,
        ),
        ax=ax,
    )


# Check it looks OK
draw_averted_scatterplot(df_hq, plt.gca());

In [None]:
from PIL import Image
import imgkit  # To convert the HTML table to a PNG. Also needs wkhtmltox to be installed
import io
import tszip

ts = tszip.load("../data/sc2ts_v1_2023-02-21_pp_dated_remapped_bps_pango_mmps.trees.tsz")

# Define RE nodes for which to show a copying pattern
copying_images = {
    "XA": {'id': 122444},
    "XBB": {'id': 1396207},
    "XZ+": {'id': 964555},
    "Delta": {'id': 200039},
    "BA.2": {'id': 822854},
}

# This loop takes a little time to do the HTML copying pattern table -> PNG conversion
for label, val in copying_images.items():
    img = imgkit.from_string(
        sc2ts.info.CopyingTable(ts, val['id']).html(hide_extra_rows=True, hide_labels=True),
        False,  # return the bytes, rather than saving to file
        options={"width": 2000})
    val['img'] = np.asarray(Image.open(io.BytesIO(img)))

In [None]:
def draw_copying_patterns(copying_images, ax, x_scale=1.45):
    row_pos = [0.3, 0, -0.3]
    copying_images["XA"]["pos"] = [0.0, row_pos[0], x_scale, 1]
    copying_images["XBB"]["pos"] = [0.53, row_pos[0], x_scale, 1]   # Manually adjusted
    copying_images["XZ+"]["pos"] = [0.0, row_pos[1], x_scale, 1]
    copying_images["Delta"]["pos"] = [0.71, row_pos[1], x_scale, 1]  # Manually adjusted
    copying_images["BA.2"]["pos"] = [0.0, row_pos[2], x_scale, 1]

    ax.axis("off")
    for label, val in copying_images.items():
        ax_image = ax.inset_axes(val["pos"])
        ax_image.imshow(val["img"])
        ax_image.axis('off')  # Remove axis of the image
        ax_image.text(5, 40, label, fontsize=15, ha="right")

In [None]:
df_hq['num_mutations_averted'] = df_hq.num_mutations_k1000 - df_hq.num_mutations_k4
df_hq['tmrca'] = df_hq.parent_mrca_time - df_hq.recombinant_time  # ts.nodes_time[df_hq.parent_mrca] - ts.nodes_time[df_hq.recombinant]
    
rebar_nr_averted_muts = df_hq[~df_hq.is_rebar_recombinant].num_mutations_averted
rebar_re_averted_muts = df_hq[df_hq.is_rebar_recombinant].num_mutations_averted

rebar_nr_tmrca = df_hq[~df_hq.is_rebar_recombinant].tmrca
rebar_re_tmrca = df_hq[df_hq.is_rebar_recombinant].tmrca

rebar_nr_pango_dist = df_hq[~df_hq.is_rebar_recombinant].parent_pangonet_distance
rebar_re_pango_dist = df_hq[df_hq.is_rebar_recombinant].parent_pangonet_distance

In [None]:
def draw_stacked_histogram(
    ax,
    a,
    b,
    *,
    alegend,
    blegend,
    acolor,
    bcolor,
    xlabel,
    ylabel,
    xlim,
    bin_size,
    bar_width=0.8,
    show_legend=True,
    show_top_spline=True,
    show_right_spline=True,
):
    bin_edges = np.arange(xlim[0], xlim[1], bin_size)
    hist_a, _ = np.histogram(a, bins=bin_edges)
    hist_b, _ = np.histogram(b, bins=bin_edges)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 - 0.5
    _ = ax.bar(
        bin_centers,
        hist_a,
        width=bar_width,
        label=alegend,
        color=acolor,
    )
    _ = ax.bar(
        bin_centers,
        hist_b,
        bottom=hist_a,
        width=bar_width,
        label=blegend,
        color=bcolor,
    )
    ax.set_xticks(bin_centers.astype(int))
    ax.set_xlabel(xlabel, fontsize=15)
    ax.set_ylabel(ylabel, fontsize=15)
    if not show_top_spline:
        ax.spines['top'].set_visible(False)
    if not show_right_spline:
        ax.spines['right'].set_visible(False)
    if show_legend:
        ax.legend(frameon=False, fontsize=14);

In [None]:
from matplotlib.gridspec import GridSpec


colorpal = ["gray", "darkviolet"]


fig = plt.figure(figsize=(18, 12))
gs = GridSpec(3, 2, figure=fig, wspace=0.25, width_ratios=(2, 1))

scatter_ax = fig.add_subplot(gs[0:2, 0])
copypattern_ax = fig.add_subplot(gs[2, :])
inset_hist_averted_muts_ax = scatter_ax.inset_axes([0, 0, 1.0, 0.5])
inset_hist_averted_muts_ax.set_facecolor("none")
hist_tmrca_ax = fig.add_subplot(gs[0, 1])
hist_pango_dist_ax = fig.add_subplot(gs[1, 1])

draw_averted_scatterplot(df_hq, scatter_ax, jitter_seed=2)
draw_copying_patterns(copying_images, copypattern_ax)

draw_stacked_histogram(
    ax=inset_hist_averted_muts_ax,
    a=rebar_re_averted_muts,
    b=rebar_nr_averted_muts,
    alegend='Recombinant',
    blegend='Non-recombinant',
    acolor=colorpal[1],
    bcolor=colorpal[0],
    ylabel='Recombination nodes',
    xlabel='Mutations averted by recombination',
    xlim=[4, 25],
    bin_size=1,
    show_legend=False,
    show_top_spline=False,
    show_right_spline=False,
)
inset_hist_averted_muts_ax.tick_params(axis='both', which='major', labelsize=14)

draw_stacked_histogram(
    ax=hist_tmrca_ax,
    a=rebar_re_tmrca,
    b=rebar_nr_tmrca,
    alegend='Recombinant',
    blegend='Non-recombinant',
    acolor=colorpal[1],
    bcolor=colorpal[0],
    ylabel='Recombination nodes',
    xlabel='Time to the MRCA of sc2ts parents',
    xlim=[0, 800],
    bin_size=25,
    bar_width=25,
    show_legend=False,
)
hist_tmrca_ax.tick_params(axis='both', which='major', labelsize=14)
xticks = [x for x in range(0, 900, 100)]
hist_tmrca_ax.set_xticks(ticks=xticks, labels=xticks, fontsize=14)

draw_stacked_histogram(
    ax=hist_pango_dist_ax,
    a=rebar_re_pango_dist,
    b=rebar_nr_pango_dist,
    alegend='Recombinant',
    blegend='Non-recombinant',
    acolor=colorpal[1],
    bcolor=colorpal[0],
    ylabel='Recombination nodes',
    xlabel='Pango distance between sc2ts parents',
    xlim=[0, 20],
    bin_size=1,
    show_legend=False,
)
hist_pango_dist_ax.tick_params(axis='both', which='major', labelsize=14)
xticks = [x for x in range(10)] + [int(max(df_hq.parent_pangonet_distance))]
hist_pango_dist_ax.set_xticks(ticks=xticks, labels=xticks, fontsize=14)

plt.savefig("fig_rebar.pdf", format='pdf', dpi=600, transparent=True);