In [1]:
from pathlib import Path

import altair as alt
import pandas as pd
import numpy as np

import qiime2
from q2_composition import DataLoafPackageDirFmt

In [16]:
def plot_differentials(
        output_dir,
        df,
        category,
        q_threshold=1.0, 
        lfc_abs_threshold=0, 
        sort_by='lfc', 
        feature_ids=None):

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True) 
    fig_fn = f'{category}-ancombc-barplot.html'.replace(' ', '-')
    fig_fp = output_dir / Path(fig_fn)
    
    if feature_ids is not None:
        # only display the feature-ids in the 
        # feature_ids list
        raise NotImplementedError
        df = df[df['feature-id'] in feature_ids]
    
    df = df[df['q_val'] <= q_threshold]
    df = df[np.abs(df['lfc']) >= lfc_abs_threshold]
    df = df.sort_values(sort_by)
    
    y_labels = []
    for i, e in enumerate(df['feature-id']):
        fields = e.split(';')
        most_specific = fields[-1]
        y_labels.append(f'{i}: {most_specific}')
    df['y_label'] = y_labels
    
    tooltips = []
    for id_, lfc, q, se in zip(df['feature-id'], df['lfc'], df['q_val'], df['se']):
        feature_label = '\n'.join(id_.split(';'))
        tooltip = f"""{feature_label} LFC: {lfc:.4f} q-value: {q:.2e} std err: {se:.4f}"""
        tooltips.append(tooltip)
    df['tooltip'] = tooltips
    
    bars = alt.Chart(df).mark_bar().encode(
        x=alt.X('lfc', title="Log Fold Change (LFC)"),
        y=alt.Y('y_label', title="Feature ID (most specific, if taxonomic)", sort="-x"),
        tooltip="tooltip",
    )
    
    # need to get error bars working (currently se is in the tooltip)
    # https://altair-viz.github.io/gallery/simple_scatter_with_errorbars.html?highlight=mark_errorbar
#     error_bars = alt.Chart().mark_errorbar().encode(
#         x=df['lfc'] + df['se'],
#         x2=df['lfc'] - df['se'],
#         y='feature-id'
#     )

#     chart = (bars + error_bars).properties(height=900)
    chart = bars.properties(height=900)
    chart.save(fig_fp)
    return fig_fp

def da_barplot(output_dir: str,
               slices: DataLoafPackageDirFmt):
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True) 
    index_fp = output_dir / Path(f'index.html')
    with open(index_fp, 'w') as index_f:
        index_f.write('<html><body>\n')
        data = {}
        for e in slices.data_slices.iter_views(pd.DataFrame):
            data[str(e[0]).replace('_slice.csv','')] = e[1]
        columns = [e for e in data['lfc'].columns if e not in '(Intercept)']

        figure_data = []
        for category in columns:
            if category == 'id':
                continue
            df = pd.concat([data['lfc']['id'], 
                            data['lfc'][category], 
                            data['se'][category], 
                            data['q_val'][category]],
                           keys=['feature-id', 'lfc', 'se', 'q_val'], axis=1)
            figure_data.append((category, df))
            figure_fp = plot_differentials(output_dir, df, category, q_threshold=0.001)
            figure_fn = figure_fp.parts[-1]
            index_f.write(f"<a href=./{figure_fn}>{category}</a><hr>\n")
        index_f.write('</body></html>')

In [17]:
loaf = qiime2.Artifact.load('./ancombc-body_habitat.qza')
slices = loaf.view(DataLoafPackageDirFmt)
da_barplot('./da-plots/', slices)