In [1]:
from pathlib import Path
from collections import OrderedDict

import altair as alt
import pandas as pd
import numpy as np

import qiime2
from q2_composition import DataLoafPackageDirFmt

In [2]:
def plot_differentials(
        output_dir,
        df,
        category,
        q_threshold=1.0, 
        lfc_threshold=0, 
        feature_ids=None):
    
    if len(df) == 0:
        raise ValueError("No features present in input.")
    
    if feature_ids is not None:
        df = df.query('`feature-id` in @feature_ids')
    
    df = df[df['q_val'] <= q_threshold]
    df = df[np.abs(df['lfc']) >= lfc_threshold]
    
    if len(df) == 0:
        raise ValueError("No features remaining after applying filters.")

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True) 
    fig_fn = '-'.join(f'{category}-ancombc-barplot.html'.split())
    fig_fp = output_dir / Path(fig_fn)
    
    # For readability, only the most specific named taxonomy will be 
    # included in the y-axis label. The full taxonomy text will be in
    # the tool-tip. Because it's possible that the most specific 
    # taxonomic name are not unique, this code simply appends a number
    # to non-unique names. Providing separate labels for ticks, which 
    # would avoid this, doesn't seem straight-forward in altair (e.g., 
    # see https://github.com/altair-viz/altair/issues/938). 
    y_labels = OrderedDict()
    for i, e in enumerate(df['feature-id']):
        fields = [field for field in e.split(';') if field != '__']
        
        most_specific = fields[-1]
        if most_specific in y_labels:
            y_labels[f"{most_specific} ({i})"] = None
        else: 
            y_labels[most_specific] = None
    df['y_label'] = y_labels.keys()
    
    df['feature'] = [id_.replace(';', ' ') for id_ in df['feature-id']]
    df['enriched'] = ["enriched" if x else "depleted" for x in df['lfc'] > 0]
    
    
    df['error-upper'] = df['lfc'] + df['se']
    df['error-lower'] = df['lfc'] - df['se']
    
    # Normally we would call bars.mark_rule to add standard error marks (opposed to 
    # alt.Chart.mark_bar, but I want to color the the standard error marks differently so 
    # they stand out against the bars themselves. As far as I can tell, I can only do 
    # this by creating two separate charts, so this shared y-axis will be used for 
    # both of those.
    shared_y = alt.Y("y_label", title="Feature ID (most specific, if taxonomic)",
                     sort=alt.EncodingSortField(field="lfc", op="min", order="descending"))
    
    bars = alt.Chart(df).mark_bar().encode(
        x=alt.X('lfc', title="Log Fold Change (LFC)"),
        y=shared_y,
        tooltip=alt.Tooltip(["feature-id", "lfc", "q_val", "se", "error-lower", "error-upper"]),
        color=alt.Color('enriched', title="Relative to reference", sort="descending")
    )

    error = alt.Chart(df).mark_rule(color='black').encode(
        x='error-lower',
        x2='error-upper',
        y=shared_y,
    )

    chart = (bars + error).properties()
    chart.save(fig_fp)
    return fig_fp

def da_barplot(output_dir: str,
               slices: DataLoafPackageDirFmt,
               q_threshold: float = 1.0, 
               lfc_threshold: float = 0, 
               feature_ids: qiime2.Metadata = None):
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True) 
    index_fp = output_dir / Path(f'index.html')
    with open(index_fp, 'w') as index_f:
        index_f.write('<html><body>\n')
        data = {}
        for e in slices.data_slices.iter_views(pd.DataFrame):
            data[str(e[0]).replace('_slice.csv','')] = e[1]
        columns = [e for e in data['lfc'].columns if e not in '(Intercept)']

        figure_data = []
        for category in columns:
            if category == 'id':
                continue
            df = pd.concat([data['lfc']['id'], 
                            data['lfc'][category], 
                            data['se'][category], 
                            data['q_val'][category]],
                           keys=['feature-id', 'lfc', 'se', 'q_val'], axis=1)
            figure_data.append((category, df))
            try:
                figure_fp = plot_differentials(output_dir, df, category,
                                               q_threshold=q_threshold,
                                               lfc_threshold=lfc_threshold,
                                               feature_ids=feature_ids)
                figure_fn = figure_fp.parts[-1]
                index_f.write(f"<a href=./{figure_fn}>{category}</a><hr>\n")
            except ValueError as e:
                index_f.write(f"Category {category} contains no data to plot: {str(e)} <hr>\n")
        index_f.write('</body></html>')

In [3]:
loaf = qiime2.Artifact.load('./ancombc-body_habitat.qza')
slices = loaf.view(DataLoafPackageDirFmt)
# feature_ids=['d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Lactobacillaceae;g__Lentilactobacillus',
#              'd__Bacteria;p__Firmicutes;c__Bacilli;o__Staphylococcales;f__Staphylococcaceae;g__Staphylococcus',
#              'd__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Streptococcaceae;g__Streptococcus']
da_barplot('./example-output/', slices, q_threshold=0.001)