In [None]:
import math

import floweaver
import numpy as np
import pandas as pd

### GFOP sample type metadata 

In [None]:
def get_sample_types(gfop_metadata, simple_complex=None):
    if simple_complex is not None:
        gfop_metadata = gfop_metadata[
            gfop_metadata['simple_complex'] == simple_complex]
    col_sample_types = [f'sample_type_group{i}' for i in range(1, 7)]
    return (gfop_metadata[['filename', *col_sample_types]]
            .set_index('filename'))

In [None]:
gfop_metadata = pd.read_csv(
    '../data/11442_foodomics_multiproject_metadata.txt', sep='\t')
# First row is empty.
gfop_metadata = gfop_metadata.drop(index=0)
# Remove trailing whitespace.
gfop_metadata = gfop_metadata.apply(lambda col: col.str.strip()
                                    if col.dtype == 'object' else col)

### Food type at different metadata levels

In [None]:
def _get_flows(gnps_network, sample_types, groups_included,
               filenames_included, max_hierarchy_level):
    # Select GNPS job groups.
    groups = {f'G{i}' for i in range(1, 7)}
    groups_excluded = groups - set(groups_included)
    df_selected = gnps_network[
        (gnps_network[groups_included] > 0).all(axis=1) &
        (gnps_network[groups_excluded] == 0).all(axis=1)].copy()
    df_selected = df_selected[
        df_selected['UniqueFileSources'].apply(lambda cluster_fn:
            any(fn in cluster_fn for fn in filenames_included))]
    filenames = (df_selected['UniqueFileSources'].str.split('|')
                 .explode())
    # Select food hierarchy levels.
    sample_types = sample_types[[
        f'sample_type_group{i}' for i in range(1, max_hierarchy_level + 1)]]
    # Match the GNPS job results to the food sample types.
    sample_types_selected = sample_types.reindex(filenames)
    sample_types_selected = sample_types_selected.dropna()
    # Discard samples that occur less frequent than water (blank).
    water_count = ((sample_types_selected['sample_type_group1'] == 'water')
                   .sum())
    sample_counts = sample_types_selected[
        f'sample_type_group{max_hierarchy_level}'].value_counts()
    sample_counts_valid = sample_counts.index[sample_counts > water_count]
    sample_types_selected = sample_types_selected[sample_types_selected[
        f'sample_type_group{max_hierarchy_level}'].isin(sample_counts_valid)]
    # Get the flows between consecutive food hierarchy levels.
    flows, processes = [], []
    for i in range(1, max_hierarchy_level):
        g1, g2 = f'sample_type_group{i}', f'sample_type_group{i + 1}'
        flow = (sample_types_selected.groupby([g1, g2]).size()
                .reset_index().rename(columns={g1: 'source', g2: 'target',
                                               0: 'value'}))
        flow['source'] = flow['source'] + f'_{i}'
        flow['target'] = flow['target'] + f'_{i + 1}'
        flow['type'] = flow['target']
        flows.append(flow)
        process = pd.concat([flow['source'], flow['target']],
                            ignore_index=True).to_frame()
        process['level'] = [*np.repeat(i, len(flow['source'])),
                            *np.repeat(i + 1, len(flow['target']))]
        processes.append(process)
    return (pd.concat(flows, ignore_index=True),
            pd.concat(processes, ignore_index=True).drop_duplicates()
            .rename(columns={0: 'id'}).set_index('id'))

In [None]:
def plot_flows(gnps_network, sample_types, groups_included,
               filenames_included, sample_type_hierarchy,
               max_hierarchy_level=4, filename=None):
    flows, processes = _get_flows(
        gnps_network, sample_types, groups_included, filenames_included,
        max_hierarchy_level)
    dataset = floweaver.Dataset(flows, dim_process=processes)
    
    food_counts = (flows[flows['target'].str.endswith(
                       f'_{max_hierarchy_level}')][['target', 'value']]
                   .rename(columns={'target': 'food', 'value': 'count'})
                   .set_index('food').squeeze())

    labels = (sample_type_hierarchy
              .reindex(set(flows['source']) | set(flows['target']))
              .sort_values('order_num').index)
    nodes, ordering, bundles = {}, [], []
    for level in processes['level'].unique():
        nodes[f'level {level}'] = floweaver.ProcessGroup(f'level == {level}')
        nodes[f'level {level}'].partition = floweaver.Partition.Simple(
            'process', labels[labels.str.endswith(f'_{level}')][::-1])

        ordering.append([f'level {level}'])

        if level + 1 in processes['level'].unique():
            bundles.append(floweaver.Bundle(f'level {level}',
                                            f'level {level + 1}'))

    sdd = floweaver.SankeyDefinition(
        nodes, bundles, ordering, flow_partition=dataset.partition('type'))
    palette = sample_type_hierarchy['color_code'].dropna().to_dict()
    return floweaver.weave(sdd, dataset, palette=palette), food_counts

In [None]:
sample_types_simple = get_sample_types(gfop_metadata, 'simple')
sample_types_complex = get_sample_types(gfop_metadata, 'complex')

In [None]:
sample_type_hierarchy = (
    pd.read_csv('../data/sample_type_hierarchy.csv')
    .set_index('descriptor').sort_values('order_num'))

In [None]:
ra_metadata = pd.read_csv(
    '../data/'
    '12_26_RA fecal - plasma - food - FoodOmics 3500 FDR 0.01 tol 0.01 min 2/'
    'ra_qiime2_metadata.tsv', sep='\t')
filenames_pre = (ra_metadata[
    ra_metadata['ATTRIBUTE_Intervention'] == '1_pre']['filename'].unique())
filenames_post = (ra_metadata[
    ra_metadata['ATTRIBUTE_Intervention'] == '2_post']['filename'].unique())

In [None]:
gnps_network = pd.read_csv(
    '../data/12_26_RA fecal - plasma - food - FoodOmics 3500 FDR 0.01 tol 0.01 min 2/'
    'METABOLOMICS-SNETS-V2-0794151f-view_all_clusters_withID_beta-main.tsv',
    sep='\t')

In [None]:
dataset = 'folder12'
simple_complex = 'simple'
groups = '1', '3', '4'
max_level = 4

if simple_complex == 'simple':
    sample_types = sample_types_simple
elif simple_complex == 'complex':
    sample_types = sample_types_complex
else:
    raise ValueError('Unknown sample type')

width, height = 1200, 1800

In [None]:
# Flows pre diet intervention.
sankey_data, food_counts_pre = plot_flows(
    gnps_network, sample_types, [f'G{g}' for g in groups], filenames_pre,
    sample_type_hierarchy, max_level)
(sankey_data.to_widget(width=width, height=height, margins={
    'left': 150, 'right': 150, 'top': -50, 'bottom': -50})
 .auto_save_png(f'flow_{dataset}_g{"".join(groups)}_level{max_level}_'
                f'{simple_complex}_pre.png'))

In [None]:
# Flows post diet intervention.
sankey_data, food_counts_post = plot_flows(
    gnps_network, sample_types, [f'G{g}' for g in groups], filenames_post,
    sample_type_hierarchy, max_level)
(sankey_data.to_widget(width=width, height=height, margins={
    'left': 150, 'right': 150, 'top': -50, 'bottom': -50})
 .auto_save_png(f'flow_{dataset}_g{"".join(groups)}_level{max_level}_'
                f'{simple_complex}_post.png'))

In [None]:
food_counts_diff = (pd.merge(food_counts_pre, food_counts_post, 'outer',
                             left_index=True, right_index=True)
                    .fillna(0)
                    .rename(columns={'count_x': 'count_pre',
                                     'count_y': 'count_post'}))
food_counts_diff['count_pre'] = food_counts_diff['count_pre'].astype(int)
food_counts_diff['count_post'] = food_counts_diff['count_post'].astype(int)
food_counts_diff['ratio'] = (food_counts_diff['count_post']
                             / food_counts_diff['count_pre'])

In [None]:
(food_counts_diff.sort_values('ratio')
 .to_csv(f'ra_intervention_g{"".join(groups)}_level{max_level}_'
         f'{simple_complex}.csv'))