Skip to content

Commit

Permalink
script: add breakpoints as a plot and report slide
Browse files Browse the repository at this point in the history
  • Loading branch information
Katherine Eaton authored and ktmeaton committed Jun 28, 2022
1 parent e988f25 commit 869f3b4
Show file tree
Hide file tree
Showing 2 changed files with 340 additions and 2 deletions.
321 changes: 320 additions & 1 deletion scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pandas as pd
import numpy as np
import epiweeks
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import patches, colors
from matplotlib import patches, colors, lines
from datetime import datetime, timedelta
import sys
import copy
Expand All @@ -21,6 +22,27 @@
DPI = 96 * 2
FIGSIZE = [6.75, 5.33]

# Breakpoint Plotting
GENOME_LENGTH = 29903
BREAKPOINT_COLOR = "lightgrey"
X_BUFF = 1000
BREAKPOINT_COLOR = "lightgrey"

# Select and rename columns from linelist
LINEAGES_COLS = [
"cluster_id",
"status",
"lineage",
"parents",
"breakpoints",
"issue",
"subtree",
"sequences",
"growth_score",
"earliest_date",
"latest_date",
]


def categorical_cmap(nc, nsc, cmap="tab20", continuous=False):
"""
Expand Down Expand Up @@ -444,6 +466,303 @@ def main(
plt.savefig(out_path + ".png")
plt.savefig(out_path + ".svg")

# -------------------------------------------------------------------------
# Breakpoints
# -------------------------------------------------------------------------

# Create the lineages dataframe
lineages_data = {col: [] for col in LINEAGES_COLS}
lineages_data[geo] = []

for cluster_id in set(df["cluster_id"]):

match_df = df[df["cluster_id"] == cluster_id]

earliest_date = min(match_df["datetime"])
latest_date = max(match_df["datetime"])
sequences = len(match_df)

# TBD majority vote on disagreement
lineages_data["cluster_id"].append(cluster_id)
lineages_data["status"].append(match_df["status"].values[0])
lineages_data["lineage"].append(match_df["lineage"].values[0])
lineages_data["parents"].append(match_df["parents"].values[0])
lineages_data["breakpoints"].append(match_df["breakpoints"].values[0])
lineages_data["issue"].append(match_df["issue"].values[0])
lineages_data["subtree"].append(match_df["subtree"].values[0])

lineages_data["sequences"].append(sequences)
lineages_data["earliest_date"].append(earliest_date)
lineages_data["latest_date"].append(latest_date)

geo_list = list(set(match_df[geo]))
geo_list.sort()
geo_counts = []
for loc in geo_list:
loc_df = match_df[match_df[geo] == loc]
num_sequences = len(loc_df)
geo_counts.append("{} ({})".format(loc, num_sequences))

lineages_data[geo].append(", ".join(geo_counts))

# Growth Calculation
growth_score = 0
duration = (latest_date - earliest_date).days + 1
growth_score = round(sequences / duration, 2)
lineages_data["growth_score"].append(growth_score)

lineages_df = pd.DataFrame(lineages_data)
lineages_df.sort_values(by="sequences", ascending=False, inplace=True)

parent_colors = {}

# Create a dataframe to hold plot data
# Lineage (y axis, categorical)
#

breakpoints_data = {
"lineage": [],
"parent": [],
"start": [],
"end": [],
}

# Store data for plotting breakpoint distributions
breakpoints_dist_data = {
"coordinate": [],
"parent": [],
}

# -------------------------------------------------------------------------
# Create a dataframe to hold plot data

for rec in lineages_df.iterrows():
lineage = rec[1]["lineage"]
cluster_id = rec[1]["cluster_id"]

lineage = "{} {}".format(lineage, cluster_id)

parents = rec[1]["parents"]
breakpoints = rec[1]["breakpoints"]

parents_split = parents.split(",")
breakpoints_split = breakpoints.split(",")

prev_start_coord = 0

for i in range(0, len(parents_split)):
parent = parents_split[i]
if parent not in parent_colors:
parent_colors[parent] = ""

if i < (len(parents_split) - 1):
breakpoint = breakpoints_split[i]
breakpoint_start_coord = int(breakpoint.split(":")[0])
breakpoint_end_coord = int(breakpoint.split(":")[1])
breakpoint_mean_coord = round(
(breakpoint_start_coord + breakpoint_end_coord) / 2
)

# Give this coordinate to both parents
parent_next = parents_split[i + 1]
breakpoints_dist_data["parent"].append(parent)
breakpoints_dist_data["parent"].append(parent_next)
breakpoints_dist_data["coordinate"].append(breakpoint_mean_coord)
breakpoints_dist_data["coordinate"].append(breakpoint_mean_coord)

start_coord = prev_start_coord
end_coord = int(breakpoint.split(":")[0]) - 1
# Update start coord
prev_start_coord = int(breakpoint.split(":")[1]) + 1

# Add record for breakpoint
breakpoints_data["lineage"].append(lineage)
breakpoints_data["parent"].append("breakpoint")
breakpoints_data["start"].append(breakpoint_start_coord)
breakpoints_data["end"].append(breakpoint_end_coord)

else:
start_coord = prev_start_coord
end_coord = GENOME_LENGTH

# Add record for parent
breakpoints_data["lineage"].append(lineage)
breakpoints_data["parent"].append(parent)
breakpoints_data["start"].append(start_coord)
breakpoints_data["end"].append(end_coord)

# Convert the dictionary to a dataframe
breakpoints_df = pd.DataFrame(breakpoints_data)

# Sort by coordinates
breakpoints_df.sort_values(by=["parent", "start", "end"], inplace=True)

# -------------------------------------------------------------------------
# Colors

# tab10/Set1 should be a safe palette for now

i = 0
for parent in parent_colors:
color_rgb = plt.cm.Set1.colors[i]
color = colors.to_hex(color_rgb)
i += 1

parent_colors[parent] = color

parent_colors["breakpoint"] = BREAKPOINT_COLOR

# -------------------------------------------------------------------------
# Plot Setup

fig, axes = plt.subplots(
2,
1,
dpi=DPI,
figsize=FIGSIZE,
gridspec_kw={"height_ratios": [1, 5]},
sharex=True,
)

# -------------------------------------------------------------------------
# Plot Breakpoint Distribution

ax = axes[0]

breakpoints_dist_df = pd.DataFrame(breakpoints_dist_data)

sns.kdeplot(
ax=ax,
data=breakpoints_dist_df,
x="coordinate",
bw_adjust=0.3,
hue="parent",
palette=parent_colors,
multiple="stack",
fill=True,
)

ax.set_yticks([])
ax.set_ylabel("")
ax.legend().remove()

for spine in ax.spines:
ax.spines[spine].set_visible(False)

# -------------------------------------------------------------------------
# Plot Breakpoint Regions

ax = axes[1]

rect_height = 1
start_y = 0
y_buff = 1
y = start_y
y_increment = rect_height + y_buff
y_tick_locs = []
y_tick_labs_lineage = []
y_tick_labs_cluster = []

num_lineages = len(set(breakpoints_df["lineage"]))
lineages_seen = []

# Iterate through lineages to plot
for rec in breakpoints_df.iterrows():
lineage = rec[1]["lineage"]
if lineage in lineages_seen:
continue
lineages_seen.append(lineage)

y_tick_locs.append(y + (rect_height / 2))
lineage_label = lineage.split(" ")[0]
cluster_id_label = lineage.split(" ")[1]
y_tick_labs_lineage.append(lineage_label)
y_tick_labs_cluster.append(cluster_id_label)

lineage_df = breakpoints_df[breakpoints_df["lineage"] == lineage]

# Iterate through regions to plot
for rec in lineage_df.iterrows():
parent = rec[1]["parent"]
start = rec[1]["start"]
end = rec[1]["end"]

color = parent_colors[parent]

region_rect = patches.Rectangle(
xy=[start, y],
width=end - start,
height=rect_height,
linewidth=1,
edgecolor="none",
facecolor=color,
)
ax.add_patch(region_rect)

# Jump to the next y coordinate
y -= y_increment

# Axis Limits
ax.set_xlim(0 - X_BUFF, GENOME_LENGTH + X_BUFF)
ax.set_ylim(
0 - ((num_lineages * y_increment) - (y_increment / 2)), 0 + (rect_height * 2)
)

# This is the default fontisze to use
y_tick_fontsize = 10
if num_lineages >= 20:
y_tick_fontsize = 8
if num_lineages >= 30:
y_tick_fontsize = 6
if num_lineages >= 40:
y_tick_fontsize = 4

# Axis ticks
ax.set_yticks(y_tick_locs)
ax.set_yticklabels(y_tick_labs_lineage, fontsize=y_tick_fontsize)

# ax2 = ax.twinx()
# ax2.set_yticks(y_tick_locs)
# ax2.set_yticklabels(y_tick_labs_cluster, fontsize=y_tick_fontsize)

# Axis Labels
ax.set_ylabel("Lineage")
ax.set_xlabel("Genomic Coordinate")

# -------------------------------------------------------------------------
# Manually create legend

legend_handles = []
legend_labels = []

for parent in parent_colors:
handle = lines.Line2D([0], [0], color=parent_colors[parent], lw=4)
label = parent.title()
legend_handles.append(handle)
legend_labels.append(label)

legend = ax.legend(
handles=legend_handles,
labels=legend_labels,
title="Parent",
bbox_to_anchor=[1.01, 1.01],
)
frame = legend.get_frame()
frame.set_linewidth(1)
frame.set_edgecolor("black")
frame.set_boxstyle("Square", pad=0.2)

# -------------------------------------------------------------------------
# Export

plt.suptitle("Recombination Breakpoints by Lineage")
plt.tight_layout()
plt.subplots_adjust(hspace=0)
outpath = os.path.join(outdir, "breakpoints")
breakpoints_df.to_csv(outpath + ".tsv", sep="\t", index=False)
plt.savefig(outpath + ".png")
plt.savefig(outpath + ".svg")


if __name__ == "__main__":
main()
21 changes: 20 additions & 1 deletion scripts/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def main(
plot_dict[label]["plot_path"] = os.path.join(plot_dir, label + plot_suffix)
plot_dict[label]["df_path"] = os.path.join(plot_dir, label + df_suffix)
plot_dict[label]["df"] = pd.read_csv(plot_dict[label]["df_path"], sep="\t")
plot_dict[label]["df"].index = plot_dict[label]["df"]["epiweek"]

# Breakpoints df isn't over time, but by lineage
if "epiweek" in plot_dict[label]["df"].columns:
plot_dict[label]["df"].index = plot_dict[label]["df"]["epiweek"]

# Largest is special, as it takes the form largest_<lineage>_<cluster_id>.*
if label.startswith("largest_"):
Expand Down Expand Up @@ -435,6 +438,22 @@ def main(
for run in paragraph.runs:
run.font.size = pptx.util.Pt(14)

# ---------------------------------------------------------------------
# Breakpoints Summary

plot_path = plot_dict["breakpoints"]["plot_path"]

graph_slide_layout = presentation.slide_layouts[8]
slide = presentation.slides.add_slide(graph_slide_layout)
title = slide.shapes.title

title.text_frame.text = "Breakpoints"
title.text_frame.paragraphs[0].font.bold = True

chart_placeholder = slide.placeholders[1]
chart_placeholder.insert_picture(plot_path)
body = slide.placeholders[2]

# ---------------------------------------------------------------------
# Changelog
text_slide_layout = presentation.slide_layouts[1]
Expand Down

0 comments on commit 869f3b4

Please sign in to comment.