# Downselect Full-Sized C18 Atlas Based on Experimental Features

In [None]:
import pandas as pd
import numpy as np
import glob
import os

import matchms as mms
from matchms.similarity import CosineGreedy

import sys
sys.path.insert(0,'/global/homes/t/tharwood/repos/')
from metatlas.io import feature_tools as ft

from tqdm.notebook import tqdm

In [None]:
def normalize_spectra(row):
    filter_idx = np.argwhere(row['spectrum'][0]<(row['precursor_mz']+0.20)).flatten()
    cleaned_spec = row['spectrum'][:,filter_idx]
    
    order_idx = np.argsort(cleaned_spec[0])
    ordered_spec = np.array([cleaned_spec[0][order_idx], cleaned_spec[1][order_idx]])
    
    return ordered_spec

def evaluate_score(score):
    
    if (score['score'] >= msms_score) & (score['matches'] >= msms_matches):
        return True
    else:
        return False

def calculate_ms2_summary(df):
    
    spectra = {'label':[], 
               'spectrum':[], 
               'rt':[], 
               'precursor_mz':[],
               'precursor_peak_height':[]}
    
    for label_group, label_data in df[df['in_feature']==True].groupby('label'):
        
        label = label_group
        
        for rt_group, rt_data in pd.DataFrame(label_data).groupby('rt'):
            
            mz = np.array(rt_data.mz.tolist())
            i = np.array(rt_data.i.tolist())
        
            mzi = np.array([mz, i])
        
            spectra['label'].append(label)
            spectra['spectrum'].append(mzi)
            spectra['rt'].append(rt_group)
            
            spectra['precursor_mz'].append(rt_data.precursor_MZ.median())
            spectra['precursor_peak_height'].append(rt_data.precursor_intensity.median())
        
    return pd.DataFrame(spectra)

## Set Pre-Filter Parameters

In [None]:
raw_data_dir = '/global/cfs/cdirs/metatlas/raw_data/jgi'
experiment = '20231103_JGI_MW_507961_Char_final-dil_EXP120B_C18_USDAY81385'

msms_refs_path = '/global/cfs/cdirs/metatlas/projects/spectral_libraries/20240222_labeled-addition_msms_refs.tab'
c18_base_atlas_dir = '/global/homes/t/tharwood/repos/metatlas-data/C18/'

polarity = 'positive'

#filtering hit generation
ppm_tolerance = 5
extra_time = 0.5

#filtering hits
rt_window = 0.5
peak_height = 5e5
num_points = 4.0

msms_filter = True
msms_score = 0.65
msms_matches = 4
frag_tolerance = 0.02

#use regression for rt allignment, otherwise use median offset
rt_regression = False

#rt alignment model degree if using regression
model_degree = 1

In [None]:
if polarity == 'positive':
    file_polarity = 'POS'
    filter_polarity = 'NEG'
else:
    file_polarity = 'NEG'
    filter_polarity = 'POS'

all_files = glob.glob(os.path.join(raw_data_dir, experiment, '*.h5'))

files_subset = [file for file in all_files if os.path.basename(file).split('_')[9] == file_polarity and 'QC' not in file]
qc_files = [file for file in all_files if os.path.basename(file).split('_')[9] != filter_polarity and 'QC_' in file]

In [None]:
len(all_files)

## Generate RT Adjusted Atlas

In [None]:
c18_adjustment_atlas = pd.read_csv('/global/homes/t/tharwood/c18_metatlas_pre-filter/rt-adjustment_atlases/C18_rt-adjustment_{}.tsv'.format(polarity), sep='\t')
c18_base_atlas = pd.read_csv(os.path.join(c18_base_atlas_dir, 'C18_standards_{}.tsv'.format(polarity)), sep='\t')

In [None]:
c18_adjustment_atlas

In [None]:
%%time
experiment_input = ft.setup_file_slicing_parameters(c18_adjustment_atlas, qc_files, base_dir=os.getcwd(), ppm_tolerance=ppm_tolerance, extra_time=extra_time, polarity=polarity)

ms1_data = []

for file_input in experiment_input:
    
    data = ft.get_data(file_input, save_file=False, return_data=True)
    
    data['ms1_summary']['lcmsrun_observed'] = file_input['lcmsrun']
      
    ms1_data.append(data['ms1_summary'])
    
ms1_data = pd.concat(ms1_data)

In [None]:
median_data = ms1_data[ms1_data['peak_height'] >= 1e4].groupby('label')['rt_peak'].median()

In [None]:
rt_merged = pd.merge(c18_adjustment_atlas[['label', 'rt_peak']], median_data, on='label')

In [None]:
x = rt_merged['rt_peak_x']
y = rt_merged['rt_peak_y']


if rt_regression:
    rt_alignment_model = np.polyfit(x, y, model_degree)
    rt_merged['rt_peak_predicted'] = rt_merged['rt_peak_x'].apply(lambda x: np.polyval(rt_alignment_model, x))
    
else:
    median_offset = (y - x).median()
    rt_merged['rt_peak_predicted'] = rt_merged['rt_peak_x'] + median_offset

In [None]:
c18_aligned_atlas = c18_base_atlas.copy()

c18_aligned_atlas['rt_peak_unaligned'] = c18_aligned_atlas['rt_peak']


if rt_regression:
    c18_aligned_atlas['rt_peak'] = c18_aligned_atlas['rt_peak'].apply(lambda x: np.polyval(rt_alignment_model, x))

else:
    c18_aligned_atlas['rt_peak'] = c18_aligned_atlas['rt_peak'] + median_offset

c18_aligned_atlas['rt_min'] = c18_aligned_atlas['rt_peak'] - rt_window
c18_aligned_atlas['rt_max'] = c18_aligned_atlas['rt_peak'] + rt_window

## Collect MS1 and MS2 Data

In [None]:
%%time
experiment_input = ft.setup_file_slicing_parameters(c18_aligned_atlas, files_subset, base_dir=os.getcwd(), ppm_tolerance=ppm_tolerance, extra_time=extra_time, polarity=polarity)

ms1_data = []
ms2_data = []

for file_input in tqdm(experiment_input):
    
    data = ft.get_data(file_input, save_file=False, return_data=True)
    
    data['ms1_summary']['lcmsrun_observed'] = file_input['lcmsrun']
    
    ms2_summary = calculate_ms2_summary(data['ms2_data'])
    ms2_summary['lcmsrun_observed'] = file_input['lcmsrun']
    
    ms1_data.append(data['ms1_summary'])
    ms2_data.append(ms2_summary)
    
ms1_data = pd.concat(ms1_data)
ms2_data = pd.concat(ms2_data)

In [None]:
ms2_data['spectrum'] = ms2_data.apply(normalize_spectra, axis=1)

In [None]:
# ms1_data['lcmsrun_basename'] = ms1_data['lcmsrun_observed'].apply(lambda x: os.path.basename(x))
# ms2_data['lcmsrun_basename'] = ms2_data['lcmsrun_observed'].apply(lambda x: os.path.basename(x))

In [None]:
# ms1_data[ms1_data['lcmsrun_basename'].str.contains('ExCtrl')]

In [None]:
# ms2_data[ms2_data['lcmsrun_basename'].str.contains('ExCtrl')]

## Get MSMS Hits

In [None]:
msms_refs_df = pd.read_csv(msms_refs_path, sep='\t')

In [None]:
refs_compounds = set(msms_refs_df[msms_refs_df['database']=='metatlas']['inchi_key'].tolist())
atlas_compounds = set(c18_base_atlas['inchi_key'].tolist())

In [None]:
ms2_data_annotated = pd.merge(ms2_data, c18_aligned_atlas[['label', 'inchi_key']], on='label')

In [None]:
msms_refs_filtered = msms_refs_df[(msms_refs_df['inchi_key'].isin(ms2_data_annotated['inchi_key'].tolist())) & 
                                  (msms_refs_df['database'] == 'metatlas') & 
                                  (msms_refs_df['polarity'] == polarity)].copy()

msms_refs_filtered['spectrum'] = msms_refs_filtered['spectrum'].apply(lambda x: np.asarray(eval(x)))
msms_refs_filtered['spectrum'] = msms_refs_filtered.apply(normalize_spectra, axis=1)

In [None]:
ms2_data_annotated = pd.merge(msms_refs_filtered[['id', 'inchi_key', 'spectrum']], ms2_data_annotated, on='inchi_key')

In [None]:
ms2_data_annotated['mms_spectrum_x'] = ms2_data_annotated.apply(lambda x: mms.Spectrum(x.spectrum_x[0], x.spectrum_x[1], metadata={'precursor_mz':x.precursor_mz}), axis=1)
ms2_data_annotated['mms_spectrum_y'] = ms2_data_annotated.apply(lambda x: mms.Spectrum(x.spectrum_y[0], x.spectrum_y[1], metadata={'precursor_mz':x.precursor_mz}), axis=1)

In [None]:
cosine_greedy = CosineGreedy(tolerance=frag_tolerance)

In [None]:
ms2_data_annotated['score'] = ms2_data_annotated.apply(lambda x: cosine_greedy.pair(x.mms_spectrum_x, x.mms_spectrum_y), axis=1)

In [None]:
ms2_data_annotated['keep'] = ms2_data_annotated['score'].apply(evaluate_score)

## Filter Collected Data & Generate Reduced Atlas

In [None]:
ms1_data_filtered = ms1_data[(ms1_data['peak_height'] >= peak_height) & (ms1_data['num_datapoints'] >= num_points)]
ms1_reduced_labels = set(ms1_data_filtered.label.tolist())

In [None]:
if msms_filter:
    ms2_data_filtered = ms2_data_annotated[ms2_data_annotated['keep']]
    ms2_reduced_labels = set(ms2_data_annotated.label.tolist())
    
else:
    ms2_reduced_labels = ms1_reduced_labels

In [None]:
reduced_labels = ms1_reduced_labels.intersection(ms2_reduced_labels)

In [None]:
c18_reduced_atlas = c18_aligned_atlas[c18_aligned_atlas['label'].isin(reduced_labels)]

In [None]:
c18_reduced_atlas.shape

In [None]:
c18_reduced_atlas

In [None]:
c18_reduced_atlas['label'] = c18_reduced_atlas['label'].apply(lambda x: x.split(': ')[1].split(' [')[0])

## Sanitize & Save Reduced Atlas

In [None]:
c18_reduced_atlas['label'] = c18_reduced_atlas['label'].apply(lambda x: x.encode('ascii','ignore').decode("ascii").lower())

In [None]:
atlas_cols = ['label', 'adduct', 'polarity', 'mz', 'rt_peak', 'rt_min', 'rt_max', 'inchi_key']

if not os.path.exists(experiment):
    os.mkdir(experiment)

c18_reduced_atlas[atlas_cols].to_csv(os.path.join(experiment, 'c18_{}_reduced_atlas.csv'.format(polarity)), index=False)