# Cocoa fermentation data analysis
This notebook contains the analysis of the 14 samples from the original study by [Almeida & De Martinis](https://doi.org/10.1128/aem.00584-21). All the samples were subject to the entire MOSHPIT MAG reconstruction pipeline available in QIIME 2. Here, we focus on downstream analysis of the resulting feature tables.

In [None]:
# silence pandas' warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import qiime2 as q2
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import os
import re

from matplotlib.colors import LinearSegmentedColormap
from qiime2.plugins import taxa, feature_table
from typing import List, Dict

%matplotlib inline

In [None]:
def plot_diversity(
    fp: str, 
    metadata: pd.DataFrame, 
    x_col: str, 
    hue: str, 
    ax,
    x_label: str,
    y_label: str,
    title: str
):
    """
    Plots diversity metrics over a specified x-axis variable using a line plot.

    Parameters:
        - fp (str): File path to the diversity metric artifact.
        - metadata (pd.DataFrame): DataFrame containing metadata to merge with the diversity data.
        - x_col (str): Column name in the DataFrame to use for the x-axis.
        - hue (str): Column name in the DataFrame to use for color encoding.
        - ax (matplotlib.axes.Axes): The axes onto which the plot will be drawn.
        - x_label (str): Label for the x-axis.
        - y_label (str): Label for the y-axis.
        - title (str): Title of the plot.

    Returns:
        - ax_new (matplotlib.axes.Axes): The axes object with the plot.
        - data (pd.DataFrame): Merged DataFrame containing the diversity data and metadata.
    """
    data = q2.Artifact.load(fp).view(pd.Series)
    metric = data.name
    data = data.to_frame().merge(metadata, left_index=True, right_index=True)
    ax_new = sns.lineplot(data=data, x=x_col, y=metric, hue=hue, ax=ax)
    
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    
    return ax_new, data


def plot_diversity_together(
    fps: Dict, 
    metadata: pd.DataFrame,
    seed: str,
    metric: str,
    x_col: str, 
    hue: str, 
    ax,
    x_label: str,
    y_label: str,
    title: str
):
    """
    Plots diversity metrics from multiple files together using a line plot.

    Parameters:
        - fps (Dict): Dictionary where keys are labels for classifiers and values are file paths to the diversity metric data.
        - metadata (pd.DataFrame): DataFrame containing metadata to merge with the diversity data.
        - seed (str): Specific seed value to filter the data.
        - metric (str): Name of the metric to plot on the y-axis.
        - x_col (str): Column name in the DataFrame to use for the x-axis.
        - hue (str): Column name in the DataFrame to use for color encoding.
        - ax (matplotlib.axes.Axes): The axes onto which the plot will be drawn.
        - x_label (str): Label for the x-axis.
        - y_label (str): Label for the y-axis.
        - title (str): Title of the plot.

    Returns:
        - ax_new (matplotlib.axes.Axes): The axes object with the plot.
        - data (pd.DataFrame): Merged DataFrame containing the combined diversity data and metadata.
    """
    data = []
    for k, v in fps.items():
        df = q2.Artifact.load(v).view(pd.Series).to_frame()
        df['classifier'] = k
        data.append(df)

    data = pd.concat(data, axis=0)
    data = data.merge(metadata, left_index=True, right_index=True)
    data = data[data['seed'] == seed]
    
    ax_new = sns.lineplot(data=data, x=x_col, y=metric, hue=hue, ax=ax)
    
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    
    return ax_new, data


def generate_colors(columns: list, keyword: str, base_color_idx: int) -> dict:
    """
    Generates a dictionary of colors for a list of columns based on a keyword and a base color index.

    Parameters:
        - columns (list): List of column names for which colors need to be generated.
        - keyword (str): Keyword to identify which columns should receive sequential colors.
        - base_color_idx (int): Index of the base color in the HUSL palette to use for generating sequential colors.

    Returns:
        - colors (dict): A dictionary mapping each column name to its corresponding color.
    """
    husl_palette = sns.husl_palette(len(columns), h=0.5)
    base_color = husl_palette[base_color_idx]
    
    sequential_cmap = LinearSegmentedColormap.from_list("custom_cmap", [base_color, "white"])
    sequential_colors = [sequential_cmap(i) for i in np.linspace(0, 1.6, len(columns))]
    qualitative_colors = [color for i, color in enumerate(husl_palette) if i != base_color_idx]
    
    colors = {}
    i, j = 0, 0
    
    for col in columns:
        if keyword in col:
            color = sequential_colors[i]
            i += 1
        else:
            color = qualitative_colors[j]
            j += 1
        colors[col] = color
    
    return colors

In [None]:
data_dir = "./data"

metadata_fp = os.path.join(data_dir, "cocoa-metadata.tsv")
metadata = pd.read_csv(metadata_fp, sep="\t", index_col=0)

## Read-based analysis

### Shannon diversity: different taxonomic classifiers
Let's compare diversity metrics based on three different classifiers: Kraken 2, Kaiju and mOTUs 3.

In [None]:
bracken_shannon_fp = os.path.join(data_dir, "bracken_shannon_vector.qza")
kaiju_shannon_fp = os.path.join(data_dir, "kaiju_shannon_vector.qza")
motus_shannon_fp = os.path.join(data_dir, "motus_shannon_vector.qza")

In [None]:
fig1, axes1 = plt.subplots(1, 3, figsize=(15,4))

_, bracken_df = plot_diversity(
    bracken_shannon_fp, metadata, "timepoint", "seed", axes1[0], None, "Shannon diversity", "Kraken 2 + Bracken"
    )
_, kaiju_df = plot_diversity(
    kaiju_shannon_fp, metadata, "timepoint", "seed", axes1[1], "Time", "Shannon diversity", "Kaiju"
    )
_, motus_df = plot_diversity(
    motus_shannon_fp, metadata, "timepoint", "seed", axes1[2], None, "Shannon diversity", "mOTU"
    )

for ax in axes1[1:]:
    ax.set_ylabel(None)

In [None]:
fig1a, axes1a = plt.subplots(1, 1, figsize=(5, 5))

_, _ = plot_diversity_together(
    {"bracken": bracken_shannon_fp, "kaiju": kaiju_shannon_fp, "motus": motus_shannon_fp}, 
    metadata, "Forasteiro", "shannon_entropy", "timepoint", "classifier", axes1a, None, 
    "Shannon diversity", "Forasteiro"
    )

In [None]:
fig1.savefig(os.path.join(data_dir, "figure1.svg"), dpi=300)
fig1a.savefig(os.path.join(data_dir, "figure1a.svg"), dpi=300)

### Relative abundances: time course
We can visualize how abundances of different taxa changed over time using taxa bar plots. We start by removing the reads which were unclassified or assigned as human.

In [None]:
bracken_ft = q2.Artifact.load(os.path.join(data_dir, "bracken_ft_filtered_rarefied.qza"))
bracken_taxonomy = q2.Artifact.load(os.path.join(data_dir, "bracken_taxonomy.qza"))
bracken_ft_filtered, = taxa.methods.filter_table(
    table=bracken_ft,
    taxonomy=bracken_taxonomy,
    exclude="Unclassified,homo",
)

In [None]:
bracken_ft_filtered, = feature_table.methods.filter_features_conditionally(
    table=bracken_ft_filtered,
    abundance=0.005,
    prevalence=0.1
)

Collapse the feature table to the species level.

In [None]:
bracken_ft_collapsed, = taxa.methods.collapse(
    table=bracken_ft_filtered,
    taxonomy=bracken_taxonomy,
    level=8
)
bracken_ft_collapsed = bracken_ft_collapsed.view(pd.DataFrame)

We want to look at all the species of Acetobacter, otherwise we will collapse to the genus level.

In [None]:
cols_new = {}
for col in bracken_ft_collapsed.columns:
    if "s__Acetobacter" in col:
        cols_new[col] = [y[3:] for y in col.split(";")][-1]
    else:
        matches = re.search(r".*(o__([A-Z]\w+))", col)
        cols_new[col] = matches.group(2) if matches else "Unclassified"

bracken_ft_collapsed = bracken_ft_collapsed.rename(
    columns=cols_new, inplace=False
)

In [None]:
bracken_grouped = bracken_ft_collapsed.merge(metadata, left_index=True, right_index=True)
bracken_grouped.head()

In [None]:
# collapse by col name and sum
bracken_grouped = bracken_grouped.groupby(level=0, axis=1).sum()

value_cols = sorted([col for col in bracken_grouped.columns if col not in ('seed', 'timepoint')])
bracken_grouped[value_cols] = bracken_grouped[value_cols].div(bracken_grouped[value_cols].sum(axis=1), axis=0)

In [None]:
bracken_grouped.columns

In [None]:
fig2, axes2 = plt.subplots(1, 2, figsize=(7, 4), sharey=True)

colors = generate_colors(value_cols, "Acetobacter", 2)
for ax, seed in zip(axes2, bracken_grouped['seed'].unique()):
    df_filtered = bracken_grouped[bracken_grouped['seed'] == seed]
    df_melted = df_filtered.melt(id_vars=['timepoint'], value_vars=value_cols, var_name='Category', value_name='Value')
    pivot_table = df_melted.pivot_table(index='timepoint', columns='Category', values='Value', aggfunc='sum')
    
    # Reverse the order of categories
    pivot_table = pivot_table[value_cols[::-1]]
    
    column_colors = [colors[col] for col in pivot_table.columns]
    pivot_table.plot(kind='bar', stacked=True, ax=ax, color=colors)
    ax.set_title(f'{seed}')
    ax.set_xlabel('Timepoint')
    ax.set_ylabel('Normalized Value')

# Disable the legend for the right plot
axes2[1].legend_.remove()
axes2[0].legend_.remove()

handles, labels = axes2[0].get_legend_handles_labels()

# Reverse the order of handles and labels
handles = handles[::-1]
labels = labels[::-1]

# Sort the labels and handles alphabetically
sorted_labels_handles = sorted(zip(labels, handles), key=lambda x: x[0])
sorted_labels, sorted_handles = zip(*sorted_labels_handles)

fig2.legend(sorted_handles, sorted_labels, title='Category', bbox_to_anchor=(1.4, 0.9))

plt.tight_layout()
plt.show()

In [None]:
fig2.savefig(os.path.join(data_dir, "figure2.svg"), dpi=300)

## MAG-based analysis
Here, we will look at the Shannon diversity of the samples based on taxonomic assignments of the recovered, dereplicated MAGs. Moreover, we will look at how the diversity of CAZyme genes identified in those genomes changes over time.

In [None]:
mag_shannon_fp = os.path.join(data_dir, "mag_shannon_vector.qza")
caz_shannon_fp = os.path.join(data_dir, "caz_shannon_vector.qza")

In [None]:
fig3, axes3 = plt.subplots(1, 2, figsize=(10,4))

_, mag_df = plot_diversity(
    mag_shannon_fp, metadata, "timepoint", "seed", axes3[0], "Time", "Shannon diversity", "Kraken 2"
    )
_, caz_df = plot_diversity(
    caz_shannon_fp, metadata, "timepoint", "seed", axes3[1], "Time", "Shannon diversity", "CAZymes [EggNOG]"
    )

In [None]:
fig3.savefig(os.path.join(data_dir, "figure3.svg"), dpi=300)

In [None]:
# combine the two taxonomy and CAZymes
mags_combined = mag_df.copy()
mags_combined.rename(columns={"shannon_entropy": "shannon_taxonomy"}, inplace=True)
mags_combined["shannon_caz"] = caz_df["shannon_entropy"]
mags_combined.head()

In [None]:
fig3a, axes3a = plt.subplots(1, 1, figsize=(5, 5))

_, _ = plot_diversity_together(
    {"taxonomy": mag_shannon_fp, "caz": caz_shannon_fp}, 
    metadata, "Forasteiro", "shannon_entropy", "timepoint", "classifier", axes3a, None, 
    "Shannon diversity", "Forasteiro"
    )

In [None]:
fig3a.savefig(os.path.join(data_dir, "figure3a.svg"), dpi=300)

### CAZyme relative abundances: time course

In [None]:
caz_ft = q2.Artifact.load(os.path.join(data_dir, "caz_ft_rarefied.qza")).view(pd.DataFrame)
caz_ft.head()

In [None]:
caz_ft_merged = caz_ft.merge(metadata, left_index=True, right_index=True)

Prepare the color maps.

In [None]:
value_cols = [col for col in sorted(caz_ft_merged.columns) if col not in ('seed', 'timepoint')]
caz_ft_merged[value_cols] = caz_ft_merged[value_cols].div(caz_ft_merged[value_cols].sum(axis=1), axis=0)

sequential_cmaps = [
    LinearSegmentedColormap.from_list(
    "custom_cmap", ["white", sns.husl_palette(6, h=0.5)[x]]
    ) for x in range(6)
]

# Define custom linspace parameters for each category
category_spacing = {
    "AA": (0.9, 1.6),  # (start, end) values for np.linspace
    "CBM": (0.4, 1.0),
    "CE": (0.6, 1.0),
    "GH": (0.2, 1.6),
    "GT": (0.2, 1.6),
    "PL": (0.9, 1.6)
}

# Group columns by category
category_groups = {}
for col in value_cols:
    curr_cat = re.sub(r'\d', '', col)
    if curr_cat not in category_groups:
        category_groups[curr_cat] = []
    category_groups[curr_cat].append(col)

# Generate color maps scaled to the actual number of columns in each category
maps = {}
for category, cols in category_groups.items():
    n_cols = len(cols)
    start, end = category_spacing.get(category, (0.1, 1.6))  # Get the custom spacing for each category
    color_range = np.linspace(start, end, n_cols)  # Adjust the color range with custom start and end
    cmap_index = ["GH", "GT", "CE", "PL", "CBM", "AA"].index(category)  # Map to the correct colormap
    maps[category] = [sequential_cmaps[cmap_index](i) for i in color_range]  # Use reversed color range

In [None]:
fig4, axes4 = plt.subplots(1, 2, figsize=(7, 4), sharey=True)

category_colors = []
curr_color, prev_cat, curr_idx = None, None, 0
for col in value_cols:
    curr_cat = re.sub(r'\d', '', col)
    if curr_cat != prev_cat:
        curr_idx = 0
        prev_cat = curr_cat
    else:
        curr_idx += 1
    curr_color = maps[curr_cat][curr_idx]
    category_colors.append(curr_color)

for ax, seed in zip(axes4, caz_ft_merged['seed'].unique()):
    df_filtered = caz_ft_merged[caz_ft_merged['seed'] == seed]
    df_melted = df_filtered.melt(id_vars=['timepoint'], value_vars=value_cols, var_name='Category', value_name='Value')
    pivot_table = df_melted.pivot_table(index='timepoint', columns='Category', values='Value', aggfunc='sum')
    
    # Reverse the order of categories
    pivot_table = pivot_table[value_cols]
    
    pivot_table.plot(kind='bar', stacked=True, ax=ax, color=category_colors)
    ax.set_title(f'{seed.capitalize()}')
    ax.set_xlabel('Timepoint')
    ax.set_ylabel('Normalized Value')
    ax.set_ylim(0, 1.05)

# Disable the legend for the right plot
axes4[1].legend_.remove()
axes4[0].legend_.remove()

# Move the legend for the left plot
handles, labels = axes4[0].get_legend_handles_labels()

# Sort the labels and handles alphabetically
sorted_labels_handles = sorted(zip(labels, handles), key=lambda x: x[0])
sorted_labels, sorted_handles = zip(*sorted_labels_handles)

fig4.legend(sorted_handles, sorted_labels, title='Category', bbox_to_anchor=(1.2, 0.9))

plt.tight_layout()
plt.show()

In [None]:
fig4.savefig(os.path.join(data_dir, "figure4.svg"), dpi=300)

### Relative abundances of MAGs: time course

In [None]:
mags_ft = q2.Artifact.load(os.path.join(data_dir, "mags_ft_rarefied.qza"))
mags_taxonomy = q2.Artifact.load(os.path.join(data_dir, "mags_taxonomy.qza"))

In [None]:
mags_ft_filtered, = taxa.methods.filter_table(
    table=mags_ft,
    taxonomy=mags_taxonomy,
    exclude="Unclassified",
)

In [None]:
mags_ft_collapsed, = taxa.methods.collapse(
    table=mags_ft_filtered,
    taxonomy=mags_taxonomy,
    level=8
)
mags_ft_collapsed = mags_ft_collapsed.view(pd.DataFrame)

In [None]:
mags_grouped = mags_ft_collapsed.merge(metadata, left_index=True, right_index=True)
mags_grouped.head()

In [None]:
value_cols = [col for col in mags_grouped if col not in ('seed', 'timepoint')]
mags_grouped[value_cols] = mags_grouped[value_cols].div(mags_grouped[value_cols].sum(axis=1), axis=0)

In [None]:
fig5, axes5 = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

for ax, seed in zip(axes5, mags_grouped['seed'].unique()):
    df_filtered = mags_grouped[mags_grouped['seed'] == seed]
    df_melted = df_filtered.melt(id_vars=['timepoint'], value_vars=value_cols, var_name='Category', value_name='Value')
    pivot_table = df_melted.pivot_table(index='timepoint', columns='Category', values='Value', aggfunc='sum')
    
    # Reverse the order of categories
    pivot_table = pivot_table[value_cols[::-1]]
    
    pivot_table.plot(kind='bar', stacked=True, ax=ax)
    ax.set_title(f'Stacked Bar Plot for Seed {seed}')
    ax.set_xlabel('Timepoint')
    ax.set_ylabel('Normalized Value')

# Disable the legend for the right plot
axes5[1].legend_.remove()
axes5[0].legend_.remove()

# Move the legend for the left plot
handles, labels = axes5[0].get_legend_handles_labels()

# Reverse the order of handles and labels
handles = handles[::-1]
labels = labels[::-1]

fig5.legend(handles, labels, title='Category', bbox_to_anchor=(1.1, -0.05))

plt.tight_layout()
plt.show()

In [None]:
fig5.savefig(os.path.join(data_dir, "figure5.svg"), dpi=300)