Skip to content

Commit

Permalink
Add workflow to make variant prevalence plot
Browse files Browse the repository at this point in the history
  • Loading branch information
corneliusroemer committed Oct 3, 2023
1 parent 3790aa3 commit bb3d0f6
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.snakemake/*
.snakemake/
builds/*
auspice/*
data/
Expand All @@ -7,3 +7,4 @@ deploy/*
.DS_Store
figures/
.swp
results/
34 changes: 34 additions & 0 deletions variant-prevalence/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
rule all:
input:
"results/weekly-variant-counts-south-africa.tsv",


rule dowload_metadata:
output:
"results/metadata.tsv.zst",
shell:
"""
aws s3 cp s3://nextstrain-ncov-private/metadata.tsv.zst {output}
"""


rule filter_south_africa:
input:
"results/metadata.tsv.zst",
output:
"results/metadata-south-africa.tsv",
shell:
"""
zstdcat {input} | tsv-filter -H --str-in-fld country:"South Africa" > {output}
"""
rule classify_variants:
input:
"results/metadata-south-africa.tsv",
output:
"results/weekly-variant-counts-south-africa.tsv",
shell:
"""
python3 scripts/classify_variants.py
"""
6 changes: 6 additions & 0 deletions variant-prevalence/profiles/default/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cores: all
reason: true
printshellcmds: true
keep-going: true
rerun-incomplete: true

114 changes: 114 additions & 0 deletions variant-prevalence/scripts/classify_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#%%
"""
Reads in TSV with each row being singl observation
- Classifies each sample as belonging to a variant
- Aggregates into weekly counts
"""
import numpy as np
import pandas as pd
from pango_aliasor.aliasor import Aliasor
aliasor = Aliasor()
#%%
df = pd.read_csv('results/metadata-south-africa.tsv', sep='\t')
df

#%%
#
def safe_float(x):
try:
return float(x)
except:
return np.nan
# Remove samples that have NAN in Nextclade_pango or clade_nextstrain
df.clock_deviation = df["clock_deviation"].apply(safe_float)
df.dropna(subset=['Nextclade_pango', 'clade_nextstrain', "clock_deviation"], inplace=True)
df

#%%
# Use clock deviation: needs to be >= -10 and <= 25
df = df[(df.clock_deviation >= -10) & (df.clock_deviation <= 20)]

#%%
# Overwrite NextcladePango with unaliased version
df['Nextclade_pango'] = df['Nextclade_pango'].apply(lambda x: aliasor.uncompress(x))
df
#%%
# Variant mapping
VARIANT_MAP = {
"Beta": df.clade_nextstrain.isin(["20H"]),
"Alpha": df.clade_nextstrain.isin(["20I"]),
"Delta": df.clade_nextstrain.isin(["21A", "21I", "21J"]),
"BA.1": df.clade_nextstrain.isin(["21K", "21M"]),
"BA.2": df.clade_nextstrain.isin(["21L", "22C"]),
"BA.2.75": df.clade_nextstrain.isin(["22D", "23C"]),
"BA.4/5": df.clade_nextstrain.isin(["22A", "22B", "22E"]),
"XBB": df.Nextclade_pango.str.startswith("XBB"),
"BA.2.86": df.Nextclade_pango.str.startswith("B.1.1.529.2.86"),
}
df["variant"] = "Other"
for variant, mask in VARIANT_MAP.items():
df.loc[mask, "variant"] = variant
df
# %%

# Round date down to Monday of week, output format as YYYY-MM-DD
df["week"] = pd.to_datetime(df.date).dt.to_period("W-MON").dt.strftime("%Y-%m-%d")
df
# %%
# Aggregate counts by week in wide format
wide = df.groupby(["week", "variant"]).size().unstack().fillna(0).astype(int)
#%%
# Add total column
wide["total"] = wide.sum(axis=1)

# Add name of most common variant column
wide["most_common"] = wide.drop(columns=["total"]).idxmax(axis=1)

# Add count of most common variant column
wide["most_common_count"] = wide.drop(columns=["total"]).max(axis=1)

# Add percentag of most common variant column
wide["most_common_pct"] = 100*wide["most_common_count"] / wide["total"]

# Add percentage columns
from itertools import chain
for variant in chain(VARIANT_MAP.keys(),["Other"]):
wide[f"{variant}_pct"] = wide[variant] / wide["total"] * 100


# %%
COLUMNS=[
"total",
"most_common",
"most_common_count",
"most_common_pct",
"Other",
"Beta",
"Alpha",
"Delta",
"BA.1",
"BA.2",
"BA.4/5",
"BA.2.75",
"XBB",
"BA.2.86",
"Other_pct",
"Beta_pct",
"Alpha_pct",
"Delta_pct",
"BA.1_pct",
"BA.2_pct",
"BA.4/5_pct",
"BA.2.75_pct",
"XBB_pct",
"BA.2.86_pct",
]
# wide.to_csv("results/weekly-variant-counts-south-africa.tsv", sep="\t", columns=COLUMNS)
# format floats with 2 decimal places
wide.to_csv("results/weekly-variant-counts-south-africa.tsv", sep="\t", columns=COLUMNS, float_format="%.1f")

# %%
# Make plot of all the pct columns
# wide[[col for col in wide.columns if (col.endswith("_pct") and col != "most_common_pct")]].plot.line(figsize=(12, 6), title="SARS-CoV-2 variant proportions in South Africa")

# %%

0 comments on commit bb3d0f6

Please sign in to comment.