Skip to content

Commit

Permalink
workflow: first successful high-throughput run
Browse files Browse the repository at this point in the history
  • Loading branch information
Katherine Eaton authored and ktmeaton committed Jun 21, 2022
1 parent cd741a1 commit d1ccca2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 26 deletions.
2 changes: 2 additions & 0 deletions resources/breakpoints.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ proposed676 676 10448:11287 Omicron/21K,Omicron/21L NA
proposed701 701 NA NA NA
proposed709 709 6516:8392 Omicron/21L,Omicron/21K NA
proposed759 759 5387:8392 Omicron/21K,Omicron/21L NA
proposed757 757 17413:19953 Omicron/21K,Omicron/21L NA
proposed771 771 18163:19862 Omicron/21L,Omicron/22B NA
1 change: 1 addition & 0 deletions resources/issue_to_lineage.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
439 proposed439
422 proposed422
210 proposed210
771 proposed771
759 proposed759
757 proposed757
646 proposed646
Expand Down
68 changes: 58 additions & 10 deletions scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import numpy as np
import epiweeks
import matplotlib.pyplot as plt
from matplotlib import patches, cm
from matplotlib import patches, colors
from datetime import datetime, timedelta
import sys
import copy
import math

NO_DATA_CHAR = "NA"
ALPHA_LAG = 0.25
Expand All @@ -21,6 +22,29 @@
FIGSIZE = [6.75, 5.33]


def categorical_cmap(nc, nsc, cmap="tab20", continuous=False):
"""
Author: ImportanceOfBeingEarnest
Link: https://stackoverflow.com/a/47232942
"""
if nc > plt.get_cmap(cmap).N:
raise ValueError("Too many categories for the specified colormap.")
if continuous:
ccolors = plt.get_cmap(cmap)(np.linspace(0, 1, nc))
else:
ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
cols = np.zeros((nc * nsc, 3))
for i, c in enumerate(ccolors):
chsv = colors.rgb_to_hsv(c[:3])
arhsv = np.tile(chsv, nsc).reshape(nsc, 3)
arhsv[:, 1] = np.linspace(chsv[1], 0.25, nsc)
arhsv[:, 2] = np.linspace(chsv[2], 1, nsc)
rgb = colors.hsv_to_rgb(arhsv)
cols[i * nsc : (i + 1) * nsc, :] = rgb
cmap = colors.ListedColormap(cols)
return cmap


@click.command()
@click.option("--input", help="Recombinant sequences (TSV)", required=True)
@click.option("--outdir", help="Output directory", required=False, default=".")
Expand Down Expand Up @@ -87,6 +111,7 @@ def main(
min_datetime = datetime.strptime(min_date, "%Y-%m-%d")
min_epiweek = epiweeks.Week.fromdate(min_datetime, system="iso").startdate()
elif weeks:
weeks = int(weeks)
min_epiweek = max_epiweek - timedelta(weeks=(weeks - 1))
else:
min_epiweek = epiweeks.Week.fromdate(
Expand Down Expand Up @@ -119,8 +144,9 @@ def main(
if cluster_size >= largest_cluster_size:
largest_cluster_id = cluster_id
largest_lineage = cluster_df["lineage"].values[0]
largest_cluster_size = cluster_size

if not cluster_size == 1:
if cluster_size == 1:
for i in cluster_df.index:
drop_singleton_ids.append(i)

Expand Down Expand Up @@ -265,18 +291,35 @@ def main(
plot_df.to_csv(out_path + ".tsv", sep="\t", index=False)

# Use the tab20 color palette
if len(plot_df.columns) > 20:
num_cat = len(plot_df.columns)

legend_ncol = 1
if num_cat > 10:
legend_ncol = 2

pal = "tab10"
# Exclude the last color in tab10, which is a light blue
pal_num_cat = 9
pal_num_sub_cat = 1

custom_cmap_i = np.linspace(0.0, 1.0, pal_num_cat)

if num_cat > pal_num_cat:
print(
"WARNING: {} dataframe has more than 20 categories".format(label),
"WARNING: {} dataframe has more than {} categories".format(
label, pal_num_cat
),
file=sys.stderr,
)
# Determine subcategories
pal_num_sub_cat = math.ceil(num_cat / pal_num_cat)
custom_cmap_i = np.linspace(0.0, 1.0, num_cat)

legend_ncol = 1
if len(plot_df.columns) > 10:
legend_ncol = 2
df_cmap = categorical_cmap(pal_num_cat, pal_num_sub_cat, cmap=pal)(
custom_cmap_i
)

custom_cmap_i = np.linspace(0.0, 1.0, 20)
df_cmap = cm.get_cmap("tab20")(custom_cmap_i)
# df_cmap = cm.get_cmap(pal)(custom_cmap_i)

# The df is sorted by time (epiweek)
# But we want colors to be sorted by number of sequences
Expand Down Expand Up @@ -374,12 +417,17 @@ def main(
else:
ax.set_ylim(0, round(max_epiweek_sequences * EPIWEEK_MAX_BUFF_FACTOR, 1))

# small df: upper right
# large df: upper left
legend_loc = "upper right"
if len(plot_df) > 16:
legend_loc = "upper left"
legend = ax.legend(
title=legend_title.title(),
edgecolor="black",
fontsize=8,
ncol=legend_ncol,
loc="upper right",
loc=legend_loc,
)

legend.get_frame().set_linewidth(1)
Expand Down
60 changes: 53 additions & 7 deletions scripts/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@
@click.option(
"--changelog", help="Markdown changelog", required=False, default="CHANGELOG.md"
)
@click.option(
"--singletons",
help="Whether singletons were included in plots",
is_flag=True,
)
@click.option(
"--template",
help="Powerpoint template",
Expand All @@ -65,6 +70,7 @@ def main(
geo,
changelog,
output,
singletons,
):
"""Create a report of powerpoint slides"""

Expand Down Expand Up @@ -163,7 +169,6 @@ def main(
if status.lower() in RECOMBINANT_STATUS:
regex = RECOMBINANT_STATUS[status.lower()]
for lineage in lineages:
print(status, regex, lineage, re.match(regex, lineage))
if re.match(regex, lineage):

seq_count = sum(lineage_df[lineage].dropna())
Expand All @@ -186,9 +191,15 @@ def main(
body = slide.placeholders[2]

summary = "\n"
summary += "There are {num_lineages} recombinant lineages.\n".format(
num_lineages=num_lineages
)
if singletons:
summary += "There are {num_lineages} recombinant lineages.\n".format(
num_lineages=num_lineages
)
else:
summary += "There are {num_lineages} recombinant lineages*.\n".format(
num_lineages=num_lineages
)

for status in RECOMBINANT_STATUS:
if status in status_counts:
count = status_counts[status]["lineages"]
Expand All @@ -198,9 +209,16 @@ def main(
lineages=count, status=status
)
summary += "\n"
summary += "There are {num_sequences} recombinant sequences.\n".format(
num_sequences=num_sequences
)
# Whether we need a footnote for singletons
if singletons:
summary += "There are {num_sequences} recombinant sequences.\n".format(
num_sequences=num_sequences
)
else:
summary += "There are {num_sequences} recombinant sequences*.\n".format(
num_sequences=num_sequences
)

for status in RECOMBINANT_STATUS:
if status in status_counts:
count = status_counts[status]["sequences"]
Expand All @@ -209,6 +227,9 @@ def main(
summary += " - {sequences} sequences are {status}.\n".format(
sequences=count, status=status
)
if not singletons:
summary += "\n"
summary += "*Excluding singleton lineages (N=1)"

body.text_frame.text = summary

Expand All @@ -224,6 +245,10 @@ def main(
geo_df = plot_dict["geography"]["df"]
geos = list(geo_df.columns)
geos.remove("epiweek")
# Order columns
geos_counts = {region: int(sum(geo_df[region])) for region in geos}
geos = sorted(geos_counts, key=geos_counts.get, reverse=True)

num_geos = len(geos)

graph_slide_layout = presentation.slide_layouts[8]
Expand Down Expand Up @@ -263,6 +288,14 @@ def main(

designated_lineages = list(designated_df.columns)
designated_lineages.remove("epiweek")

# Order columns
designated_counts = {
lineage: int(sum(designated_df[lineage])) for lineage in designated_lineages
}
designated_lineages = sorted(
designated_counts, key=designated_counts.get, reverse=True
)
num_designated = len(designated_lineages)

graph_slide_layout = presentation.slide_layouts[8]
Expand Down Expand Up @@ -302,6 +335,15 @@ def main(

largest_geos = list(largest_df.columns)
largest_geos.remove("epiweek")

# Order columns
largest_geos_counts = {
region: int(sum(largest_df[region])) for region in largest_geos
}
largest_geos = sorted(
largest_geos_counts, key=largest_geos_counts.get, reverse=True
)

num_geos = len(largest_geos)

largest_lineage = plot_dict["largest"]["lineage"]
Expand Down Expand Up @@ -358,6 +400,10 @@ def main(

parents = list(parents_df.columns)
parents.remove("epiweek")

parents_counts = {p: int(sum(parents_df[p])) for p in parents}
parents = sorted(parents_counts, key=parents_counts.get, reverse=True)

num_parents = len(parents)

graph_slide_layout = presentation.slide_layouts[8]
Expand Down
12 changes: 6 additions & 6 deletions scripts/usher_collapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def json_get_strains(json_tree):
@click.option("--indir", help="Input directory of subtrees.", required=True)
@click.option("--outdir", help="Output directory for collapsed trees.", required=True)
@click.option("--log", help="Logfile.", required=False)
@click.option(
"--duplicate-col",
help="Label duplicate sequences based on the ID in this column.",
required=False,
)
# @click.option(
# "--duplicate-col",
# help="Label duplicate sequences based on the ID in this column.",
# required=False,
# )
def main(
indir,
outdir,
log,
duplicate_col,
# duplicate_col,
):
"""Collect and condense UShER subtrees"""

Expand Down
8 changes: 5 additions & 3 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ rule usher_subtree_collapse:
"logs/{rule}/{{build}}_{today}.log".format(today=today, rule=rule_name),
shell:
"""
python3 scripts/usher_collapse.py --indir {input.subtrees_dir} --outdir {output.collapse_dir} --log {log} >> {log} 2>&1 ;
python3 scripts/usher_collapse.py --indir {input.subtrees_dir} --outdir {output.collapse_dir} --log {log};
"""

# -----------------------------------------------------------------------------#
Expand Down Expand Up @@ -1249,6 +1249,7 @@ rule report:
geo = lambda wildcards: _params_report(wildcards.build)["geo"],
changelog = lambda wildcards: _params_report(wildcards.build)["changelog"],
template = lambda wildcards: _params_report(wildcards.build)["template"],
singletons = lambda wildcards: _params_plot(wildcards.build)["singletons"],

threads: 1
resources:
Expand All @@ -1263,7 +1264,7 @@ rule report:
csvtk csv2xlsx -t -o {output.xlsx} {input.tables} > {log} 2>&1;
# Create the powerpoint slides
python3 scripts/report.py --plot-dir {input.plots} --output {output.pptx} {params.geo} {params.changelog} {params.template} >> {log} 2>&1;
python3 scripts/report.py --plot-dir {input.plots} --output {output.pptx} {params.geo} {params.changelog} {params.template} {params.singletons} >> {log} 2>&1;
"""

rule report_historical:
Expand All @@ -1283,6 +1284,7 @@ rule report_historical:
geo = lambda wildcards: _params_report(wildcards.build)["geo"],
changelog = lambda wildcards: _params_report(wildcards.build)["changelog"],
template = lambda wildcards: _params_report(wildcards.build)["template"],
singletons = lambda wildcards: _params_plot(wildcards.build)["singletons"],

threads: 1
resources:
Expand All @@ -1294,5 +1296,5 @@ rule report_historical:
shell:
"""
# Create the powerpoint slides
python3 scripts/report.py --plot-dir {input.plots} --output {output.pptx} {params.geo} {params.changelog} {params.template} > {log} 2>&1;
python3 scripts/report.py --plot-dir {input.plots} --output {output.pptx} {params.geo} {params.changelog} {params.template} {params.singletons} > {log} 2>&1;
"""

0 comments on commit d1ccca2

Please sign in to comment.