In [None]:
import math
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats

In [None]:
def get_sample_types(gfop_metadata, simple_complex=None):
    if simple_complex is not None:
        gfop_metadata = gfop_metadata[
            gfop_metadata['simple_complex'] == simple_complex]
    col_sample_types = [f'sample_type_group{i}' for i in range(1, 7)]
    return (gfop_metadata[['filename', *col_sample_types]]
            .set_index('filename'))

In [None]:
gfop_metadata = pd.read_csv(
    '../data/11442_foodomics_multiproject_metadata.txt', sep='\t')
# First row is empty.
gfop_metadata = gfop_metadata.drop(index=0)
# Remove trailing whitespace.
gfop_metadata = gfop_metadata.apply(lambda col: col.str.strip()
                                    if col.dtype == 'object' else col)

In [None]:
def get_file_food_counts(gnps_network, sample_types, groups_included,
                         filenames_included, level):
    # Select GNPS job groups.
    groups = {f'G{i}' for i in range(1, 7)}
    groups_excluded = groups - set(groups_included)
    df_selected = gnps_network[
        (gnps_network[groups_included] > 0).all(axis=1) &
        (gnps_network[groups_excluded] == 0).all(axis=1)].copy()
    df_selected = df_selected[
        df_selected['UniqueFileSources'].apply(lambda cluster_fn:
            any(fn in cluster_fn for fn in filenames_included))]
    filenames = (df_selected['UniqueFileSources'].str.split('|')
                 .explode())
    # Select food hierarchy levels.
    sample_types = sample_types[f'sample_type_group{level}']
    # Match the GNPS job results to the food sample types.
    sample_types_selected = sample_types.reindex(filenames)
    sample_types_selected = sample_types_selected.dropna()
    # Discard samples that occur less frequent than water (blank).
    water_count = (sample_types_selected == 'water').sum()
    sample_counts = sample_types_selected.value_counts()
    sample_counts_valid = sample_counts.index[sample_counts > water_count]
    sample_types_selected = sample_types_selected[
        sample_types_selected.isin(sample_counts_valid)]
    # Get sample counts at the specified level.
    return sample_types_selected.value_counts()

In [None]:
sample_types = get_sample_types(gfop_metadata)
sample_types_simple = get_sample_types(gfop_metadata, 'simple')
sample_types_complex = get_sample_types(gfop_metadata, 'complex')

In [None]:
data_dir = os.path.join('..', 'data', '12_26_RA fecal - plasma - food - '
                        'FoodOmics 3500 FDR 0.01 tol 0.01 min 2')

In [None]:
metadata = pd.read_csv(os.path.join(data_dir, 'ra_qiime2_metadata.tsv'),
                       sep='\t')

In [None]:
gnps_network = pd.read_csv(
    os.path.join(data_dir, 'METABOLOMICS-SNETS-V2-0794151f-'
                 'view_all_clusters_withID_beta-main.tsv'),
    sep='\t')

In [None]:
# Calculate number of matches to food categories per file.
level = 4
food_counts, filenames = [], []
for sample_type, groups in [#('stool', ['G1', 'G4']),
                            ('plasma', ['G2', 'G4'])]:
    metadata_group = metadata[
        metadata['ATTRIBUTE_SampleTypeSub1'] == sample_type]
    for filename in metadata_group['filename']:
        file_food_counts = get_file_food_counts(
            gnps_network, sample_types, groups, [filename], level)
        if len(file_food_counts) > 0:
            food_counts.append(file_food_counts)
            filenames.append(filename)

In [None]:
food_counts = (pd.concat(food_counts, axis=1, sort=True)
               .fillna(0).astype(int).T)
food_counts.index = pd.Index(filenames, name='filename')
food_counts = food_counts.sort_index()

In [None]:
# Map GFOP foods to foods specified in the diet diary.
food_map = pd.read_csv(os.path.join(data_dir, 'ra_diary_gfop_map.csv'))
# Split multiply matching foods.
food_map['STG5'] = food_map['STG5'].str.split(';')
food_map = food_map.explode('STG5')
# Add level 4 foods from their level 5 successors.
map_level45 = (sample_types[['sample_type_group4', 'sample_type_group5']]
               .reset_index(drop=True).drop_duplicates())
map_level45 = (map_level45[map_level45['sample_type_group5']
                           .isin(food_map['STG5'])]
               .set_index('sample_type_group5').squeeze().to_dict())
# Force map complex as it can map to a lot of different things.
map_level45['complex'] = 'complex'
# Missing entries.
map_level45['not represented'] = 'not represented'
food_map['STG4'] = food_map['STG5'].map(map_level45)
food_map = food_map.sort_values(['STG4', 'STG5'])

In [None]:
# Self-reported diet diary.
diary = (pd.read_csv(os.path.join(data_dir, 'ra_diet_diary.csv'),
                     index_col='Diary_category')
         .dropna('columns', 'all').replace({'yes': True, 'no': False}).T
         .rename_axis(columns=None))
diary['study_id'] = diary.index.str[1:5]
diary['time'] = diary.index.str[6:].astype(int)
diary = diary.set_index(['study_id', 'time'])
column_rename = (food_map[[f'STG{level}', 'Diary_category']]
                 .set_index('Diary_category').squeeze().to_dict())
# Combine diary entries that match to multiple foods
# by aggregating their absence/presence values.
diary = (diary.rename(columns=column_rename).drop(columns='not represented')
         .sort_index('columns').groupby(axis='columns', level=0).any())

In [None]:
# Convert between patient identifiers and run names.
patient_map = pd.read_csv(os.path.join(data_dir, 'ra_patient_map.csv'))
patient_map['study_id'] = patient_map['study_id'].str[:4]
patient_map['patient'] = patient_map['patient'].str[:4]
patient_map = patient_map.drop_duplicates()

In [None]:
food_counts.index = food_counts.index.str[1:-6]
food_counts['time'] = 0
food_counts.loc[food_counts.index.str.endswith('T1'), 'time'] = -14
food_counts.loc[food_counts.index.str.endswith('T3'), 'time'] = 14
food_counts.index = (food_counts.index.str[:-3].map(
    patient_map.set_index('patient').squeeze().to_dict()).rename('study_id'))
food_counts = food_counts.set_index('time', append=True)

In [None]:
shared_foods = (food_counts.columns & diary.columns).drop('complex')
food_counts = food_counts[shared_foods]
diary = diary[shared_foods]

In [None]:
statistics = []
study_ids = (food_counts.index.get_level_values('study_id').unique() &
             diary.index.get_level_values('study_id').unique())
for study_id in study_ids:
    food_counts_study = food_counts[
        food_counts.index.get_level_values('study_id') == study_id]
    diary_study = diary[
        diary.index.get_level_values('study_id') == study_id]
    for food_counts_time, food_counts_time_study in food_counts_study.iterrows():
        for diary_time, diary_time_study in diary_study.iterrows():
            statistic, _ = stats.kendalltau(food_counts_time_study, diary_time_study)
            statistics.append((*food_counts_time, diary_time[1], statistic))

In [None]:
statistics = pd.DataFrame(statistics, columns=[
    'study_id', 'MS timepoint', 'Diary timepoint', 'Kendall\'s tau'])

In [None]:
sns.catplot(x='Diary timepoint', y='Kendall\'s tau', data=statistics,
            col='MS timepoint', kind='box', height=6, aspect=1.5)

plt.savefig('ra_diet_diary.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
width = 14
height = width / 1.618
fig, ax = plt.subplots(figsize=(width, height))

food_counts_norm = food_counts.divide(food_counts.sum(axis=1), axis=0)
order = food_counts_norm.sum(axis=0).sort_values(ascending=False).index
with sns.color_palette('tab20'):
    food_counts_norm[order].plot.bar(ax=ax, stacked=True)

ax.set_xticklabels([f'P{study_id}T{time}'
                    for study_id, time in food_counts.index], rotation=90)

ax.yaxis.set_major_formatter(mticker.PercentFormatter(1))

ax.set_xlabel('Patient at timepoint')
ax.set_ylabel('Relative food count')

ax.legend(loc='center left', bbox_to_anchor=(1.05, 0.5), ncol=2,
          frameon=False)

sns.despine()

plt.savefig('ra_individual_food_count.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
food_counts_norm_intervention = food_counts_norm.xs(14, level='time')
food_counts_norm_intervention = food_counts_norm_intervention.loc[
    :, (food_counts_norm_intervention != 0).any(axis=0)]
sns.clustermap(food_counts_norm_intervention.corr(), figsize=(10, 10),
               vmin=-1, vmax=1)

plt.savefig('ra_food_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()