In [None]:
import locale

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.cluster.hierarchy import dendrogram
from sklearn import cluster, metrics, preprocessing

In [None]:
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette(['#9e0059', '#6da7de', '#ee266d', '#dee000', '#eb861e', '#63c5b5'])
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

### Convert gene names to UniProt ids

In [None]:
gene_names = set()
with open("../data/external/crocosphaera_watsonii_wh_8501_200811.fasta") as f_in:
    for line in f_in:
        if line.startswith(">"):
            gene_names.add(line.split()[2])

gene_to_uniprot = pd.read_csv(
    "../data/external/crocosphaera_watsonii_wh_8501_uniprot_20210912.tab",
    sep="\t", usecols=["Entry", "Gene names"])
gene_to_uniprot["Gene names"] = gene_to_uniprot["Gene names"].str.split()
gene_to_uniprot = (gene_to_uniprot.explode("Gene names")
                   .set_index("Gene names")
                   .squeeze()
                   .to_dict())
(pd.Series(gene_to_uniprot.values()).drop_duplicates()
 .to_csv("../data/interim/crocosphaera_watsonii_wh_8501_population.txt",
         index=False, header=False))

### Time series preprocessing

In [None]:
quickgo_annotations = pd.read_csv(
    "../data/external/crocosphaera_watsonii_wh_8501_quickgo.tsv", sep="\t",
    usecols=["GENE PRODUCT ID", "GO TERM"])
quickgo_annotations = (quickgo_annotations.groupby("GENE PRODUCT ID")
                       ["GO TERM"].apply(lambda x: ";".join(x)))
quickgo_annotations.to_csv(
    "../data/interim/crocosphaera_watsonii_wh_8501_go_map.tsv", sep="\t",
    header=None)

In [None]:
data_columns = [f"151222_WH8501diel_T{t}_2ug" for t in range(1, 17)]
data = pd.read_csv(
    "../data/processed/160214_Crocodiel_Full_rawdata_noheader_fig_may7annotation.csv",
    usecols=["Identified Proteins (1170)", "Molecular Weight", *data_columns])
for col in data_columns:
    data[col] = data[col].apply(locale.atof)
data["uniprot_id"] = (data["Identified Proteins (1170)"].str.split().str[1]
                      .map(gene_to_uniprot))
data = (data[data["Molecular Weight"] != "?"]
        .drop(columns=["Identified Proteins (1170)", "Molecular Weight"])
        .sort_values("uniprot_id")
        .set_index("uniprot_id")
        .drop_duplicates())

In [None]:
data_standardized = (preprocessing.StandardScaler()
                     .fit_transform(data[data_columns].T).T)

In [None]:
pairwise_timepoints = metrics.pairwise_distances(
    data_standardized.T, n_jobs=-1)

In [None]:
width = 7
height = width / 1.618
fig, ax = plt.subplots(figsize=(width, height))

sns.heatmap(pairwise_timepoints, square=True, ax=ax)

ax.set_xlabel('Timepoint')
ax.set_ylabel('Timepoint')

plt.savefig("pairwise_timepoints.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

### Time series clustering

In [None]:
n_clusters = 6
clusterer = cluster.AgglomerativeClustering(
    n_clusters=n_clusters, compute_distances=True)

In [None]:
cluster_labels = clusterer.fit_predict(data_standardized)

In [None]:
width = 7
height = width / 1.618
n_row = 2
n_col = n_clusters // n_row
fig, axes = plt.subplots(n_row, n_col, sharex=True, sharey=True,
                         figsize=(width * n_col, height * n_row))

for i, ax in enumerate(axes.ravel()):
    cluster_timepoints = data_standardized[cluster_labels == i].T
    ax.plot(cluster_timepoints, alpha=0.1, color="black")
    ax.plot(np.median(cluster_timepoints, axis=1), c="#9e0059", lw=5)

    # Highlight night period.
    ax.axvspan(4, 11, color="lightgray")
    
    ax.set_title(f"Cluster {i}")
    
    sns.despine(ax=ax)

for ax in axes[-1]:
    ax.set_xlabel("Timepoint")
for ax in axes[:, 0]:
    ax.set_ylabel("Standardized abundance")

plt.tight_layout()
    
plt.savefig("clusters.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
width = 7
height = width / 1.618
fig, ax = plt.subplots(figsize=(width, height))

counts = np.zeros(clusterer.children_.shape[0])
n_samples = len(clusterer.labels_)
for i, merge in enumerate(clusterer.children_):
    current_count = 0
    for child_idx in merge:
        if child_idx < n_samples:
            current_count += 1
        else:
            current_count += counts[child_idx - n_samples]
    counts[i] = current_count
linkage = np.column_stack(
    [clusterer.children_, clusterer.distances_, counts]).astype(float)

plot_labels = -np.ones(2 * n_samples - 1, np.int8)
plot_labels[:n_samples] = cluster_labels
for parent_i, child_i in enumerate(clusterer.children_[:, 0], n_samples):
    plot_labels[parent_i] = plot_labels[child_i]

dendrogram(linkage, p=4, truncate_mode="level", color_threshold=30,
           distance_sort=True, show_leaf_counts=False, leaf_rotation=0,
           leaf_label_func=lambda i: plot_labels[i],
           ax=ax, above_threshold_color="black")

ax.set_xlabel("Cluster label")
ax.set_ylabel("Euclidean distance")

sns.despine()

plt.savefig("dendrogram.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

### GO enrichment

In [None]:
for i in range(n_clusters):
    (data.iloc[cluster_labels == i].index.to_series()
     .to_csv(f"../data/interim/crocosphaera_watsonii_wh_8501_cluster{i}.txt",
             index=False, header=False))
    ! python find_enrichment.py \
        ../data/interim/crocosphaera_watsonii_wh_8501_cluster{i}.txt \
        ../data/interim/crocosphaera_watsonii_wh_8501_population.txt \
        ../data/interim/crocosphaera_watsonii_wh_8501_go_map.tsv \
        --pval=0.05 \
        --method=fdr_bh \
        --pval_field=fdr_bh \
        --outfile=../data/processed/crocosphaera_watsonii_wh_8501_go{i}.csv