# Clock Filter

This is a notebook to identify date outliers in GISAID SARS-CoV-2.

## Setup

### Libraries

In [None]:
import pandas as pd
import copy
import numpy as np
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
import os

### Variables

In [None]:
# Reference Tree
tree_path     = "../data/ncov_gisaid_reference_2022-08-22.json"

NO_DATA_CHAR  = "NA"
REFERENCE_STRAIN="Wuhan/Hu-1/2019"
MIN_DATE="2020-01-01"

# Exclude strain if 'bad' for any of these
BAD_QUALITY_COLS = [
    "qc.missingData.status",
    "qc.mixedSites.status",
    "qc.frameShifts.status",
    "qc.stopCodons.status"
]

plt.rcParams["svg.fonttype"]     = "none"
plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["axes.facecolor"]   = "white"
plt.rcParams["savefig.facecolor"]   = "white"

#### GISAID Param

In [None]:
metadata_path = "../results/minimal.tsv"
outdir        = os.path.join("../results")

In [None]:
if not os.path.exists(outdir):
    os.mkdir(outdir)

### Functions

In [None]:
def extract_clade_mrca(tree_data):
    
    tmrca_data = []
    
    # Found a node with a clade annotation (MRCA)
    # The branch attributes will contain: 'labels': {'clade': '19A'}
    if "labels" in tree_data["branch_attrs"]:
        if "clade" in tree_data["branch_attrs"]["labels"]:
            clade = tree_data["branch_attrs"]["labels"]["clade"]
            # Dates are in node attributes
            tmrca = tree_data["node_attrs"]["num_date"]["value"]
            tmrca_ci = tree_data["node_attrs"]["num_date"]["confidence"]
            tmrca_data = [[clade,tmrca, tmrca_ci[0], tmrca_ci[1]]]

    # Continue recursion
    if "children" in tree_data:
        for child in tree_data["children"]:
            child_mrca_data = extract_clade_mrca(child)
            tmrca_data = tmrca_data + child_mrca_data

    return tmrca_data

In [None]:
def decimal_date_to_datetime(decimal_date):
    """
    Credit: Jon Clements
    Link: https://stackoverflow.com/a/20911144
    """
    year = int(decimal_date)
    remainder = decimal_date - year
    
    base_date = datetime(year, 1, 1)
    year_days = base_date.replace(year=base_date.year + 1) - base_date
    year_seconds = year_days.total_seconds()
    
    remainder_seconds = year_seconds * remainder
    result_datetime = base_date + timedelta(seconds=remainder_seconds)
    # exclude time
    result_date = result_datetime.replace(hour=0, minute=0, second=0, microsecond=0)
    
    return result_date

## Dataframes

### Metadata

In [None]:
metadata_df = pd.read_csv(metadata_path, sep="\t")
# Filter for complete dates
metadata_df = metadata_df[metadata_df["date"].str.match("[0-9]{4}-[0-9]{2}-[0-9]{2}")]
# Convert string dates to datetime objects
metadata_df["datetime"] = pd.to_datetime(metadata_df["date"], format="%Y-%m-%d")
metadata_df.fillna(NO_DATA_CHAR, inplace=True)

### Filter

#### 1. Date Range

In [None]:
metadata_df = metadata_df[metadata_df["datetime"] >= np.datetime64(MIN_DATE)]

#### 2. Host

In [None]:
metadata_df = metadata_df[metadata_df["host"] == "Human"]

#### 3. Reference Strain

In [None]:
metadata_df = metadata_df[metadata_df["strain"] != REFERENCE_STRAIN]

#### 4. Genome Quality

In [None]:
for col in BAD_QUALITY_COLS:
    print("Filter: {}".format(col))
    metadata_df = metadata_df[metadata_df[col] != "bad"]

### Nextstrain MRCA

#### Parse Tree JSON

In [None]:
with open(tree_path) as infile:
    json_data = json.load(infile)
tree_data = json_data["tree"]

# Creates a list of lists
mrca_data = extract_clade_mrca(tree_data)

#### Convert to Dataframe

In [None]:
# First, convert list of lists to dictionary
mrca_dict = {
    "clade" : [],
    "inferred_date" : [],
    "ci_low" : [],
    "ci_high" : [],
}
for clade_data in mrca_data:
    clade = clade_data[0]
    
    # Convert decimal dates (2021.5) to datetime objects    
    inferred_date = decimal_date_to_datetime(clade_data[1])
    ci_low = decimal_date_to_datetime(clade_data[2])
    ci_high = decimal_date_to_datetime(clade_data[3])
    
    mrca_dict["clade"].append(clade)
    mrca_dict["inferred_date"].append(inferred_date)
    mrca_dict["ci_low"].append(ci_low)
    mrca_dict["ci_high"].append(ci_high)
    
mrca_df = pd.DataFrame(mrca_dict)
mrca_df = mrca_df.sort_values(by="clade")
outpath = os.path.join(outdir, "clade_mrca.tsv")
mrca_df.to_csv(outpath, sep="\t", index=False)

mrca_df

## Clock Filter

In [None]:
clade_data = {}

# Iterate through the clades observed in the metadata
for clade in set(metadata_df["clade"]):

    # Skip over samples missing a clade assignment (i.e. couldn't be aligned)
    # Or recombinants, which is not a true clade
    if (
        clade == NO_DATA_CHAR 
        or clade == "recombinant"
    ): continue
    
    print("Filtering clade:", clade)
    
    # Initialize data dict for this clade
    clade_data[clade] = {}
    
    clade_metadata = copy.copy(metadata_df[metadata_df["clade"] == clade])
    clade_mrca     = mrca_df[mrca_df["clade"] == clade]
   
    # Classify samples based on collection dates
    # By default, set to Undetermined
    clade_metadata.loc[clade_metadata.index,"date_filter"] = "Undetermined"

    # Check if we have MRCA dates for this clade
    if len(clade_mrca) > 0:
        clade_metadata.loc[clade_metadata["datetime"] < clade_mrca["ci_low"].values[0], "date_filter"] = "Fail"
        clade_metadata.loc[clade_metadata["datetime"] >= clade_mrca["ci_high"].values[0], "date_filter"] = "Pass"

    # Get counts for each category, just in case we want to use this in a figure caption or title
    num_undetermined = len(clade_metadata[clade_metadata["date_filter"] == "Undetermined"])
    num_fail = len(clade_metadata[clade_metadata["date_filter"] == "Fail"])
    num_pass = len(clade_metadata[clade_metadata["date_filter"] == "Pass"])
        
    # Remove duplicate values for a scatter plot
    clade_metadatal_minimal = clade_metadata[["datetime", "totalSubstitutions", "date_filter"]].drop_duplicates()

    # Store the stats in the clade data dict
    clade_data[clade] = {}
    clade_data[clade]["clade_metadata"]         = clade_metadata
    clade_data[clade]["clade_metadata_minimal"] = clade_metadata    
    clade_data[clade]["clade_mrca"]             = clade_mrca
    clade_data[clade]["num_undetermined"]       = num_undetermined
    clade_data[clade]["num_fail"]               = num_fail
    clade_data[clade]["num_pass"]               = num_pass

## Exclusion List

In [None]:
exclude_strains = []

for clade in clade_data:
    clade_metadata = clade_data[clade]["clade_metadata"]
    exclude_df = clade_metadata[(clade_metadata["date_filter"] == "Fail")]
    exclude_strains = exclude_strains + list(exclude_df["strain"].values)

outpath = os.path.join(outdir, "exclude.clock.txt")
with open(outpath,"w") as outfile:
    outfile.write("\n".join(exclude_strains) + "\n")

## Plot

In [None]:
figsize = [12,6]
dpi = 200

# For testing
num_rows=1000000000

for clade in clade_data:

    print("Plotting clade:", clade)
    
    # Setup up a joint plot (central scatter with marginal distributions)
    plot = sns.JointGrid()

    # Plot the scatter in the central axis
    # This uses the minimal data, because we ignore scatter density
    sns.scatterplot(
        ax   = plot.ax_joint,
        data = clade_data[clade]["clade_metadata_minimal"].head(num_rows),
        x    = "datetime",
        y    = "totalSubstitutions",
        s    = 10,
        ec   = "none",
        hue  = "date_filter",
        palette = {"Pass" : "green", "Fail" : "red", "Undetermined": "grey",},
        alpha = 0.75,
        zorder = 2,
        rasterized = True,    
    )
    
    # Collection Dates: Plot the marginal distribution
    # This uses the full data, because duplicates are import
    sns.kdeplot(
        ax   = plot.ax_marg_x,
        data = clade_data[clade]["clade_metadata"].head(num_rows),
        x    = "datetime",
        color = "black",
        fill = True,
        alpha = 0.5,
    )
    
    # Substitutions: Plot the marginal distribution
    # This uses the full data, because duplicates are import
    sns.kdeplot(
        ax   = plot.ax_marg_y,
        data = clade_data[clade]["clade_metadata"].head(num_rows),
        y    = "totalSubstitutions",
        color = "black",
        fill = True,
        alpha = 0.5,
    ) 
    
    # MRCA Dates
    if len(clade_data[clade]["clade_mrca"]) > 0:
        plot.ax_joint.axvline(clade_data[clade]["clade_mrca"]["inferred_date"], color="black", linewidth=1, label="MRCA", zorder = 1)
        plot.ax_joint.axvline(clade_data[clade]["clade_mrca"]["ci_low"], color="grey", linestyle="--", linewidth=1, label="MRCA (95% CI)", zorder = 1)
        plot.ax_joint.axvline(clade_data[clade]["clade_mrca"]["ci_high"], color="grey", linestyle="--", linewidth=1, zorder = 1)

    # Axis Labels
    plot.ax_joint.set_xlabel("Collection Date", fontweight="bold")
    plot.ax_joint.set_ylabel("Total Substitutions", fontweight="bold")
    
    # Axis Ticks
    plot.ax_joint.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[1,3,5,7,9,11]))
    for label in plot.ax_joint.get_xticklabels(which='major'):
        label.set(rotation=90, horizontalalignment='center')

    # Legend
    legend = plot.ax_joint.legend(title=clade)
    legend.get_frame().set_edgecolor("black")
    legend.get_title().set_fontweight("bold")

    # Dimensions and Resolution
    plot.fig.set_figwidth(figsize[0])
    plot.fig.set_figheight(figsize[1])
    plot.fig.set_dpi(dpi)
    plot.fig.tight_layout()

    # Export
    outpath = os.path.join(outdir, "{clade}_clock_filter.png".format(
        clade=clade.replace(" ","_")
    ))
    plt.savefig(outpath)
    
    # Close figure, so we don't get memory warnings
    plt.close("all")