In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
df = pd.read_csv(
    snakemake.input.metadata,
    sep="\t",
    parse_dates=["timepoint"],
)

In [None]:
df.head()

In [None]:
count_df = df.groupby([
    "delay_type",
    "timepoint"
])["strain"].count().reset_index().rename(
    columns={"strain": "count"}
)

In [None]:
count_df.head()

In [None]:
count_by_timepoint = count_df.pivot(
    index="timepoint",
    columns=["delay_type"],
    values="count",
)

In [None]:
count_by_timepoint["proportion_ideal"] = count_by_timepoint["ideal"] / count_by_timepoint["none"]
count_by_timepoint["proportion_realistic"] = count_by_timepoint["realistic"] / count_by_timepoint["none"]

In [None]:
count_by_timepoint.head()

In [None]:
proportions = count_by_timepoint.loc[
    :,
    ["proportion_ideal", "proportion_realistic"]
].melt(
    value_name="proportion",
    ignore_index=False,
).reset_index()

In [None]:
proportions["delay_type"] = proportions["delay_type"].str.replace("proportion_", "")

In [None]:
proportions.head()

In [None]:
fig, (ax_count, ax_proportion) = plt.subplots(2, 1, figsize=(8, 6), dpi=200)

color_by_delay_type = {
    "none": "C0",
    "ideal": "C1",
    "realistic": "C2",
}

for delay_type in color_by_delay_type.keys():
    delay_df = count_df[count_df["delay_type"] == delay_type]
    
    ax_count.plot(
        delay_df["timepoint"],
        delay_df["count"],
        "-",
        color=color_by_delay_type[delay_type],
        label=delay_type,
    )

ax_count.legend(
    title="Delay type",
    frameon=False,
)
    
ax_count.set_xlabel("Date")
ax_count.set_ylabel("Number of sequences")
    
ax_count.set_ylim(bottom=0)

# Proportion of total without delay per delay type.
for delay_type in color_by_delay_type.keys():
    proportions_delay_df = proportions[proportions["delay_type"] == delay_type]
    
    ax_proportion.plot(
        proportions_delay_df["timepoint"],
        proportions_delay_df["proportion"],
        "-",
        color=color_by_delay_type[delay_type],
        label=delay_type,
    )
    
ax_proportion.set_xlabel("Date")
ax_proportion.set_ylabel("Proportion of undelayed\nsequences available at delay")

ax_proportion.set_ylim(bottom=0, top=1)

panel_labels_dict = {
    "weight": "bold",
    "size": 14,
}
plt.figtext(0.01, 0.97, "A", **panel_labels_dict)
plt.figtext(0.01, 0.47, "B", **panel_labels_dict)

sns.despine()
plt.tight_layout()
plt.savefig(snakemake.output.figure)