In [None]:
import math

import matplotlib as mpl
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.pyplot import cm
from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import dendrogram
from sklearn import cluster, metrics, preprocessing

In [None]:
# Plot styling.
plt.style.use(["seaborn-v0_8-white", "seaborn-v0_8-paper"])
colors = ["#9e0059", "#6da7de", "#ee266d", "#dee000", "#eb861e", "#63c5b5"]
sns.set_palette(colors)
sns.set_context("paper")

### Convert gene names to UniProt ids

In [None]:
gene_names = set()
with open(
    "../data/external/crocosphaera_watsonii_wh_8501_200811.fasta"
) as f_in:
    for line in f_in:
        if line.startswith(">"):
            gene_names.add(line.split()[2])

gene_to_uniprot = pd.read_csv(
    "../data/external/crocosphaera_watsonii_wh_8501_uniprot_20210912.tab",
    sep="\t",
    usecols=["Entry", "Gene names"],
)
gene_to_uniprot["Gene names"] = gene_to_uniprot["Gene names"].str.split()
gene_to_uniprot = (
    gene_to_uniprot.explode("Gene names")
    .set_index("Gene names")
    .squeeze()
    .to_dict()
)
population = pd.Series(gene_to_uniprot.values()).drop_duplicates()
population.to_csv(
    "../data/interim/crocosphaera_watsonii_wh_8501_population.txt",
    index=False,
    header=False,
)

### Time series preprocessing

In [None]:
quickgo_annotations = pd.read_csv(
    "../data/external/crocosphaera_watsonii_wh_8501_quickgo.tsv",
    sep="\t",
    usecols=["GENE PRODUCT ID", "GO TERM"],
).groupby("GENE PRODUCT ID")["GO TERM"].apply(";".join)
quickgo_annotations.to_csv(
    "../data/interim/crocosphaera_watsonii_wh_8501_go_map.tsv",
    sep="\t",
    header=None,
)

In [None]:
data_columns = [f"151222_WH8501diel_T{t}_2ug" for t in range(1, 17)]
data = pd.read_csv(
    "../data/processed/160214_Crocodiel_Full_rawdata_noheader_fig_may7annotation.csv",
    usecols=["Identified Proteins (1170)", "Molecular Weight", *data_columns],
    thousands=",",
)

# Map gene names to UniProt ids.
data["uniprot_id"] = (
    data["Identified Proteins (1170)"].str.split().str[1]
    .map(gene_to_uniprot)
)

# Drop proteins without a molecular weight specified.
data = (
    data[data["Molecular Weight"] != "?"]
    .drop(columns=["Identified Proteins (1170)", "Molecular Weight"])
    .sort_values("uniprot_id")
    .set_index("uniprot_id")
    .drop_duplicates()
)

# Convert column names to times.
time_map = {
    0: pd.Timedelta(days=0, hours=0, minutes=0),
    1: pd.Timedelta(days=0, hours=1, minutes=30),
    2: pd.Timedelta(days=0, hours=3, minutes=0),
    3: pd.Timedelta(days=0, hours=4, minutes=30),
    4: pd.Timedelta(days=0, hours=5, minutes=30),
    5: pd.Timedelta(days=0, hours=6, minutes=30),
    6: pd.Timedelta(days=0, hours=8, minutes=30),
    7: pd.Timedelta(days=0, hours=10, minutes=30),
    8: pd.Timedelta(days=0, hours=12, minutes=30),
    9: pd.Timedelta(days=0, hours=14, minutes=30),
    10: pd.Timedelta(days=0, hours=15, minutes=30),
    11: pd.Timedelta(days=0, hours=16, minutes=40),
    12: pd.Timedelta(days=0, hours=18, minutes=30),
    13: pd.Timedelta(days=0, hours=20, minutes=30),
    14: pd.Timedelta(days=0, hours=22, minutes=30),
    15: pd.Timedelta(days=1, hours=0, minutes=0),
}
data = data.rename(
    columns=lambda col: time_map[
        int(col[len("151222_WH8501diel_T"):-len("_2ug")]) - 1
    ]
)

In [None]:
data_standardized = pd.DataFrame(
    columns=data.columns,
    data=preprocessing.StandardScaler().fit_transform(data.T).T,
)

### Time series clustering

In [None]:
def plot_dendrogram(
    clusterer,
    cluster_labels,
    distance_threshold,
    filename,
    p=4,
    distance_sort="ascending",
    colors=colors,
):
    # Dendrogram styling.
    hierarchy.set_link_color_palette(
        [mpl.colors.rgb2hex(rgb) for rgb in colors]
    )

    width = 3.5
    height = width / 1.618
    fig, ax = plt.subplots(figsize=(width * 2, height))
    
    counts = np.zeros(clusterer.children_.shape[0])
    n_samples = len(clusterer.labels_)
    for i, merge in enumerate(clusterer.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count
    linkage = np.column_stack(
        [clusterer.children_, clusterer.distances_, counts]
    ).astype(float)
    
    plot_labels = [None] * (2 * n_samples)
    plot_labels[:n_samples] = cluster_labels
    for parent_i, child_i in enumerate(clusterer.children_[:, 0], n_samples):
        plot_labels[parent_i] = plot_labels[child_i]
    
    dendrogram(
        linkage,
        p=p,
        truncate_mode="level",
        color_threshold=distance_threshold,
        distance_sort=distance_sort,
        show_leaf_counts=False,
        leaf_rotation=0,
        leaf_label_func=lambda i: plot_labels[i],
        ax=ax,
        above_threshold_color="black",
    )
    
    ax.set_xlabel("Cluster label")
    ax.set_ylabel("Euclidean distance")
    
    sns.despine()
    
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()

    hierarchy.set_link_color_palette(None)


def plot_time_series(data, cluster_labels, filename, colors=colors):
    width = 3.5
    height = width / 1.618
    n_col = 3
    n_row = math.ceil(len(np.unique(cluster_labels)) / n_col)
    fig, axes = plt.subplots(
        n_row,
        n_col,
        sharex=True,
        sharey=True,
        figsize=(width * 2, height * n_row)
    )

    x = [c.total_seconds() for c in data.columns]
    labels = sorted(set(cluster_labels))
    for label, c, ax in zip(labels, colors, axes.ravel()):
        cluster_timepoints = data[cluster_labels == label].T
        ax.plot(x, cluster_timepoints.values, alpha=0.1, color="black")
        ax.plot(x, np.median(cluster_timepoints, axis=1), c=c, lw=5)

        ax.set_xlim(x[0], x[-1])
    
        # Highlight night period.
        ax.axvspan(
            x[4],
            x[11],
            color="lightgray",
        )

        # Set x ticks every 3 hours.
        ax.xaxis.set_major_locator(mticker.MultipleLocator(60 * 60 * 3))
        ax.xaxis.set_major_formatter(lambda x, _: int(x // 3600))
        
        ax.set_title(f"Cluster {label}")
        
        sns.despine(ax=ax)

    axes_col = axes[-1] if n_row > 1 else axes
    axes_row = axes[:, 0] if n_row > 1 else axes[0:1]
    for ax in axes_col:
        ax.set_xlabel("Hours")
    for ax in axes_row:
        ax.set_ylabel("Standardized\nprotein abundance")
    
    # plt.tight_layout()
        
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
distance_threshold = 30
clusterer = cluster.AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=distance_threshold,
    compute_distances=True,
)
cluster_labels = clusterer.fit_predict(data_standardized)

In [None]:
# Remap cluster labels to consecutive order.
mapper = {0: "6", 1: "5", 3: "4", 4: "3", 2: "2", 5: "1"}
cluster_labels = np.asarray([mapper[i] for i in cluster_labels])

# Export cluster assignments.
pd.DataFrame(
    index=data.index, data=cluster_labels, columns=["cluster_id"]
).to_csv("cluster_ids.csv")

In [None]:
# Plot dendrogram to evaluate the number of clusters.
plot_dendrogram(
    clusterer, cluster_labels, distance_threshold, "dendrogram_all.png"
)
# Plot the protein abundances for each of the clusters.
plot_time_series(data_standardized, cluster_labels, "clusters_all.png")

In [None]:
# Dig deeper into cluster 1, associated with nitrogen fixing.
subdistance_threshold = 11
subclusterer = cluster.AgglomerativeClustering(
    n_clusters=None, distance_threshold=subdistance_threshold
)
submask = cluster_labels == "1"
subdata_standardized = data_standardized[submask]
subcluster_labels = subclusterer.fit_predict(subdata_standardized)

In [None]:
# Remap cluster labels to consecutive order.
submapper = {0: "1.3", 2: "1.2", 1: "1.1", 3: "1.4"}
subcluster_labels = np.asarray([submapper[i] for i in subcluster_labels])

# Export cluster assignments.
pd.DataFrame(
    index=data.index[submask], data=subcluster_labels, columns=["cluster_id"]
).to_csv("subcluster_ids.csv")

In [None]:
subcolors = ["#771145", "#9e0059", "#cf87aa", "black"]

# Plot dendrogram to evaluate the number of clusters.
plot_dendrogram(
    subclusterer,
    subcluster_labels,
    subdistance_threshold,
    "dendrogram_sub.png",
    2,
    "descending",
    subcolors,
)
# Plot the protein abundances for each of the clusters.
plot_mask = subcluster_labels != "1.4"
plot_time_series(
    subdata_standardized[plot_mask],
    subcluster_labels[plot_mask],
    "clusters_sub.png",
    subcolors,
)

### GO enrichment

In [None]:
for labeling, label_data in (
    (cluster_labels, data), (subcluster_labels, data.loc[submask])
):
    for label in sorted(set(labeling)):
        cluster_i = label_data.iloc[labeling == label].index.to_series()
        cluster_i.to_csv(
            f"../data/interim/crocosphaera_watsonii_wh_8501_cluster_{label}.txt",
            index=False,
            header=False,
        )
        ! python find_enrichment.py \
            ../data/interim/crocosphaera_watsonii_wh_8501_cluster_{label}.txt \
            ../data/interim/crocosphaera_watsonii_wh_8501_population.txt \
            ../data/interim/crocosphaera_watsonii_wh_8501_go_map.tsv \
            --pval=0.05 \
            --method=fdr_bh \
            --pval_field=fdr_bh \
            --outfile=../data/processed/crocosphaera_watsonii_wh_8501_go_{label}.csv