In [None]:
# import libraries

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import csv
import pickle
import re
import itertools
import time
from collections import Counter

import numpy as np
from scipy.special import softmax
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV

import tensorflow as tf
print("tf.__version__ =", tf.__version__)
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras import regularizers
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from spektral.layers import GCNConv, GlobalSumPool
from spektral.data import Dataset
from spektral.data import Graph
from spektral.data import BatchLoader

import matplotlib
from matplotlib import pyplot
%matplotlib inline
from IPython import display
display.set_matplotlib_formats('svg')

import glypy
from glypy.io.nomenclature import synonyms, identity
from glypy.plot import plot as glypy_plot
from glypy.io import glycoct as glypy_glycoct



In [None]:
#########################################################################
# I/O FUNCTIONS and DATA PRE-PROCESSING
#########################################################################

In [None]:
# Prepare input data

# input data folder includes the following main files:
#   "fraction_[1-5].mgf": MS2 spectra, exported from PEAKS GlycanFinder
#   "fraction_[1-5].labeled.csv": glycoPSM from PEAKS GlycanFinder database search
#   "glycans.txt": glycan database

# data_folder = "data.training/glycan/mouse_brain/"
# num_fractions = 5
# mode = 'training'
data_folder = "data.training/glycan/Demo_IgG_Orbitrap/"
num_fractions = 3
mode = 'evaluation'
print("data_folder =", data_folder)
print("num_fractions =", num_fractions)
print("mode =", mode)
print()

fraction_id_list = list(range(1, 1 + num_fractions))

# read mgf files
mgf_files = [data_folder + 'fraction_' + str(x) + '.mgf' for x in fraction_id_list]
input_spectrum_file = data_folder + "spectrum.mgf"
print("Prepare input_spectrum_file =", input_spectrum_file)
if os.path.exists(input_spectrum_file):
    print("input_spectrum_file exists!")
    print()
else:
    def merge_mgf_file(input_file_list, fraction_list, output_file):
        """Merge multiple mgf files into one, adding fraction ID to scan ID.

            Usage:
                folder_path = "data.training/aa.hla.bassani.nature_2016.mel_16.class_1/"
                fraction_list = range(0, 10+1)
                merge_mgf_file(
                    input_file_list=[folder_path + "export_" + str(i) + ".mgf" for i in fraction_list],
                    fraction_list=fraction_list,
                    output_file=folder_path + "spectrum.mgf")
        """

        print("merge_mgf_file()")

        # iterate over mgf files and their lines
        counter = 0
        with open(output_file, mode="w") as output_handle:
            for input_file, fraction in zip(input_file_list, fraction_list):
                print("input_file = ", os.path.join(input_file))
                with open(input_file, mode="r") as input_handle:
                    for line in input_handle:
                        if "SCANS=" in line: # a spectrum found
                            counter += 1
                            scan = re.split('=|\n|\r', line)[1]
                            # re-number scan id
                            output_handle.write("SCANS=F{0}:{1}\n".format(fraction, scan))
                        else:
                            output_handle.write(line)
        print("output_file = {0:s}".format(output_file))
        print("counter = {0:d}".format(counter))
        print()
    merge_mgf_file(mgf_files, fraction_id_list, input_spectrum_file)
print()

# store spectrum_location_dict to quickly retrieve spectrum from its scan id (copied from deepnovo_worker_io:get_location)
spectrum_location_file = input_spectrum_file + '.locations.pkl'
if os.path.exists(spectrum_location_file):
    with open(spectrum_location_file, 'rb') as fr:
        print("WorkerIO: read cached spectrum locations")
        data = pickle.load(fr)
        spectrum_location_dict, spectrum_rtinseconds_dict, spectrum_count = data
else:
    print("WorkerIO: build spectrum location from scratch")
    input_spectrum_handle = open(input_spectrum_file, 'r')
    spectrum_location_dict = {}
    spectrum_rtinseconds_dict = {}
    line = True
    while line:
        current_location = input_spectrum_handle.tell()
        line = input_spectrum_handle.readline()
        if "BEGIN IONS" in line:
            spectrum_location = current_location
        elif "SCANS=" in line:
            scan = re.split('=|\r|\n', line)[1]
            spectrum_location_dict[scan] = spectrum_location
        elif "RTINSECONDS=" in line:
            rtinseconds = float(re.split('=|\r|\n', line)[1])
            spectrum_rtinseconds_dict[scan] = rtinseconds
    spectrum_count = len(spectrum_location_dict)
    with open(spectrum_location_file, 'wb') as fw:
        pickle.dump((spectrum_location_dict, spectrum_rtinseconds_dict, spectrum_count), fw)
    input_spectrum_handle.close()
print("len(spectrum_location_dict) =", len(spectrum_location_dict))
print()
# funtion to retrieve spectrum from its scan id
def get_spectrum(input_spectrum_handle, spectrum_location_dict, scan):

    spectrum_location = spectrum_location_dict[scan]
    input_file_handle = input_spectrum_handle
    input_file_handle.seek(spectrum_location)

    # parse header lines
    line = input_file_handle.readline()
    assert "BEGIN IONS" in line, "Error: wrong input BEGIN IONS"
#     line = input_file_handle.readline()
#     assert "TITLE=" in line, "Error: wrong input TITLE="
#     line = input_file_handle.readline()
#     assert "PEPMASS=" in line, "Error: wrong input PEPMASS="
#     line = input_file_handle.readline()
#     assert "CHARGE=" in line, "Error: wrong input CHARGE="
#     line = input_file_handle.readline()
#     assert "SCANS=" in line, "Error: wrong input SCANS="
#     line = input_file_handle.readline()
#     assert "RTINSECONDS=" in line, "Error: wrong input RTINSECONDS="
    while not "RTINSECONDS=" in line:
        line = input_file_handle.readline()

    # parse fragment ions
    mz_list = []
    intensity_list = []
    line = input_file_handle.readline()
    while not "END IONS" in line:
        mz, intensity = re.split(' |\n', line)[:2]
        mz_float = float(mz)
        intensity_float = float(intensity)
        # skip an ion if its mass > MZ_MAX
#         if mz_float > 3000:
#             line = input_file_handle.readline()
#             continue
        mz_list.append(mz_float)
        intensity_list.append(intensity_float)
        line = input_file_handle.readline()

    return mz_list, intensity_list


# read csv files
if mode == 'prediction':
    csv_files = [data_folder + 'fraction_' + str(i) + '.csv' for i in fraction_id_list]
elif mode == 'evaluation' or mode == 'training':
    csv_files = [data_folder + 'fraction_' + str(i) + '.labeled.csv' for i in fraction_id_list]
    glycan_score_cutoff = 1.
    print("glycan_score_cutoff =", glycan_score_cutoff)
print("Prepare csv_files =", csv_files)
glycan_psm = {x:[] for x in fraction_id_list}
for fraction_id, csvfile in zip(fraction_id_list, csv_files):
    with open(csvfile, 'r') as csvfile:
        csvreader = csv.DictReader(csvfile)
        for row in csvreader:
            if mode == 'evaluation' or mode == 'training':
                glycan_score = float(row['Glycan Score'])
                if glycan_score < glycan_score_cutoff:
                    continue
            row['fraction_id'] = fraction_id
            glycan_psm[fraction_id].append(row)
total_psm = 0
for fraction_id, psm_list in glycan_psm.items():
    print("fraction_id =", fraction_id, ",", "len(psm_list) =", len(psm_list))
    total_psm += len(psm_list)
print("total_psm =", total_psm)
print()

# for evaluation for training, read glycan database
if mode == 'evaluation' or mode == 'training':
    glycan_db_file = data_folder + 'glycans.txt'
    print("Prepare glycan_db_file =", glycan_db_file)
    glycan_dict = {}
    with open(glycan_db_file, 'r') as handle:
        text = handle.read().strip()
        text = text.split('GLYCAN END')[:-1]
        text = [x.strip() for x in text]
        for block in text:
            lines = block.split('\n')
            glycan = {}
            res_lin = []
            for line in lines:
                line = line.strip()
                if 'GLYCAN' not in line:
                    res_lin.append(line)
                elif '=' in line:
                    k, v = line.split('=')
                    glycan[k] = v
            res_lin = '\n'.join(res_lin)
            glycan['GLYCAN'] = glypy_glycoct.loads(res_lin)
            glycan_dict[glycan['GLYCANID']] = glycan
    print("len(glycan_dict) =", len(glycan_dict))
    print()


In [None]:
# Define the sugar classes for the classification task

# they can be loaded from the file "sugar_classes.pkl"
sugar_classes_file = "sugar_classes.pkl"
if os.path.exists(sugar_classes_file):
    with open(sugar_classes_file, 'rb') as fr:
        print("Load sugar classes  from", sugar_classes_file)
        sugar_classes = pickle.load(fr)
# or they can be derived from the glycan_psm of the training data
elif mode == 'training':
    print("Collect sugar classes from glycan_psm")
    sugar_name_set = set()
    for fraction_id, psm_list in glycan_psm.items():
        for psm in psm_list:
            # read sugar names
            # many glycoct have the same shape and mass, but different colors
            # we can use sugar names (like super-classes) to reduce the number of classes
            sugar_name_list = re.split('\(|\)', psm['Glycan'])
            sugar_name_list = [x for x in sugar_name_list if x and not(x.isdigit())]
            sugar_name_set.update(set(sugar_name_list))
    # make sure that all sugar names exist in glypy
    #sugar_name_set.remove('HexA')
    #sugar_name_set.add('GlcA')
    #sugar_name_set.add('KDN')
    #sugar_name_set.add('Xyl')
    assert all([x in glypy.monosaccharides for x in sugar_name_set])
    sugar_classes = sorted(list(sugar_name_set))
    with open(sugar_classes_file, 'wb') as fw:
        pickle.dump(sugar_classes, fw)
num_sugars = len(sugar_classes)
print("sugar_classes = ", sugar_classes)
print("num_sugars = ", num_sugars)
for name in sugar_classes:
        print(name, glypy.monosaccharides[name].mass(), sep='\t')
print()

# function to retrieve the class of a sugar node in a glycan tree based on its mass
def get_class_name(node):
    
    sugar_mass_tolerance = 0.05
    node_mass = node.mass()
    node_name = ''
    for name in sugar_classes:
        if abs(node_mass - glypy.monosaccharides[name].mass()) < sugar_mass_tolerance:
            node_name = name
            break
    if not node_name:
        print("Unknown sugar: {}".format(node))
        print(node_mass)
        print(stop)
    return node_name


In [None]:
# Clone n_link_core & n_link_core_fuc of glycan trees

# n_link_core
n_link_core = glypy.glycans['N-Linked Core'].clone()
glypy_plot(n_link_core, label=True)
print("len(n_link_core.index) =", len(n_link_core.index))
print()

# n_link_core_fuc
sugar_fuc = glypy.monosaccharides['Fuc'].clone()
n_link_core_fuc = n_link_core.clone()
n_link_core_fuc.root.add_monosaccharide(sugar_fuc, position=6, child_position=1)
n_link_core_fuc.reindex(method='bfs')
glypy_plot(n_link_core_fuc, label=True)
print("len(n_link_core_fuc.index) =", len(n_link_core_fuc.index))
for leaf in n_link_core_fuc.index:
    print(leaf)
print()


In [None]:
# Essential functions to calculate B & Y ions, glycopsm, and glycan comparison
# only use y-ions ???

def get_b_y_set(glycan, resolution):

    mass_free_reducing_end = 18.0105546
    glycan_clone = glycan.clone()
    glycan_b_set = set()
    glycan_y_set = set()
    for links, frags in itertools.groupby(glycan_clone.fragments(), lambda f: f.link_ids.keys()):
        y_ion, b_ion = list(frags)
        y_mass_reduced = y_ion.mass - mass_free_reducing_end
        b_mass_int = int(round(b_ion.mass * resolution))
        y_mass_int = int(round(y_mass_reduced * resolution))
        glycan_b_set.add(b_mass_int)
        glycan_y_set.add(y_mass_int)
    
    return glycan_b_set, glycan_y_set


def compute_glycopsm_score(glycan, peptide_only_mass, mz1_list, intensity_list):

    # calculate theoretical b, y ions of the glycan
    resolution = 1e3
    glycan_b_set, glycan_y_set = get_b_y_set(glycan, resolution)
    glycan_b_list = sorted(list(glycan_b_set))
    glycan_y_list = sorted(list(glycan_y_set))
    glycopeptide_b_list = [float(x)/resolution for x in glycan_b_list]
    glycopeptide_y_list = [peptide_only_mass + float(x)/resolution for x in glycan_y_list]
    # only use y-ions
    glycopeptide_ion_list = glycopeptide_y_list
    num_glyco = len(glycopeptide_ion_list)

    # calculate the neutral masses and normalize the intensities of fragment ions in the spectrum
    charge = 1.0
    mass_H = 1.0078
    mz0_list = [mz1 - charge*mass_H for mz1 in mz1_list]
    num_mz0 = len(mz0_list)
    intensity_max = max(intensity_list)
    intensity_list = [x/intensity_max for x in intensity_list]

    # convert and broadcast to np arrays of shape (num_mz0, num_glyco)
    glycopeptide_array = np.array(glycopeptide_ion_list)
    glycopeptide_array = np.broadcast_to(glycopeptide_array, shape=(num_mz0, num_glyco))
    mz0_array = np.array(mz0_list)
    mz0_array = np.broadcast_to(mz0_array, shape=(num_glyco, num_mz0))
    mz0_array = np.transpose(mz0_array)
    intensity_array = np.array(intensity_list)
    intensity_array = np.broadcast_to(intensity_array, shape=(num_glyco, num_mz0))
    intensity_array = np.transpose(intensity_array)

    # calculate glycopsm as following:
    # intensity is weighted by mass error sigma (similar to Rui's PointNovo paper)
    # softmax along the mz dimension is used to select the fragment ion closest to the theoretical ion
    # add penalty if best peak has intensity < 0.5%
    C_const = 10.0
    delta = np.abs(mz0_array - glycopeptide_array)
    delta_C = -delta * C_const
    sigma = np.exp(delta_C)
    sigma_softmax = softmax(delta_C, axis=0)
    glycopsm = np.sum(sigma_softmax * sigma * intensity_array, axis=0)
    glycopsm = np.log((glycopsm+0.0001)/0.005)
    glycopsm_score = np.sum(glycopsm)
    
    return glycan_y_list, glycopsm


def test_glycan_accuracy(target_glycans, predict_glycans, top=1):
    
    print("test_glycan_accuracy()")
    
    resolution = 1e3
    num_targets = float(len(target_glycans))
    num_predicts = float(len([x for x in predict_glycans if x]))
    num_target_y = 0.
    num_predict_y = 0.
    num_correct_y = 0.
    num_correct_glycans = 0.
    with open('test_glycan_accuracy.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile, delimiter=',')
        csvwriter.writerow(['best_predict_y', 'best_correct_y', 'best_score'])
        for target, predict in zip(target_glycans, predict_glycans):
            target_b_set, target_y_set = get_b_y_set(target, resolution)
            num_target_y += len(target_y_set)
            best_predict_y = 0.
            best_correct_y = 0.
            best_correct_glycan = 0.
            best_score = None
            for candidate, score in predict[:top]:
                predict_b_set, predict_y_set = get_b_y_set(candidate, resolution) if candidate else (set(), set())
                correct_y_set = target_y_set.intersection(predict_y_set)
                correct_glycan = 1 if target_y_set == predict_y_set else 0
                if len(correct_y_set) > best_correct_y:
                    best_predict_y = len(predict_y_set)
                    best_correct_y = len(correct_y_set)
                    best_correct_glycan = correct_glycan
                    best_score = score

#                 if correct_glycan == 0:
#                     fig, axes = pyplot.subplots(1, 2)
#                     fig.set_size_inches(12, 4)
#                     glypy_plot(target, ax=axes[0], center=True)
#                     glypy_plot(candidate, ax=axes[1], center=True)
#                     print(stop)
#             if not predict:
#                 glypy_plot(target, center=True)
#                 print(stop)

            csvwriter.writerow([best_predict_y, best_correct_y, best_score])
            num_predict_y += best_predict_y
            num_correct_y += best_correct_y
            num_correct_glycans += best_correct_glycan
    
    sensitivity_y = num_correct_y / num_target_y
    sensitivity_glycan = num_correct_glycans / num_targets
    precision_y = num_correct_y / num_predict_y
    
    print("num_targets = ", num_targets)
    print("num_predicts = ", num_predicts)
    print("num_correct_glycans = ", num_correct_glycans)
    print("sensitivity_glycan = {:.2f}".format(sensitivity_glycan))
    print("num_target_y = ", num_target_y)
    print("num_predict_y = ", num_predict_y)
    print("num_correct_y = ", num_correct_y)
    print("sensitivity_y = {:.2f}".format(sensitivity_y))
    print("precision_y = {:.2f}".format(precision_y))
    print()


In [None]:
#########################################################################
# MODEL PREPARATION
#########################################################################

In [None]:
# Define Graph Neural Network model

class GnnModel(Model):

    def __init__(self, n_hidden, version):
        super().__init__()
        self.n_hidden = n_hidden
        self.version = version
        self.graph_conv = GCNConv(n_hidden, name='gcn_conv', kernel_regularizer=regularizers.l2(0.01))
        #self.dropout = Dropout(0.5, name='dropout')
        self.pool = GlobalSumPool(name='global_sum_pool')
        #self.dense_1 = Dense(n_hidden, name='dense_1', kernel_regularizer=regularizers.l2(0.01))
        self.dense_last = Dense(num_sugars,
                                kernel_regularizer=regularizers.l2(0.01),
#                                 activation='softmax',
                                name='dense_last')

    def call(self, inputs):
        
        outputs = []
        
        # glycan_y_superset model
        outputs.append(inputs[0])

        # 8 separate GNN models on 8 inputs
        for input_ in inputs[1:]:
            out = self.graph_conv(input_)
            #out = self.dropout(out)
            out = self.pool(out)
            #out = self.dense_1(out)
            outputs.append(out)
        
        # only use glycan_y_superset model or combine two models
        if self.version == 'linear':
            out = outputs[0]
        elif self.version == 'gnn':
            out = tf.concat(outputs, axis=1)

        # last logit layer
        out = self.dense_last(out)
        
        return out



In [None]:
# Function to convert glycan trees to graphs

def tree_to_graph(tree):

    nodes = []
    node_id_to_index ={}
    for node in tree.clone().index[::-1]:
        node_index = len(nodes)
        node_id = node.id
        parents = node.parents()
        if parents:
            node.drop_monosaccharide(parents[0][0])
        node_name = get_class_name(node)
        node_sugar_index = sugar_classes.index(node_name)
        nodes.append({'index': node_index, 'id': node_id, 'name': node_name, 'sugar_index': node_sugar_index})
        node_id_to_index[node_id] = node_index
    num_nodes = len(nodes)
    
    num_nodes_max = max(20, num_nodes)
    nodes_onehot = np.zeros((num_nodes_max, num_sugars)) # np.zeros((num_nodes, num_sugars))
    nodes_onehot[np.arange(num_nodes), np.array([node['sugar_index'] for node in nodes])] = 1
    # padding 'Hex' up to 20 nodes
    if num_nodes < num_nodes_max:
        nodes_onehot[num_nodes:num_nodes_max, sugar_classes.index('Hex')] = 1
    
    adjacency_matrix = np.zeros((num_nodes_max, num_nodes_max)) # np.zeros((num_nodes, num_nodes))
    for link in tree.link_index:
        parent_index = node_id_to_index[link.parent.id]
        child_index = node_id_to_index[link.child.id]
        adjacency_matrix[parent_index, child_index] = 1
    
    return nodes_onehot, adjacency_matrix

class Trees_to_Graphs(Dataset):
    """
    """
    def __init__(self, trees_labels, pseudo_graph, **kwargs):
        self.trees_labels = trees_labels
        self.pseudo_graph = pseudo_graph

        super().__init__(**kwargs)

    # The `download()` method is automatically called if the path returned by
    # `Dataset.path` does not exists (default `~/.spektral/datasets/ClassName/`).
    def download(self):
        data = ...  # Download from somewhere

    def read(self):
        # We must return a list of Graph objects
        output = []
        if self.pseudo_graph:
            output.append(self.pseudo_graph)
        for tree, label in self.trees_labels:
            graph_x, graph_a = tree_to_graph(tree)
            graph = Graph(x=graph_x, a=graph_a, y=label)
            output.append(graph)

        return output



In [None]:
# Functions to prepare training/testing data

# prepare training and testing samples from the input data
def prepare_training_samples(input_spectrum_file, glycan_psm, fraction_id_list):

    input_spectrum_handle = open(input_spectrum_file, 'r')
    sample_list = []
    for fraction_id in fraction_id_list:
        psm_list = glycan_psm[fraction_id]
        
        print("fraction_id = {0:d}, len(psm_list) = {1:d}".format(fraction_id, len(psm_list)))
        for index, psm in enumerate(psm_list):

            if ((index+1) % 1000 == 0):
                print("Processed {:d} PSMs...".format(index+1))

            # read peptide and glycan
            peptide_mass = float(psm['Mass'])
            target_glycan_id = psm['Glycan ID']
            target_glycan_mass = float(psm['Glycan Mass'])
            peptide_only_mass = peptide_mass - target_glycan_mass
            #target_glycan = glycan_dict[target_glycan_id]['GLYCAN'].clone()#index_method='bfs')
            target_glycan_2idx = [glycan_dict[target_glycan_id]['GLYCAN'].clone(),
                                  glycan_dict[target_glycan_id]['GLYCAN'].clone(index_method='bfs')]

            # read spectrum
            scan = 'F' + str(fraction_id) + ':' + psm['Scan']
            mz1_list, intensity_list = get_spectrum(input_spectrum_handle, spectrum_location_dict, scan)

            # recursively partition the glycan tree into tree_glycopsm_list
            tree_glycopsm_list = []
            for target_glycan in target_glycan_2idx:
                for leaf in target_glycan.index[::-1]:
                    parents = leaf.parents()
                    if parents:
                        leaf.drop_monosaccharide(parents[0][0])
                        leaf_name = get_class_name(leaf)
                        #subtree = target_glycan.clone()
                        candidate_glycopsm = []
                        candidate_tree = []
                        for name in sugar_classes:
                            sugar = glypy.monosaccharides[name].clone()
                            parents[0][1].add_monosaccharide(sugar)
                            glycan_y_list, glycopsm = compute_glycopsm_score(target_glycan, peptide_only_mass, mz1_list, intensity_list)
                            candidate_glycopsm.append((glycan_y_list, glycopsm))
                            candidate_tree.append(target_glycan.clone())
                            sugar.drop_monosaccharide(-1)
                        tree_glycopsm_list.append((leaf_name, candidate_glycopsm, candidate_tree))

#                         print("leaf_name = ", leaf_name)
#                         glypy_plot(subtree, center=True, label=True)
#                         print(candidate_glycopsm)
#                         print(stop)

            sample_list.append(tree_glycopsm_list)
    input_spectrum_handle.close()
    
    return sample_list


# convert leaf, glycopsm, trees of training/testing samples into np arrays
def prepare_np_arrays(sample_list, glycan_y_superset):
    x_array = []
    y_array = []
    x_len = len(glycan_y_superset)
    x_trees_labels = []
    for sample in sample_list:
        for leaf_name, candidate_glycopsm, candidate_tree in sample:
            x_subarray = np.zeros((num_sugars , x_len))
            for candidate, (glycan_y_list, glycopsm) in enumerate(candidate_glycopsm):
                for y, score in zip(glycan_y_list, glycopsm):
                    if y in glycan_y_superset:
                        idx = glycan_y_superset.index(y)
                        x_subarray[candidate, idx] = score
            x_array.append(x_subarray)
            y_array.append(sugar_classes.index(leaf_name))

            for tree in candidate_tree:
                #glypy_plot(tree, center=True, label=True)
                label = 0
                x_trees_labels.append((tree, label))
                
    x_array = np.array(x_array) # shape (num_samples*num_leaves, num_sugars, x_len)
    x_array = np.reshape(x_array, (-1, num_sugars *x_len)) # shape (num_samples*num_leaves, num_sugars*x_len)
    y_array = np.array(y_array) # shape (num_samples*num_leaves,)
    
    x_graphs = Trees_to_Graphs(x_trees_labels, pseudo_graph=None) # shape (num_samples*num_leaves*num_sugars,)
    # use BatchLoader to do zero padding for all graphs to have the same size
    # put all graphs into 1 batch
    loader = BatchLoader(x_graphs, batch_size=len(x_graphs), shuffle=False)
    x_n_array = []
    x_a_array = []
    for batch in loader.load():
        # each batch is a tuple (inputs, labels)
        # inputs is a tuple containing:
        # 0: node attributes of shape [batch, n_max, n_node_features];
        # 1: adjacency matrices of shape [batch, n_max, n_max];
        x_n_array.append(batch[0][0]) # node features
        x_a_array.append(batch[0][1]) # adjacency matrix
        # because batch_size=len(x_graphs), one iteration is enough
        break
    x_n_array = np.array(x_n_array) # shape (1, num_samples*num_leaves*num_sugars, n_max, n_node_features)
    x_a_array = np.array(x_a_array) # shape (1, num_samples*num_leaves*num_sugars, n_max, n_max)
    # reshape to (num_samples*num_leaves, num_sugars, , )
    n_shape = list(x_n_array.shape)
    a_shape = list(x_a_array.shape)
    n_shape[0] = -1
    n_shape[1] = num_sugars
    a_shape[0] = -1
    a_shape[1] = num_sugars
    x_n_array = x_n_array.reshape(n_shape)
    x_a_array = x_a_array.reshape(a_shape)

    return x_array, y_array, x_n_array, x_a_array



In [None]:
#############################################################################################
# MODEL TRAINING
# (if testing, skip this section and go straight to MODEL TESTING below)
# (if denovo sequencing, skip this section and go straight to MODEL DENOVO SEQUENCING below)
#############################################################################################

In [None]:
# Prepare training data

# prepare training_samples and glycan_y_superset
# since they are too large, we save/load them from hard disk to reuse
if os.path.exists("training_arrays/x_training.npy"):
    print("Load training data")
    x_training = np.load("training_arrays/x_training.npy")
    y_training = np.load("training_arrays/y_training.npy")
    x_n_training = np.load("training_arrays/x_n_training.npy")
    x_a_training = np.load("training_arrays/x_a_training.npy")
else:
    print("Prepare training samples")
    fraction_id_list = [2, 3, 4, 5]
    training_samples = prepare_training_samples(input_spectrum_file, glycan_psm, fraction_id_list)
    print("len(training_samples) = ", len(training_samples))
    glycan_y_superset = [w for x in training_samples for _,y,_ in x for z,_ in y for w in z]
    print("len(glycan_y_superset) = ", len(glycan_y_superset))
    glycan_y_superset = set(glycan_y_superset)
    glycan_y_superset = sorted(list(glycan_y_superset))
    print("len(glycan_y_superset) = ", len(glycan_y_superset))
    with open("training_arrays/glycan_y_superset.pkl", 'wb') as f:
        pickle.dump(glycan_y_superset, f)
    print()

    # convert training_samples into np arrays
    # since they are too large, we save/load them from hard disk to reuse
    x_training, y_training, x_n_training, x_a_training = prepare_np_arrays(training_samples, glycan_y_superset)
    np.save("training_arrays/x_training.npy", x_training)
    np.save("training_arrays/y_training.npy", y_training)
    np.save("training_arrays/x_n_training.npy", x_n_training)
    np.save("training_arrays/x_a_training.npy", x_a_training)

# check the shapes of the arrays
print("x_training.shape = ", x_training.shape) # shape (num_samples*num_leaves, num_sugars*x_len)
print("y_training.shape = ", y_training.shape) # shape (num_samples*num_leaves,)
print("x_training[0] = ", x_training[0])
print("y_training[0] = ", y_training[0])
print("x_n_training.shape = ", x_n_training.shape) #  (num_samples*num_leaves, num_sugars, n_max, n_node_features)
print("x_a_training.shape = ", x_a_training.shape) #  (num_samples*num_leaves, num_sugars, n_max, n_max)
print()

# train_test_split and normalization
x_train, x_test, x_n_train, x_n_test, x_a_train, x_a_test, y_train, y_test = train_test_split(
    x_training, 
    x_n_training, 
    x_a_training, 
    y_training, test_size=0.1, random_state=99)
x_training_train_mean = np.mean(x_train, axis=0)
# #x_training_train_std = np.std(x_train, axis=0) # the array is sparse, too many 0
np.save("training_arrays/x_training_train_mean.npy", x_training_train_mean)
x_train_norm = (x_train - x_training_train_mean)
x_test_norm = (x_test - x_training_train_mean)
print("x_train_norm.shape = ", x_train.shape)
print("x_test_norm.shape = ", x_test.shape)
print("x_n_train.shape = ", x_n_train.shape)
print("x_n_test.shape = ", x_n_test.shape)
print("x_a_train.shape = ", x_a_train.shape)
print("x_a_test.shape = ", x_a_test.shape)
print()

# transpose and group x, x_n, x_a into x_inputs, which is a list of 9 input arrays
x_n_train, x_n_test, x_a_train, x_a_test = [np.transpose(x, axes=[1,0,2,3]) for x in [x_n_train, x_n_test, x_a_train, x_a_test]]
print("x_train_norm.shape = ", x_train.shape)
print("x_test_norm.shape = ", x_test.shape)
print("x_n_train.shape = ", x_n_train.shape)
print("x_n_test.shape = ", x_n_test.shape)
print("x_a_train.shape = ", x_a_train.shape)
print("x_a_test.shape = ", x_a_test.shape)
x_inputs_train = [x_train_norm] + [(n, a) for n, a in zip(x_n_train, x_a_train)]
x_inputs_test = [x_test_norm] + [(n, a) for n, a in zip(x_n_test, x_a_test)]
print("len(x_inputs_train) = ", len(x_inputs_train))
print("len(x_inputs_test) = ", len(x_inputs_test))
print()


In [None]:
# delete training samples to free up memory
# they are no longer needed after train_test_split
del x_training, x_n_training, x_a_training, y_training
del training_samples


In [None]:
# Create and train model

# create model
version = 'gnn'
model = GnnModel(n_hidden=16, version=version)
model._name = version + '_model'
checkpoint_path = "model_copy/cp.ckpt"
print("checkpoint_path =", checkpoint_path)
# train model
loss_func = SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='Adam', loss=loss_func, metrics=['accuracy'])
model_checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True, save_weights_only=True)
history = model.fit(x_inputs_train, y_train,
                    epochs=10,
                    validation_split=0.1,
                    callbacks=[model_checkpoint],
                    verbose=0,
                   )

# summary of model training
print(model.summary())
loss = history.history['loss']
val_loss = history.history['val_loss']
accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
epochs = range(1, len(loss)+1)
print("min(loss) = {:.2f}".format(min(loss)))
print("min(val_loss) = {:.2f}".format(min(val_loss)))
print("max(accuracy) = {:.2f}".format(max(accuracy)))
print("max(val_accuracy) = {:.2f}".format(max(val_accuracy)))

# plot train/valid loss and accuracy
# pyplot.subplot(1, 2, 1)
# pyplot.plot(epochs, loss, label='loss')
# pyplot.plot(epochs, val_loss, label='val_loss')
# pyplot.xlabel('Epochs')
# pyplot.ylabel('Loss')
# pyplot.legend()

# pyplot.subplot(1, 2, 2)
# pyplot.plot(epochs, accuracy, label='accuracy')
# pyplot.plot(epochs, val_accuracy, label='val_accuracy')
# pyplot.xlabel('Epochs')
# pyplot.ylabel('Loss')
# pyplot.legend()


In [None]:
# delete training data to free up memory
# they are no longer needed after training
del x_train, x_test, x_n_train, x_n_test, x_a_train, x_a_test, y_train, y_test


In [None]:
#########################################################################
# MODEL TESTING
# (if denovo sequencing, skip this section and go straight to MODEL DENOVO SEQUENCING below)
######################################################################################################################################################################

In [None]:
# Prepare testing data

# prepare testing_samples and load glycan_y_superset
# since they are too large, we save/load them from hard disk to reuse
if os.path.exists("training_arrays/x_testing.npy"):
    print("Load testing data")
    x_testing = np.load("training_arrays/x_testing.npy")
    y_testing = np.load("training_arrays/y_testing.npy")
    x_n_testing = np.load("training_arrays/x_n_testing.npy")
    x_a_testing = np.load("training_arrays/x_a_testing.npy")
else:
    print("Prepare testing samples")
    fraction_id_list = [1]
    testing_samples = prepare_training_samples(input_spectrum_file, glycan_psm, fraction_id_list)
    print("len(testing_samples) = ", len(testing_samples))
    with open("training_arrays/glycan_y_superset.pkl", 'rb') as f:
        glycan_y_superset = pickle.load(f)
    print("len(glycan_y_superset) = ", len(glycan_y_superset))
    print()

    # convert testing_samples into np arrays
    # since they are too large, we save/load them from hard disk to reuse
    x_testing, y_testing, x_n_testing, x_a_testing = prepare_np_arrays(testing_samples, glycan_y_superset)
    np.save("training_arrays/x_testing.npy", x_testing)
    np.save("training_arrays/y_testing.npy", y_testing)
    np.save("training_arrays/x_n_testing.npy", x_n_testing)
    np.save("training_arrays/x_a_testing.npy", x_a_testing)

# check the shapes of the arrays
print("x_testing.shape = ", x_testing.shape)
print("y_testing.shape = ", y_testing.shape)
print("x_n_testing.shape = ", x_n_testing.shape)
print("x_a_testing.shape = ", x_a_testing.shape)
print()

# normalization
x_training_train_mean = np.load("training_arrays/x_training_train_mean.npy")
x_testing_norm = (x_testing - x_training_train_mean)
print("x_testing_norm.shape = ", x_testing_norm.shape)
print()

# transpose and group x, x_n, x_a into x_inputs, which is a list of 9 input arrays
x_n_testing, x_a_testing = [np.transpose(x, axes=[1,0,2,3]) for x in [x_n_testing, x_a_testing]]
print("x_testing_norm.shape = ", x_testing_norm.shape)
print("x_n_testing.shape = ", x_n_testing.shape)
print("x_a_testing.shape = ", x_a_testing.shape)
x_inputs_testing = [x_testing_norm] + [(n, a) for n, a in zip(x_n_testing, x_a_testing)]
print("len(x_inputs_testing) = ", len(x_inputs_testing))
print()


In [None]:
# Create, load and test model

# create, compile  and load model
version = 'gnn'
model = GnnModel(n_hidden=16, version=version)
model._name = version + '_model'
checkpoint_path = "model_" + version + "/cp.ckpt"
print("checkpoint_path =", checkpoint_path)
loss_func = SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='Adam', loss=loss_func, metrics=['accuracy'])
model.load_weights(checkpoint_path)

# test model
testing_loss, testing_accuracy = model.evaluate(x_inputs_testing, y_testing)
print('testing_loss = {:.2f}'.format(testing_loss))
print('testing_accuracy = {:.2f}'.format(testing_accuracy))
# double-check test set performance: accuracy (sensitivity/specificity/precision not applicable)
y_prob = np.array(model.predict(x_inputs_testing))
y_pred = np.argmax(y_prob, axis=1)
y_correct = y_testing == y_pred
print('double-check accuracy = {:.2f}'.format(np.sum(y_correct) / len(y_testing)))
print()


In [None]:
#########################################################################
# MODEL DENOVO SEQUENCING
#########################################################################

In [None]:
# Predict de novo glycans with the pretrained model

def gnn_denovo(mode, input_spectrum_file, glycan_psm, fraction_id_list,
               version, checkpoint_path, x_training_train_mean_path, glycan_y_superset_path):
    
    print("gnn_denovo()")

    # create, compile  and load model
    model = GnnModel(n_hidden=16, version=version)
    model._name = version + '_model'
    model.load_weights(checkpoint_path)

    # load x_train_mean of the training data for normalization
    x_training_train_mean = np.load(x_training_train_mean_path)

    # load glycan_y_superset of the training data
    input_spectrum_handle = open(input_spectrum_file, 'r')
    with open(glycan_y_superset_path, 'rb') as f:
        glycan_y_superset = pickle.load(f)

    # denvo sequencing parameters
    target_glycan_ids = []
    target_glycans = []
    predict_glycans = []
    beam_size = 1
    print("beam_size =", beam_size)
    delta_mass_tolerance = 1.
    delta_mass_left = 100.
    # use a pseudo graph to control padding
    pseudo_x = np.random.rand(18, 8)
    pseudo_a = np.random.randint(0, 2, (18, 18))
    pseudo_y = 0
    pseudo_graph = None #Graph(x=pseudo_x, a=pseudo_a, y=pseudo_y)


    for fraction_id in fraction_id_list:
        psm_list = glycan_psm[fraction_id]
        
#         # unseen test
#         psm_list = [x for x in psm_list if x['Glycan'] in unseen_glycan_composition]
        
        print("fraction_id = {0:d}, len(psm_list) = {1:d}".format(fraction_id, len(psm_list)))
        for index, psm in enumerate(psm_list[:]):
            if ((index+1) % 100 == 0):
                print("Processed {:d} PSMs...".format(index+1))

            # read peptide and glycan
            if mode == 'evaluation':
                peptide = psm['Peptide']
                peptide_mass = float(psm['Mass'])
                target_glycan_id = psm['Glycan ID']
                target_glycan_ids.append(target_glycan_id)
                target_glycan = glycan_dict[target_glycan_id]['GLYCAN'].clone()
                target_glycans.append(target_glycan)
                target_glycan_mass = float(psm['Glycan Mass'])
                peptide_only_mass = peptide_mass - target_glycan_mass
            elif mode == 'prediction':
                # only do prediction on PSM with empty Glycan ID
                if psm['Glycan ID']:
                    predict_glycans.append([])
                    continue
                target_glycan_mass = float(psm['Glycan Mass'])
                peptide_only_mass = float(psm['PepMass'])
            
            # note that glypy adds reducing end to the glycan mass, it's 18.01 Da larger than PEAKS reported mass
            #print(target_glycan_mass)
            #print(target_glycan.mass())
            target_glycan_mass += 18.012115

            # read spectrum
            scan = 'F' + str(fraction_id) + ':' + psm['Scan']
            mz1_list, intensity_list = get_spectrum(input_spectrum_handle, spectrum_location_dict, scan)


            # find next candidate by iteratively adding 1 node to each leaf of each current candidate glycan
            final_candidates = []
            final_scores = []
            x_len = len(glycan_y_superset)
            for core_glycan in [n_link_core_fuc.clone(), n_link_core.clone()]:
                core_len = len(core_glycan.index)
                current_candidates = [core_glycan]
                while current_candidates:
                    # contruct next_candidates glycans 
                    next_candidates = []
                    x_array = []
                    for glycan in current_candidates:
                        glycan.reindex(method='bfs')
                        # core leaves
                        leaves = [x for x in glycan.index[core_len-2:core_len] if len(x.children()) < 1]
                        # branch leaves
                        leaves += [x for x in glycan.index[core_len:] if len(x.children()) < 1]
                        for leaf in leaves[:]:
        #                     for glycoct in sugar_dict:
        #                         sugar = sugar_dict[glycoct].clone()
                            x_subarray = np.zeros((num_sugars, x_len))
                            for candidate, name in enumerate(sugar_classes):
                                sugar = glypy.monosaccharides[name].clone()
                                leaf.add_monosaccharide(sugar)
        #                         glycan.reindex()
        #                         glycan.canonicalize()
                                glycan_y_list, glycopsm = compute_glycopsm_score(glycan, peptide_only_mass, mz1_list, intensity_list)
                                for y, score in zip(glycan_y_list, glycopsm):
                                    if y in glycan_y_superset:
                                        idx = glycan_y_superset.index(y)
                                        x_subarray[candidate, idx] = score
                                next_candidates.append(glycan.clone())
                                sugar.drop_monosaccharide(-1)
                            x_array.append(x_subarray)

                    # extract input features for glycan_y_superset model # (candidates * leaves, num_sugars*x_len)
                    x_array = np.array(x_array) # (candidates * leaves, num_sugars, x_len)
                    x_array = np.reshape(x_array, (-1, num_sugars*x_len))
                    x_array_norm = (x_array - x_training_train_mean)

                    # extract input trees for GNN model
                    x_trees_labels = [(tree, 0) for tree in next_candidates]
                    # convert trees to graphs
                    x_graphs = Trees_to_Graphs(x_trees_labels, pseudo_graph) # shape (1 + num_samples*num_leaves*num_sugars,)
                    # use BatchLoader to do zero padding for all graphs to have the same size
                    # put all graphs into 1 batch
                    loader = BatchLoader(x_graphs, batch_size=len(x_graphs), shuffle=False)
                    x_n_array = []
                    x_a_array = []
                    for batch in loader.load():
                        # each batch is a tuple (inputs, labels)
                        # inputs is a tuple containing:
                        # 0: node attributes of shape [batch, n_max, n_node_features];
                        # 1: adjacency matrices of shape [batch, n_max, n_max];
                        x_n_array.append(batch[0][0]) # node features
                        x_a_array.append(batch[0][1]) # adjacency matrix
                        # because batch_size=len(x_graphs), one iteration is enough
                        break
                    x_n_array = np.array(x_n_array) # shape (1, 1 + num_samples*num_leaves*num_sugars, n_max, n_node_features)
                    x_a_array = np.array(x_a_array) # shape (1, 1 + num_samples*num_leaves*num_sugars, n_max, n_max)
                    # exclude pseudo_graph if not None
                    if pseudo_graph:
                        x_n_array = x_n_array[:,1:,:,:]
                        x_a_array = x_a_array[:,1:,:,:]
                    # reshape to (num_samples*num_leaves, num_sugars, , )
                    n_shape = list(x_n_array.shape)
                    a_shape = list(x_a_array.shape)
                    n_shape[0] = -1
                    n_shape[1] = num_sugars
                    a_shape[0] = -1
                    a_shape[1] = num_sugars
                    x_n_array = x_n_array.reshape(n_shape)
                    x_a_array = x_a_array.reshape(a_shape)

                    # combine inputs
                    x_n_array, x_a_array = [np.transpose(x, axes=[1,0,2,3]) for x in [x_n_array, x_a_array]]
                    x_inputs = [x_array_norm] + [(n, a) for n, a in zip(x_n_array, x_a_array)]

                    # predict scores for next_candidates, sort and keep top-10
                    next_scores = model.predict(x_inputs) # (candidates * leaves, num_sugars)

                    # only allow 1 sugar with max score per leaf
        #             next_candidates_split = [next_candidates[i:i + num_sugars] for i in range(0, len(next_candidates), num_sugars)]
        #             next_scores_argmax = np.argmax(next_scores, axis=1)
        #             next_candidates = [x[y] for x, y in zip(next_candidates_split, next_scores_argmax)]
        #             next_scores = np.amax(next_scores, axis=1)
                    # allow all sugars per leaf
                    next_scores = next_scores.flatten() # (candidates * leaves * num_sugars)

                    # check candidates for mass, sort and keep top-10
                    current_candidates = []
                    current_scores = []
                    for glycan, score in zip(next_candidates, next_scores):
                        if abs(target_glycan_mass - glycan.mass()) <= delta_mass_tolerance:
                            final_candidates.append(glycan.clone())
                            final_scores.append(score)
                        elif glycan.mass() < (target_glycan_mass - delta_mass_left):
                            current_candidates.append(glycan.clone())
                            current_scores.append(score)
                    if current_candidates:
                        current_sorted = sorted(zip(current_candidates, current_scores), key=lambda pair: -pair[1])
                        current_candidates = [x for x, y in current_sorted[:beam_size]]

            # sort final candidates
            if final_candidates:
                final_sorted = sorted(zip(final_candidates, final_scores), key=lambda pair: -pair[1])
                predict_glycans.append(final_sorted)
            else:
                predict_glycans.append([])
        
    input_spectrum_handle.close()
    return target_glycans, predict_glycans

# input_spectrum_file: already defined in I/O FUNCTIONS and DATA PRE-PROCESSING
# glycan_psm: already defined in I/O FUNCTIONS and DATA PRE-PROCESSING
fraction_id_list = [1, 2, 3]
version = 'gnn'
checkpoint_path = "model_" + version + "/cp.ckpt"
print("checkpoint_path =", checkpoint_path)
x_training_train_mean_path = "training_arrays/x_training_train_mean.npy"
glycan_y_superset_path = "training_arrays/glycan_y_superset.pkl"
basic_denovo_time = time.time()
target_glycans, predict_glycans = gnn_denovo(
    mode,
    input_spectrum_file, glycan_psm, fraction_id_list,
    version, checkpoint_path, x_training_train_mean_path, glycan_y_superset_path)
basic_denovo_time = time.time() - basic_denovo_time
print("basic_denovo_time =", basic_denovo_time)

if mode == 'evaluation':
    test_glycan_accuracy(target_glycans[:1000], predict_glycans[:1000], top=1)
elif mode == 'prediction':
    fieldnames = list(glycan_psm[1][0].keys())
    fieldnames += ['denovo glycan ID', 'denovo glycan score', 'denovo glycan composition']
    with open(data_folder + 'denovo_glycan.csv', 'w', newline='') as csvfile:
        with open(data_folder + 'denovo_glycan.txt', 'w') as txtfile:
            csvwriter = csv.DictWriter(csvfile, fieldnames)
            csvwriter.writeheader()
            psm_list_flatten = [psm for x in fraction_id_list for psm in glycan_psm[x]]
            num_predict = 0
            unique_compositions = set()
            for psm, candidates in zip(psm_list_flatten, predict_glycans):
                if candidates:
                    num_predict += 1
                    glycan, score = candidates[0]
                    row = psm.copy()
                    composition = str(glypy.structure.glycan_composition.GlycanComposition.from_glycan(glycan))
                    if composition not in unique_compositions:
                        unique_compositions.add(composition)
                    # output to csv file
                    row['denovo glycan ID'] = num_predict
                    row['denovo glycan score'] = score
                    row['denovo glycan composition'] = composition
                    csvwriter.writerow(row)
                    # output to txt file
                    lines = ['GLYCAN START\n']
                    lines.append('GLYCANID={0}\n'.format(num_predict))
                    lines.append('GLYCAN_MASS={0}\n'.format(glycan.mass() - 18.012115))
                    lines.append('GLYCAN_TAXON=Others\n')
                    lines.append(str(glycan))
                    lines.append('GLYCAN END\n')
                    lines.append('\n')
                    txtfile.write(''.join(lines))
        print("len(psm_list_flatten) =", len(psm_list_flatten))
        print("num_predict =", num_predict)
        print("len(unique_compositions) =", len(unique_compositions))
        print()

