# Objectives

1. Plot Locus Coverage of Whole Genome.
1. Add Locus Coverage to Dataframe.
1. Plot Locus Coverage: Distribution.
1. Plot Locus Coverage: Phlyo.
1. Identify Interesting Samples.

---
# Setup

## Imports

In [1]:
import os
import subprocess
from Bio import Phylo
import copy
import math
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import lines, patches, colors, gridspec, ticker
import seaborn as sns
import scipy.stats
from functions import *

from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score

import matplotlib.cm as cm
import numpy as np


## Variables

In [7]:
#WILDCARDS = ["all", "chromosome", "full", "30"]
WILDCARDS = ["all", "chromosome", "full", "5"]
#project_dir = "/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/denmark/"
#project_dir = "/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/pla/"
project_dir = "/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/main/"
results_dir = project_dir

READS_ORIGIN = WILDCARDS[0]
LOCUS_NAME = WILDCARDS[1]
PRUNE = WILDCARDS[2]
MISSING_DATA = WILDCARDS[3]

NO_DATA_CHAR = "NA"

pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 10)

plt.rcParams['lines.linewidth'] = 0.5

flierprops = dict(marker='o', markerfacecolor='black', markersize=1,
                   markeredgecolor='none')

D3_COL_PAL = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]

## Paths

In [8]:
colors_path = os.path.join(
    results_dir,
    "augur/{}/{}/{}/filter{}/beast/colors.tsv"
    .format(READS_ORIGIN, LOCUS_NAME, PRUNE, MISSING_DATA)
)

tree_path = os.path.join(
    results_dir, 
    "augur/{}/{}/{}/filter{}/beast/all.timetree.nwk"
    .format(READS_ORIGIN, LOCUS_NAME, PRUNE, MISSING_DATA)
)
"""metadata_path = os.path.join(
    results_dir, 
    "augur/{}/{}/{}/filter{}/beast/metadata.tsv"
    .format(READS_ORIGIN, LOCUS_NAME, PRUNE, MISSING_DATA)
)"""
metadata_path = os.path.join(results_dir, "metadata/all/metadata.tsv"
)

cov_df_path = os.path.join(
    results_dir, 
    "locus_coverage_collect/{}/locus_coverage.txt"
    .format(READS_ORIGIN)
)

dep_df_path = os.path.join(
    results_dir, 
    "locus_coverage_collect/{}/locus_depth.txt"
    .format(READS_ORIGIN)
)
ref_gff_path = os.path.join(
    results_dir,
    "data/reference/GCA_000009065.1_ASM906v1_genomic/GCA_000009065.1_ASM906v1_genomic.gff"
)

# ------------------------------------------
# Output
out_dir = os.path.join(results_dir, "locus_coverage_collect/{}/".format(READS_ORIGIN))
print(out_dir)

/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/main/locus_coverage_collect/all/


## Import Tree

In [9]:
"""tree = Phylo.read(tree_path, format="newick")
tree.ladderize(reverse=True)"""

'tree = Phylo.read(tree_path, format="newick")\ntree.ladderize(reverse=True)'

## Import Metadata

In [10]:
metadata_df = pd.read_csv(metadata_path, sep='\t')
metadata_df.set_index(metadata_df.columns[0], inplace=True)
metadata_df.fillna(NO_DATA_CHAR, inplace=True)

#display(metadata_df)

## Import Locus Dataframes

In [11]:
cov_df = pd.read_csv(cov_df_path, sep='\t')
cov_df.set_index("Sample", inplace=True)

#display(cov_df)  

dep_df = pd.read_csv(dep_df_path, sep='\t')
dep_df.set_index("Sample", inplace=True)
               
#display(dep_df)

FileNotFoundError: [Errno 2] File /mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/main/locus_coverage_collect/all/locus_coverage.txt does not exist: '/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/main/locus_coverage_collect/all/locus_coverage.txt'

In [None]:
# Remove samples not in tree dataframe (ex. outgroup)
for sample in cov_df.index:
    if sample not in metadata_df.index:
        print("Dropping:",sample)
        cov_df.drop(sample, inplace=True)
        dep_df.drop(sample, inplace=True)

# exception

#sample = "SAMN01991268"
#if sample in metadata_df.index:
#    cov_df.drop(sample, inplace=True)
#    dep_df.drop(sample, inplace=True)    

## Import Reference GFF

In [None]:
ref_gff_columns = [
        "seqname",
        "source",
        "feature",
        "start", #sequence numbering starting at 1.
        "end",
        "score",
        "strand",
        "frame",
        "attribute",
        "id",
]

ref_gff_tsv = ref_gff_path + ".tsv"
subprocess.run(["rm", "-f", ref_gff_tsv]) 

with open(ref_gff_path) as infile:
    with open(ref_gff_tsv, "a") as outfile:
        ref_gff_read = infile.read().split("\n")
        for line in ref_gff_read:
            if not line.startswith("#"):
                split_line = line.strip().split("\t")
                attr = split_line[-1].split(";")
                attr_id = attr[0].split("=")
                if len(attr_id) < 2:
                    continue
                attr_id = attr_id[1]
                line += "\t{}".format(attr_id)
                outfile.write(line + "\n")
                
ref_df = pd.read_csv(ref_gff_tsv, sep='\t', header=None)
ref_df.columns = ref_gff_columns
ref_df.set_index("id", inplace=True)
#display(ref_df)

## Create Dataframes for replicons and genes

In [None]:
ref_rep_df = ref_df[ref_df["feature"] == "region"]
#display(ref_rep_df)

ref_gene_df = ref_df[ref_df["feature"] == "gene"]
#display(ref_gene_df)

## Separate Dataframe by Type

In [None]:
cov_rep_df = copy.deepcopy(cov_df)
cov_gene_df = copy.deepcopy(cov_df)
dep_rep_df = copy.deepcopy(dep_df)
dep_gene_df = copy.deepcopy(dep_df)
non_replicon_loci = []
non_gene_loci = []

for col in cov_df.columns:
    if col not in ref_rep_df.index:
        non_replicon_loci.append(col)
    if col not in ref_gene_df.index:
        non_gene_loci.append(col)
        
cov_rep_df.drop(columns=non_replicon_loci, inplace=True)
display(cov_rep_df)

dep_rep_df.drop(columns=non_replicon_loci, inplace=True)
#display(dep_rep_df)
       
cov_gene_df.drop(columns=non_gene_loci, inplace=True)
#display(cov_gene_df)

dep_gene_df.drop(columns=non_gene_loci, inplace=True)
#display(dep_gene_df)

---
# pPCP1 Analysis

In [None]:
target_seqname = "AL109969"
target_locus = "AL109969.1:1..9612" # pPCP1
target_gene = "gene-YPPCP1.07" # pla

# Danish exlude
# exclude_samples = ["P187","P212","P387",]

# Exclude IS100 and the IS100 transposase
exclude_genes = ["gene-YPPCP1.01", "gene-YPPCP1.02" ]
gene_labels = {
    "gene-YPPCP1.01": 
        {
            "label": "IS100",
            "coord": [87, 1109],
            "color" : "#c4c4c4",
        },
    "gene-YPPCP1.02": 
        {
            "label": "IS100",
            "coord": [1106, 1888],
            "color" : "#c4c4c4",            
        },
    "gene-YPPCP1.03": 
        {
            "label": "rop",
            "coord": [2925, 3119],
            "color": "#1f77b4",
        },    
    "gene-YPPCP1.04": 
        {
            "label": "pim",
            "coord": [4355, 4780],
            "color": "#1f77b4",            
        },  
    "gene-YPPCP1.05c": 
        {
            "label": "pst",
             "coord": [4815, 5888],
            "color": "#1f77b4",            
        }, 
    "gene-YPPCP1.06": 
        {
            "label": "Hypothetical protein",
            "coord": [6006, 6422],
            "color": "#1f77b4",            
        },     
    "gene-YPPCP1.07": 
        {
            "label": "pla",
            "coord": [6665, 7603],
            "color" : "#ff7f0e",
        }, 
    "gene-YPPCP1.08c": 
        {
            "label": "Probable transcriptional regulator",
            "coord": [7790, 8089],
            "color" : "#ff7f0e",            
        }, 
    "gene-YPPCP1.09c": 
        {
            "label": "Hypothetical protein",
            "coord": [8089, 8436],
            "color" : "#ff7f0e",            
        },    
}

In [None]:
# Select the plasmid genes from the reference
ppcp1_genes = list(ref_gene_df[ref_gene_df["seqname"] == target_seqname].index)

# Select the depth
ppcp1_genes_df = dep_gene_df[ppcp1_genes]

display(ppcp1_genes_df)

---
## Calculate Depletion Ratio

In [None]:
# Gene ID
depleted_gene = "gene-YPPCP1.07" # pla
baseline_gene = "gene-YPPCP1.05c" #pst

# Gene Label
depleted_gene_label = gene_labels[depleted_gene]["label"]
baseline_gene_label = gene_labels[baseline_gene]["label"]

# Gene Depth
baseline_dep = list(ppcp1_genes_df[baseline_gene])
depleted_dep = list(ppcp1_genes_df[depleted_gene])

# Accessory plotting variables
sample = list(ppcp1_genes_df.index)
#timetree_num_date = metadata_df["timetree_num_date"][ppcp1_genes_df.index]
timetree_num_date = metadata_df["date_mean"][ppcp1_genes_df.index]
strain = metadata_df["strain"][ppcp1_genes_df.index]
#root_rtt_dist = metadata_df["root_rtt_dist"][ppcp1_genes_df.index]
country = metadata_df["country"][ppcp1_genes_df.index]
ratio = []
population = metadata_df["population"][ppcp1_genes_df.index]

for depleted,baseline in zip(depleted_dep, baseline_dep):
    if baseline == 0:
        ratio.append(0)
    else:
        ratio.append(depleted/baseline)

data = {
    "baseline": baseline_dep, 
    "depleted": depleted_dep, 
    "ratio": ratio, 
    "sample" : sample,
    "timetree_num_date": timetree_num_date,
    "strain" : strain,
    "country" : country,
    "population" : population,
    #"root_rtt_dist" : root_rtt_dist,
    }

ratio_df = pd.DataFrame(data)
ratio_df.set_index("sample", inplace=True)
ratio_df.sort_values(by="ratio", inplace=True)
#ratio_df.sort_values(by="baseline", inplace=True)

display(ratio_df)

## Filter Genomes

In [None]:
# Keep only high cov samples
ppcp1_genes_orig_df = copy.deepcopy(ppcp1_genes_df)
ratio_orig_df = copy.deepcopy(ratio_df)


high_cov_samples = list(dep_rep_df[dep_rep_df[target_locus] >= 10].index)
ppcp1_genes_df = ppcp1_genes_df.loc[high_cov_samples]

ratio_df =ratio_df.loc[high_cov_samples]

# Exclude Danish samples
#ppcp1_genes_df.drop(index=exclude_samples, inplace=True)
#ratio_df.drop(index=exclude_samples, inplace=True)

display(ppcp1_genes_df)
display(ratio_df)


## Plot Gene coverage across plasmid

In [None]:
# ---------------------------------------
# PLOT SETUP
TARGET_RES = [480, 480]
DPI=200
FIGSIZE=[TARGET_RES[0] / DPI, TARGET_RES[1] / DPI]
FONTSIZE=5
LOCUS_DEP = 10
RATIO = 0.7
plt.rc('font', size=FONTSIZE)

fig, axes = plt.subplots(1,1,figsize=FIGSIZE, dpi=DPI)

ax = axes
df = ppcp1_genes_df[ppcp1_genes_df["gene-YPPCP1.03"] > 10]
data = {
    col:list(df[col]) for col in df.columns
}
df = pd.DataFrame(data)
colors = [gene_labels[gene]["color"] for gene in gene_labels]
sns.boxplot(
    ax=ax,
    data=df,
    flierprops=flierprops,
    palette=colors,
)
xtick_labels = ["{} {}".format(gene_labels[gene]["label"],gene_labels[gene]["coord"])  for gene in gene_labels]
ax.set_xticklabels(xtick_labels)
plt.xticks(rotation=45, ha="right")
ax.set_xlabel("Gene")
ax.set_ylabel("Average Depth (X)")
plt.title("Sequencing Depth of pPCP1 Genes (N={})".format(len(df)))

out_path = os.path.join(out_dir, "ppcp1_gene_depth")
plt.savefig(out_path + ".png", dpi=DPI, bbox_inches = "tight", facecolor="white")
plt.savefig(out_path + ".svg", dpi=DPI, bbox_inches = "tight")

---
# Detect K Cluster

In [None]:
X = np.array(ratio_df["ratio"]).reshape(-1, 1)

### Elbow Method

In [None]:
# calculate distortion for a range of number of cluster
distortions = []

iter = 10
for i in range(1, iter):
    km = KMeans(
        n_clusters=i, 
        init='random',
        n_init=10, 
        max_iter=300,
        tol=1e-04, 
        random_state=0
    )
    km.fit(X)
    distortions.append(km.inertia_)

# ---------------------------------------
# PLOT SETUP
TARGET_RES = [480, 480]
DPI=200
FIGSIZE=[TARGET_RES[0] / DPI, TARGET_RES[1] / DPI]
FONTSIZE=5
LOCUS_DEP = 10
RATIO = 0.7
plt.rc('font', size=FONTSIZE)

fig, ax = plt.subplots(1,1,figsize=FIGSIZE, dpi=DPI)

sns.lineplot(
    ax=ax,
    x=range(1, iter), 
    y=distortions,
)
sns.scatterplot(
    ax=ax,
    x=range(1, iter), 
    y=distortions,
    s=20,
    ec="black",
    lw=0.5,
)

ax.set_xlabel('Number of Clusters')
ax.set_ylabel('Distortion')
fig.suptitle("Cluster Detection by Elbow Method")

out_path = os.path.join(out_dir, "elbow_{}_{}".format(depleted_gene_label, baseline_gene_label))
plt.savefig(out_path + ".png", bbox_inches = "tight", facecolor="white")
plt.savefig(out_path + ".svg", bbox_inches = "tight")

## Silhoutte Clusters

In [None]:
range_n_clusters = [2, 3, 4, 5, 6]

highest_score = 0
highest_cluster = 0

for n_clusters in range_n_clusters:
    
    clusterer = KMeans(n_clusters=n_clusters, random_state=10)
    cluster_labels = clusterer.fit_predict(X)

    silhouette_avg = silhouette_score(X, cluster_labels)
    print("For n_clusters =", n_clusters,
          "The average silhouette_score is :", silhouette_avg)
    
    if silhouette_avg > highest_score:
        highest_cluster = n_clusters
        highest_score = silhouette_avg

print(highest_cluster)
# To Plot see: https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html

---
# Cluster

### Manual Override

In [None]:
highest_cluster = 1

In [None]:
km = KMeans(
    n_clusters=highest_cluster, init='random',
    n_init=10, max_iter=300, 
    tol=1e-04, random_state=0
)
y_km = km.fit_predict(X)

for ratio, cluster, sample in zip(X, y_km, ratio_df.index):
    ratio_df.at[sample, "cluster"] = cluster
    
#display(ratio_df)
cluster_max_ratios = {cluster:0 for cluster in range(0, highest_cluster)}

for cluster in range(0, highest_cluster):
    ratios = [r[0] for r in X[y_km == cluster]]
    cluster_max_ratios[cluster] = max(ratios)

cluster_order = list({cluster:ratio for cluster,ratio in sorted(cluster_max_ratios.items(), key=lambda item: item[1])}.keys())

for rec in ratio_df.iterrows():
    sample = rec[0]
    cluster = rec[1]["cluster"]
    new_cluster = cluster_order.index(cluster)
    ratio_df.at[sample, "cluster"] = new_cluster
    
    
display(ratio_df)

---
## Plot

In [None]:
# ---------------------------------------
# PLOT SETUP
TARGET_RES = [2400, 720]
DPI=400
FIGSIZE=[TARGET_RES[0] / DPI, TARGET_RES[1] / DPI]
FONTSIZE=5
LOCUS_DEP = 10
RATIO = 0.7
plt.rc('font', size=FONTSIZE)

# 2 x2
#fig, axes = plt.subplots(2,2,figsize=FIGSIZE, dpi=DPI)
# 1 x 3
fig, axes = plt.subplots(1,3,figsize=FIGSIZE, dpi=DPI)

fig.subplots_adjust(wspace=0.4, hspace=0.4)

#------------------------------------------------
# Regresion
#ax = axes[0][0]
ax = axes[0]

ax.set_title("A", fontsize=FONTSIZE * 1.5, fontweight="bold")

df = ratio_df[ratio_df["cluster"] == cluster]
#for cluster in range(0, highest_cluster):
for country in set(ratio_df["country"]):
    df = ratio_df[ratio_df["country"] == country]
    sns.regplot(
        ax=ax,
        data=df,
        x="baseline",
        y="depleted", 
        #ci=None,
        scatter_kws={"s": 5, "ec":"black", "lw": 0.25, "zorder":1,},   
        line_kws={"zorder":0, },
    )
    
ax.set_xlabel("$\it{" + "{}".format(baseline_gene_label) + "}$ Depth (X)")
ax.set_ylabel("$\it{" + "{}".format(depleted_gene_label) + "}$ Depth (X)")
ax.set_xlim(0, max(ratio_df["baseline"] + 10))

reg_handles = []
for cluster in range(highest_cluster):
    cluster_df = ratio_df[ratio_df["cluster"] ==cluster]
    x = list(cluster_df["baseline"])
    y = list(cluster_df["depleted"])
    slope, intercept, r_value, p_value, stderr, = scipy.stats.linregress(x,y)
    
    p_sig = ""
    if p_value < 0.05:
        p_sig = "*"
        
    r_squared = r_value * r_value
    r_squared_pretty = str(round(r_squared, 2))
    
    intercept_pretty = "+ " + str(round(intercept, 2)).replace("+","")
    if "-" in str(intercept):
        intercept_pretty = "- " + str(round(intercept, 2)).replace("-","")
        
    reg_handles.append(
        lines.Line2D([0], [0], 
        label = (
            "R$^{2}$ : " + "{}".format(r_squared_pretty)
            #"y = {}x {}".format(round(slope,2),intercept_pretty)
            + "\np   : {:0.2e}{}".format(p_value, p_sig)
        ),
        color=D3_COL_PAL[cluster])
    )

legend = ax.legend(handles=reg_handles, loc=2, fontsize=FONTSIZE * 0.75, edgecolor="black")
legend.get_frame().set_linewidth(0.25)

#------------------------------------------------
# Histogram
#ax = axes[0][1]
ax = axes[1]

ax.set_title("B", fontsize=FONTSIZE * 1.5, fontweight="bold")
sns.histplot(
    ax=ax,
    data=ratio_df,
    x="ratio",
    bins=len(ratio_df),
    #hue="cluster",
    hue="country",
    palette = D3_COL_PAL[0:len(set(ratio_df["country"]))],
    #palette = D3_COL_PAL[0:highest_cluster],
    #alpha=0.75,
    zorder=0,
)
sns.kdeplot(
    ax=ax,
    data=ratio_df,
    x="ratio",
    #hue="cluster",
    hue = "country",
    #palette = D3_COL_PAL[0:highest_cluster],
    palette = D3_COL_PAL[0:len(set(ratio_df["country"]))],    
    #alpha=1.0,
    fill=True,
    zorder=1,
)



ax.set_xlim(0)
ax.set_ylabel("Number of Samples")
ax.set_xlabel(
    "$\it{" 
    + "{}".format(depleted_gene_label)
    + "}$ Depth (X)  / $\it{"
    + "{}".format(baseline_gene_label)
    + "}$ Depth (X)")
ax.legend().remove()

#------------------------------------------------
# Boxplot

"""
ax = axes[1][0]
ax.set_title("C", fontsize=FONTSIZE * 1.5, fontweight="bold")

sns.boxplot(
    ax=ax,
    data=ratio_df,
    x="cluster",
    y="ratio",
    flierprops=flierprops,
)
ax.set_xlabel("Cluster")
ax.set_xticklabels(["Depleted", "Normal"])
ax.set_ylabel(
    "$\it{" 
    + "{}".format(depleted_gene_label)
    + "}$ Depth (X)  / $\it{"
    + "{}".format(baseline_gene_label)
    + "}$ Depth (X)")
"""

# --------------------------------------
# TIMELINE
#ax = axes[1][1]
ax = axes[2]
ax.set_title("C", fontsize=FONTSIZE * 1.5, fontweight="bold")
sns.scatterplot(
    ax=ax, 
    data=ratio_df, 
    x="timetree_num_date", 
    y="ratio",
    s=10,
    ec="black",
    #palette = D3_COL_PAL[0:highest_cluster],
    #hue="cluster",
    hue="country",
    palette = D3_COL_PAL[0:len(set(ratio_df["country"]))],     
)
ax.legend().remove()

#------------------------------------------------
# Date Confidences
"""for sample in ratio_df.index:
    conf = [float(c) for c in metadata_df["timetree_num_date_confidence"][sample].strip("[]").split(",")]
    ratio = ratio_df["ratio"][sample]
    cluster = int(ratio_df["cluster"][sample])
    color = D3_COL_PAL[cluster]
    ax.add_patch(
        patches.Rectangle(
            (conf[0], ratio), conf[1] - conf[0], 0.02, linewidth=0, facecolor=color, alpha=0.20, zorder=0))"""


#ax.set_xlim(1150,1850)
ax.set_xlabel("Date")
ax.set_ylabel(
    "$\it{" 
    + "{}".format(depleted_gene_label)
    + "}$ Depth (X)  / $\it{"
    + "{}".format(baseline_gene_label)
    + "}$ Depth (X)")

#------------------------------------------------
# Legend
#normal_patch = patches.Patch(color=D3_COL_PAL[1], label='Normal $\it{pla}$')
normal_patch = patches.Patch(color=D3_COL_PAL[1], label='United States of America')
#low_patch = patches.Patch(color=D3_COL_PAL[0], label='Depleted $\it{pla}$')
low_patch = patches.Patch(color=D3_COL_PAL[0], label='Madagascar')
#legend = fig.legend(handles=[normal_patch, low_patch], bbox_to_anchor=(0.6,0.98), edgecolor="black")
legend = fig.legend(handles=[normal_patch, low_patch], bbox_to_anchor=(0.575,1.2), edgecolor="black")
legend.get_frame().set_linewidth(0.5)

"""fig.suptitle(
    "Relative depletion of the plasminogen activator ($\it{pla}$) virulence factor", 
    x=0.5, 
    y=1.3, 
    fontsize=FONTSIZE * 2,
)"""

out_path = os.path.join(out_dir, "depletion_{}".format(depleted_gene_label))
plt.savefig(out_path + ".png", bbox_inches = "tight", facecolor="white")
plt.savefig(out_path + ".svg", bbox_inches = "tight")

## Add Depletion Statistics to Metadata

In [None]:
metadata_df["baseline_{}".format(baseline_gene_label)] = [NO_DATA_CHAR] * len(metadata_df)
metadata_df["depleted_{}".format(depleted_gene_label)] = [NO_DATA_CHAR] * len(metadata_df)
metadata_df["ratio_{}_{}".format(depleted_gene_label, baseline_gene_label)] = [NO_DATA_CHAR] * len(metadata_df)
metadata_df["cluster_{}_{}".format(depleted_gene_label, baseline_gene_label)] = [NO_DATA_CHAR] * len(metadata_df)

# Including filter
#for rec in ratio_df.iterrows():
# Excluding filtering
for sample in ratio_orig_df.index:
    metadata_df.at[sample, "baseline_{}".format(baseline_gene_label)] = ratio_orig_df["baseline"][sample]
    metadata_df.at[sample, "depleted_{}".format(depleted_gene_label)] = ratio_orig_df["depleted"][sample]
    metadata_df.at[sample, "ratio_{}_{}".format(depleted_gene_label, baseline_gene_label)] = ratio_orig_df["ratio"][sample]
    if sample in ratio_df.index:
        metadata_df.at[sample, "cluster_{}_{}".format(depleted_gene_label, baseline_gene_label)] = ratio_df["cluster"][sample]
    
display(metadata_df)

## Draw Tree

In [None]:
for c in tree.get_terminals():
    sample = c.name
    cluster = metadata_df["cluster_{}_{}".format(depleted_gene_label, baseline_gene_label)][sample]
    if cluster == NO_DATA_CHAR:
        color = "#c4c4c4"
    else:
        color = D3_COL_PAL[int(cluster)]
    c.color = color

# ---------------------------------------
# PLOT SETUP
TARGET_RES = [960, 960]
DPI=200
FIGSIZE=[TARGET_RES[0] / DPI, TARGET_RES[1] / DPI]
FONTSIZE=4
LOCUS_DEP = 10
RATIO = 0.7
plt.rc('font', size=FONTSIZE)

fig, ax = plt.subplots(1,figsize=FIGSIZE, dpi=DPI)

Phylo.draw(
    tree, 
    axes=ax, 
    show_confidence=False, 
    label_func = lambda x: metadata_df["country_date_strain"][x.name] if x.is_terminal() else "", 
    do_show=False
)

ax.axis("off")
#ax.set_ylabel("")
#ax.set_yticklabels([])
#ax.set_xlabel("Years")

#------------------------------------------------
# Legend
normal_patch = patches.Patch(color=D3_COL_PAL[1], label='Normal $\it{pla}$')
low_patch = patches.Patch(color=D3_COL_PAL[0], label='Depleted $\it{pla}$')
unknown_patch = patches.Patch(color="#c4c4c4", label='Unknown $\it{pla}$')
legend = fig.legend(handles=[normal_patch, low_patch, unknown_patch], loc=6, edgecolor="black")
legend.get_frame().set_linewidth(0.5)

out_path = os.path.join(out_dir, "depletion_tree_{}".format(depleted_gene_label))
plt.savefig(out_path + ".png", bbox_inches = "tight", facecolor="white", dpi=DPI)
plt.savefig(out_path + ".svg", bbox_inches = "tight")

# Export

## Metadata

In [None]:
out_path_metadata = os.path.join(out_dir, "metadata.tsv")
metadata_df.to_csv(out_path_metadata, sep="\t", index=True)

## Timetree

In [None]:
out_tree = copy.deepcopy(tree)

metadata_to_comment(out_tree, metadata_df)    
out_path_tree_nex = os.path.join(out_dir, "all.timetree.nex")
Phylo.write(out_tree, out_path_tree_nex, "nexus")

## Prior Distribution

In [None]:
# ---------------------------------------
# PLOT SETUP
TARGET_RES = [480, 480]
DPI=200
FIGSIZE=[TARGET_RES[0] / DPI, TARGET_RES[1] / DPI]
FONTSIZE=4
LOCUS_DEP = 10
RATIO = 0.7
plt.rc('font', size=FONTSIZE)

priors = {
    "Denmark": {"mu": 1330, "range" : 230}, 
    "Rostov2033": {"mu": 1767.5, "range": 2.75 },
    "Azov38": {"mu":2021-471.0, "range": 75.0},
    "CHE1" : {"mu": 2021-371.0, "range" :75.0}, 
    "BED" : {"mu": 2021-423.5, "range" : 18.75}
}

for prior in priors:
          
    fig, ax = plt.subplots(1,figsize=FIGSIZE, dpi=DPI)


    mu = priors[prior]["mu"]
    date_range = priors[prior]["range"]
    sigma = date_range / 2
    x = np.linspace(mu - 3*sigma, mu + 3*sigma, 100)
    y = scipy.stats.norm.pdf(x, mu, sigma)

    ax.fill_between(x, y, alpha=0.25)

    sns.lineplot(ax=ax, x=x, y=y)
    #ax.plot(x, scipy.stats.norm.pdf(x, mu, sigma))
    ax.set_title("{} Collection Date Prior".format(prior), fontsize=FONTSIZE * 1.5, y=1.1)
    ax.set_xlabel("Date", fontsize=FONTSIZE * 1.5)
    ax.set_ylabel("Density", fontsize=FONTSIZE * 1.5)

    ax.axvline(x=mu, color="black", label="Mean Site Date")
    ax.axvline(x=mu - date_range, color="black", linestyle="--", label="Site Date Range")
    ax.axvline(x=mu + date_range, color="black", linestyle="--",)
    ax.legend(bbox_to_anchor=(0.30,1), edgecolor="black").get_frame().set_linewidth(0.5)


#out_path = os.path.join(out_dir, "prior_denmark")
#plt.savefig(out_path + ".png", bbox_inches = "tight", facecolor="white")
#plt.savefig(out_path + ".svg", bbox_inches = "tight")