In [None]:
%load_ext autoreload
%autoreload 2

%pylab inline

# Analysis of User-defined Metabolite Sets

This notebook demonstrates how PALS can be used to analyse user-defined metabolite sets. Here PALS is used to analyse Molecular Families from GNPS, as well as Mass2Motifs from MS2LDA for differentially expressed metabolite sets.

In [None]:
import os
import sys
import pathlib
import pickle

sys.path.append('..')

In [None]:
import pandas as pd
import zipfile
import seaborn as sns
import matplotlib.patches as mpatches
from IPython.display import display, HTML
from loguru import logger

In [None]:
from pals.feature_extraction import DataSource
from pals.loader import GNPSLoader
from pals.PLAGE import PLAGE
from pals.common import *

## GNPS Molecular Family Analysis

### Load GNPS results using the loader

The loader is used to retrieve molecular networking results from GNPS and extract the necessary data for PALS analysis. 

Input:
- URL to FBMN GNPS results
- A metadata CSV

Provide the link to your FBMN GNPS experiment results below

In [None]:
database_name = DATABASE_GNPS_MOLECULAR_FAMILY
gnps_url = 'https://gnps.ucsd.edu/ProteoSAFe/status.jsp?task=0a8432b5891a48d7ad8459ba4a89969f'

Provide a metadata CSV file describing the sample groups

In [None]:
metadata_file = os.path.join('test_data', 'AGP', 'AG_Plants_extremes_metadata_df.csv')
metadata_df = pd.read_csv(metadata_file)
metadata_df

Define case and control groups, and create a comparison.

In [None]:
case = 'More than 30'
control = 'Less than 10'
comp_name = 'more_plants/no_plants'
comparisons = [{'case': case, 'control': control, 'name': comp_name },]
comparisons

Instantiate the loader object, and load the data into a `database` object.

In [None]:
loader = GNPSLoader(database_name, gnps_url, metadata_df, comparisons)
database = loader.load_data()

Create data source. For the GNPS database, we get *measurement_df*, *annotation_df* and *experimental_design* parameters from *database* before passing them to the *DataSource* constructor.

In [None]:
measurement_df = database.extra_data['measurement_df']
annotation_df = database.extra_data['annotation_df']
experimental_design = database.extra_data['experimental_design']

In [None]:
measurement_df.head()

In [None]:
annotation_df.head()

### Create data source and run PLAGE analysis

Create a PALS data source and run PLAGE analysis.

In [None]:
gnps_ds = DataSource(measurement_df, annotation_df, experimental_design, None, database=database, min_replace=SMALL)

In [None]:
plage = PLAGE(gnps_ds, num_resamples=1000)
results = plage.get_results_df(standardize=True)

In [None]:
p_value_col = '%s p-value' % comp_name
count_col = 'unq_pw_F'
pathway_df.sort_values([p_value_col, count_col], ascending=[True, False], inplace=True)
pathway_df

### Checking results

In [None]:
sns.set()
pd.set_option('display.max_colwidth', None)

In [None]:
# cmap = 'RdBu_r'
# cmap = 'jet'
cmap = 'vlag'

Filter significant molecular families by p-value. 

In [None]:
pval_threshold = 0.05
df = pathway_df[pathway_df[p_value_col] < pval_threshold]
df

Count how many significant molecular families having at least 10 members.

In [None]:
min_members = 10
df[df[count_col] >= min_members].shape

Plot significant molecular families

In [None]:
all_samples = []
all_groups = []
for group in experimental_design['groups']:
    samples = experimental_design['groups'][group]
    all_samples.extend(samples)
    all_groups.extend([group] * len(samples))

In [None]:
entity_dict = gnps_ds.entity_dict
intensities_df = gnps_ds.standardize_intensity_df()

In [None]:
for idx, row in df.iterrows():
    members = gnps_ds.dataset_pathways_to_row_ids[idx]
    if len(members) < min_members:
        continue
        
    pw_name = row['pw_name']
    p_value = row[p_value_col]
    print(pw_name)
    
    # plotting for manuscript
    to_plot = ['Molecular Family #148']
    if pw_name not in to_plot:
        continue
    else:
        print(row)
                
    # get group intensities
    group_intensities = intensities_df.loc[members][all_samples]    
    
    # get group info
    # print('%s p-value=%.4f' % (pw_name, p_value))
    data = []
    for member in members:
        member_info = entity_dict[member]
        unique_id = member_info['unique_id']
        library_id = member_info['LibraryID']
        gnps_linkout_network = member_info['GNPSLinkout_Network']
        no_spectra = member_info['number of spectra']
        rt = member_info['RT']
        mz = member_info['mass']    
        intensity = member_info['SumPeakIntensity']
        row = [unique_id, library_id, mz, rt, intensity, no_spectra, gnps_linkout_network]
        data.append(row)
    member_df = pd.DataFrame(data, columns=['id', 'LibraryID', 'Precursor m/z', 'RTConsensus', 'PrecursorInt', 'no_spectra', 'link']).set_index('id')    
    
    # Create a categorical palette to identify the networks
    used_groups = list(set(all_groups))
    group_pal = sns.husl_palette(len(used_groups), s=.45)
    group_lut = dict(zip(map(str, used_groups), group_pal))

    # Convert the palette to vectors that will be drawn on the side of the matrix
    group_colours = pd.Series(all_groups, index=group_intensities.columns).map(group_lut)    
    group_colours.name = 'groups'
    
    # plot heatmap
    g = sns.clustermap(group_intensities, center=0, cmap=cmap, col_colors=group_colours, 
                   col_cluster=False, linewidths=0.75, figsize=(10, 10), cbar_pos=(1.0, 0.3, 0.05, 0.5))
    
    title = pw_name
    if pw_name == 'Molecular Family #127':
        title = 'Cinnamic Acid-related Molecular Family'
    elif pw_name == 'Molecular Family #148':
        title = 'Steroid-related Molecular Family'
    plt.suptitle('%s' % (title), fontsize=24, y=0.9)

    # draw group legend
    for group in used_groups:
        g.ax_col_dendrogram.bar(0, 0, color=group_lut[group], label=group, linewidth=0)
    g.ax_col_dendrogram.legend(loc="right")
    
    # make the annotated peaks to have labels in bold
    annotated_df = member_df[member_df['LibraryID'].notnull()]
    annotated_peaks = annotated_df.index.values
    for label in g.ax_heatmap.get_yticklabels():
        if label.get_text() in annotated_peaks:
            label.set_weight("bold")
            label.set_color("green")   
    plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)        

    out_file = '%s.pdf' % pw_name
    plt.savefig(out_file, dpi=300)
    plt.show()
    
    out_file = '%s.csv' % pw_name
    display(member_df)
    member_df.drop('link', axis=1).round(4).to_csv(out_file, index=True)

## GNPS-MS2LDA Analysis

Similar to above, but here we provide a link to GNPS-MS2LDA analysis to retrieve Mass2Motifs, and pass it to the loader.

In [None]:
gnps_ms2lda_url = 'https://gnps.ucsd.edu/ProteoSAFe/status.jsp?task=7c34badae00e43bc87b195a706cf1f43'

In [None]:
database_name = DATABASE_GNPS_MS2LDA
loader = GNPSLoader(database_name, gnps_url, metadata_df, comparisons, gnps_ms2lda_url=gnps_ms2lda_url)
database = loader.load_data()

In [None]:
measurement_df = database.extra_data['measurement_df']
annotation_df = database.extra_data['annotation_df']
experimental_design = database.extra_data['experimental_design']
gnps_ds = DataSource(measurement_df, annotation_df, experimental_design, None, database=database, min_replace=SMALL)

In [None]:
plage = PLAGE(gnps_ds, num_resamples=1000)
pathway_df = plage.get_pathway_df(standardize=True)

In [None]:
p_value_col = '%s p-value' % comp_name
count_col = 'unq_pw_F'
pathway_df.sort_values([p_value_col, count_col], ascending=[True, False], inplace=True)
pathway_df

In [None]:
df = pathway_df[pathway_df[p_value_col] < 0.05]
df

In [None]:
entity_dict = gnps_ds.entity_dict
intensities_df = gnps_ds.standardize_intensity_df()

In [None]:
for idx, row in df.iterrows():
    members = gnps_ds.dataset_pathways_to_row_ids[idx]
    if len(members) < min_members:
        continue
        
    pw_name = row['pw_name']
    p_value = row[p_value_col]
    
    if 'gnps_motif_54.m2m' not in pw_name:
        continue
    else:
        print(pw_name)
                
    # get group intensities
    group_intensities = intensities_df.loc[members][all_samples]    
    
    # get group info
    # print('%s p-value=%.4f' % (pw_name, p_value))
    data = []
    for member in members:
        member_info = entity_dict[member]
        unique_id = member_info['unique_id']
        library_id = member_info['LibraryID']
        gnps_linkout_network = member_info['GNPSLinkout_Network']
        no_spectra = member_info['number of spectra']
        rt = member_info['RT']
        mz = member_info['mass']    
        intensity = member_info['SumPeakIntensity']
        row = [unique_id, library_id, mz, rt, intensity, no_spectra, gnps_linkout_network]
        data.append(row)
    member_df = pd.DataFrame(data, columns=['id', 'LibraryID', 'Precursor m/z', 'RTConsensus', 'PrecursorInt', 'no_spectra', 'link']).set_index('id')    
    
    # Create a categorical palette to identify the networks
    used_groups = list(set(all_groups))
    group_pal = sns.husl_palette(len(used_groups), s=.45)
    group_lut = dict(zip(map(str, used_groups), group_pal))

    # Convert the palette to vectors that will be drawn on the side of the matrix
    group_colours = pd.Series(all_groups, index=group_intensities.columns).map(group_lut)    
    group_colours.name = 'groups'
    
    # plot heatmap
    g = sns.clustermap(group_intensities, center=0, cmap=cmap, col_colors=group_colours, 
                   col_cluster=False, linewidths=0.75, figsize=(10, 10), cbar_pos=(1.0, 0.3, 0.05, 0.5))

    title = pw_name
    if 'gnps_motif_54.m2m' in pw_name:
        title = 'Ferulic-acid-related Mass2Motif'
    plt.suptitle('%s' % (title), fontsize=24, y=0.9)

    # draw group legend
    for group in used_groups:
        g.ax_col_dendrogram.bar(0, 0, color=group_lut[group], label=group, linewidth=0)
    g.ax_col_dendrogram.legend(loc="right")
    
    # make the annotated peaks to have labels in bold
    annotated_df = member_df[member_df['LibraryID'].notnull()]
    annotated_peaks = annotated_df.index.values
    for label in g.ax_heatmap.get_yticklabels():
        if label.get_text() in annotated_peaks:
            label.set_weight("bold")
            label.set_color("green")   
    plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)        

    out_file = '%s.pdf' % pw_name
    plt.savefig(out_file, dpi=300)
    plt.show()
    
    out_file = '%s.csv' % pw_name
    display(member_df)
    member_df.drop('link', axis=1).round(4).to_csv(out_file, index=True)