# LOAD UNIPROT DATA AND TOKENIZATION

In [None]:
import torch
import sys
import torch.nn as nn
import math
import urllib.request as request
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from contextlib import closing
from pathlib import Path
import numpy as np 
import torch.optim as optim
import pickle
from torch.optim import lr_scheduler
import copy
import time
from torch.nn import Linear, GRU, Conv2d, Dropout, MaxPool2d, BatchNorm1d
from torch.nn.functional import relu, elu, relu6, sigmoid, tanh, softmax
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import re
import tape
from tape import ProteinBertModel
from tape import ProteinConfig
import regex as re
from sklearn import metrics


#Seed
torch.manual_seed(0)
np.random.seed(0)


def write_test_set(dataset, data_path, glyc_type):
    """writes the test parition data to fasta and txt file to enable use in netO and netN models"""
    inp_tokenizer = TAPETokenizer(vocab='iupac')
    tar_tokenizer = TAPETokenizer(vocab='glycolysation')
    seqs = list()
    tars = list()
    special_tokens = ("<pad>", "<mask>", "<cls>", "<sep>", "<unk>")
    for i in range(1000):
        as_protein = inp_tokenizer.convert_ids_to_tokens(dataset.inputs[i].data.numpy())
        as_protein = "".join([x for x in as_protein if x not in special_tokens])
        header = ">test_seq_{}\n".format(i + 1)
        seq = header + as_protein
        seqs.append(seq)

        as_one_hot = tar_tokenizer.convert_ids_to_tokens(dataset.labels[i].data.numpy())
        as_one_hot = np.array([x for x in as_one_hot if x not in special_tokens], dtype=int)
        tars.append(as_one_hot)

    with open(data_path / "test_seqs_{}.fsa".format(glyc_type), "w") as seq_file:
        for seq, tar in zip(seqs, tars):
            print(seq, end="\n", file=seq_file)
    np.savez_compressed(data_path / "test_tars_{}.npz".format(glyc_type), tars)


def write_token_dataset_to_original(dataset):
    """writes the tokenized parition data back to original dataset format"""
    
    inp_tokenizer = TAPETokenizer(vocab='iupac')
    tar_tokenizer = TAPETokenizer(vocab='glycolysation')
    seqs = list()
    tars = list()
    
    for i in range(len(dataset)):
        as_protein = inp_tokenizer.convert_ids_to_tokens(dataset.inputs[i].data.numpy())
        seq = "".join([x for x in as_protein if x not in ("<pad>", "<mask>", "<cls>", "<sep>", "<unk>")])
        seqs.append(seq)

        as_one_hot = tar_tokenizer.convert_ids_to_tokens(dataset.labels[i].data.numpy())
        special_tokens = ("<pad>", "<mask>", "<cls>", "<sep>", "<unk>")
        as_one_hot = np.array([x for x in as_one_hot if x not in special_tokens], dtype=int)
        tars.append(as_one_hot)
        
    return seqs, tars  


def concat_data(inp, tar, inp_non, tar_non):
    """
    concatenates glyco and non glyco protein data
    """
    rand_idx = np.random.permutation(len(inp_non))[0:len(inp)]
    inp_non = inp_non[rand_idx, :]
    tar_non = tar_non[rand_idx, :]
    inp = np.concatenate([inp, inp_non], axis=0)
    tar = np.concatenate([tar, tar_non], axis=0)
    return inp, tar


def save_tokenized_data(inp, tar, save_path):
    print("Saving tokenized dataset in train, val and prediction partitions, to {}".format(save_path))
    inp = torch.from_numpy(inp)
    tar = torch.from_numpy(tar)
    train, validation, prediction = construct_datasets(
        inp, 
        tar, 
        Dataset, 
        p_train=0.8, 
        p_val=0.1, 
        p_test=0.1
    )

    save_train_inp = np.vstack([t[0].data.numpy() for t in train])
    save_train_tar = np.vstack([t[1].data.numpy() for t in train])
    save_val_inp = np.vstack([v[0].data.numpy() for v in validation])
    save_val_tar = np.vstack([v[1].data.numpy() for v in validation])
    save_pred_inp = np.vstack([p[0].data.numpy() for p in prediction])
    save_pred_tar = np.vstack([p[1].data.numpy() for p in prediction])
    np.savez_compressed(
            save_path,
            train_inp_seq=save_train_inp,
            train_tar_seq=save_train_tar,
            val_inp_seq=save_val_inp,
            val_tar_seq=save_val_tar,
            pred_inp_seq=save_pred_inp,
            pred_tar_seq=save_pred_tar,
        )


# paths and other static stuff
ROOT_DIR = Path.cwd()
full_data_url = r'ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.xml.gz'
glyc_only_data_url = r'https://www.uniprot.org/uniprot/?query=annotation:(type:carbohyd)&fil=reviewed%3Ayes&format=xml&compress=yes'

datapath = ROOT_DIR / 'data'
datapath.mkdir(exist_ok=True)

full_data_path = ROOT_DIR / 'data' / 'uniprot_sprot.xml.gz'
glyc_only_data_path = ROOT_DIR / 'data' / 'glyconly_uniprot_sprot.xml.gz'

tokenized_data_path = ROOT_DIR / 'data' / f'tokenized_seq_data.npz'
tokenized_n_glyc_path = ROOT_DIR / 'data' / f'n_glyc_data.npz'
tokenized_o_glyc_path = ROOT_DIR / 'data' / f'o_glyc_data.npz'

bert_embedded_n_glyc_path = ROOT_DIR / 'data' / f'bert_embedded_n_glyc_data.npz'
bert_embedded_o_glyc_path = ROOT_DIR / 'data' / f'bert_embedded_o_glyc_data.npz'

loss_data_path = ROOT_DIR / 'data' / 'loss_data.npz'
checkpoint_path = ROOT_DIR / 'data' / 'model_checkpoint.pt'

glyc_type = "n"
batch_size = 1
max_sequence_length = 2000  # we set a max sequence length to avoid inflation of tensors by rare large proteins

# Get Swiss-Prot data if not already downloaded
if not full_data_path.is_file():
    print("Downloading {}".format(full_data_url))
    with closing(request.urlopen(full_data_url)) as r:
        with open(full_data_path, 'wb') as f:
            shutil.copyfileobj(r, f)

# glycosylated proteins
if not glyc_only_data_path.is_file():
    print("Downloading {}".format(glyc_only_data_url))
    with closing(request.urlopen(glyc_only_data_url)) as r:
        with open(glyc_only_data_path, 'wb') as f:
            shutil.copyfileobj(r, f)

# if data is not tokenized, tokenize data
if not tokenized_data_path.is_file():
    print("Tokenizing dataset from {}".format(full_data_path))
    
    dataset = load_unencoded_data(full_data_path)
    inp_n, inp_o, inp_non, tar_n, tar_o, tar_non  = tokenize_dataset(
        dataset, 
        max_length=max_sequence_length
    )  # we filter very long proteins which inflate tensor size
    
    inp_n = pad_sequences(inp_n, max_length=max_sequence_length)
    tar_n = pad_sequences(tar_n, max_length=max_sequence_length)
    inp_o = pad_sequences(inp_o, max_length=max_sequence_length)
    tar_o = pad_sequences(tar_o, max_length=max_sequence_length)
    inp_non = pad_sequences(inp_non, max_length=max_sequence_length)
    tar_non = pad_sequences(tar_non, max_length=max_sequence_length)
    
    np.savez_compressed(
        tokenized_data_path,
        inp_seq_n=inp_n,
        tar_seq_n=tar_n,
        inp_seq_o=inp_o,
        tar_seq_o=tar_o,
        inp_seq_non=inp_non,
        tar_seq_non=tar_non
    )
    
    data_load = np.load(tokenized_data_path, allow_pickle=True)
    
    # glycosyaled protein data
    inp_o = data_load['inp_seq_o']
    tar_o = data_load['tar_seq_o']
    
    inp_n = data_load['inp_seq_n']
    tar_n = data_load['tar_seq_n']
    inp_n = inp_n[0:3000]  # subsample n-glyc dataset as it is too large to handle
    tar_n = tar_n[0:3000]
    
    # non glycosyaled protein data
    inp_non = data_load['inp_seq_non']
    tar_non = data_load['tar_seq_non']
    
    inp_o, tar_o= concat_data(inp_o, tar_o, inp_non, tar_non)
    inp_n, tar_n = concat_data(inp_n, tar_n, inp_non, tar_non)
    

# if data is already encoded, simply load data
else:
    print("Loading tokenized dataset {}".format(tokenized_data_path))
    data_load = np.load(tokenized_data_path, allow_pickle=True)
    
    # glycosyaled protein data
    inp_o = data_load['inp_seq_o']
    tar_o = data_load['tar_seq_o']
    
    inp_n = data_load['inp_seq_n']
    tar_n = data_load['tar_seq_n']
    inp_n = inp_n[0:3000]  # subsample n-glyc dataset as it is too large to handle
    tar_n = tar_n[0:3000]
    
    # non glycosyaled protein data
    inp_non = data_load['inp_seq_non']
    tar_non = data_load['tar_seq_non']
    
    # concatenating glyco and non glyco protein data
    inp_o, tar_o = concat_data(inp_o, tar_o, inp_non, tar_non)
    inp_n, tar_n = concat_data(inp_n, tar_n, inp_non, tar_non)
    
#save training, validation and prediction datasets
if not tokenized_n_glyc_path.is_file():
    save_tokenized_data(inp_n, tar_n, tokenized_n_glyc_path)  
if not tokenized_o_glyc_path.is_file():
    save_tokenized_data(inp_o, tar_o, tokenized_o_glyc_path)

# BERT PROTEIN EMBEDDING

In [None]:
### Custom BERT BOY ###
class ProteinBERTBoy(nn.Module):
    """Example: Loading some linear layers on top of the pretrained Bert model"""
    
    def __init__(self):
        super(ProteinBERTBoy, self).__init__()
        self.bertbase = ProteinBertModel.from_pretrained('bert-base')
        self.bertbase.max_length_embedding = max_sequence_length
    
    def forward(self, source_seq):
        BERT_encoded_protein, _ = self.bertbase(source_seq)
        return BERT_encoded_protein


def BertEncode(BertModel, loader, number_of_proteins):
    """
    Use pretrained protein BERT model to encode tokenized protein data.
    set number_of_proteins to None to encode all proteins.
    """
    count = 0
    labels = list()
    start_stop_tokens = {5, 6}
    encoded_proteins = list()
    protein_lengths = list()
    with torch.set_grad_enabled(False):
        for src, tgt in loader:
            if count == number_of_proteins:
                break
            print(f"Bert encoding protein: {count}")
            encoded_protein = BertModel(src)
            encoded_protein = encoded_protein.squeeze()
            encoded_protein = encoded_protein.numpy()
            label = tgt.squeeze()
            label = label.numpy()

            store_index = []
            for i in range(max_sequence_length):
                check = label[i]
                if check in start_stop_tokens:  # store start and stop token indices
                    store_index.append(i)
            protein_lengths.append(len(label[store_index]))
            
            encoded_proteins.append(encoded_protein)
            labels.append(label)
            count += 1            
    return labels, encoded_proteins, protein_lengths


def bert_encode_dataset(data_load_path, save_path):
    """encodes a tokenized dataset with bert"""
    data_load = np.load(data_load_path, allow_pickle=True)
    
    train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']))
    train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=False)
    
    val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']))
    val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=False)
    
    pred = Dataset(torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq']))
    pred_loader = data.DataLoader(pred, batch_size=batch_size, shuffle=False)
    
    print("BERT encoding train dataset")
    train_labels, train_embedded_proteins, train_protein_lengths = BertEncode(BertModel_instance, train_loader, None)
    print("BERT encoding validation dataset")
    val_labels, val_embedded_proteins, val_protein_lengths = BertEncode(BertModel_instance, val_loader, None)
    print("BERT encoding prediction dataset")
    pred_labels, pred_embedded_proteins, pred_protein_lengths = BertEncode(BertModel_instance, pred_loader, None)
    
    print("Saving BERT embedded dataset in train, val and prediction partitions, to {}".format(save_path))
    np.savez_compressed(
        save_path,
        train_inp_seq=train_embedded_proteins,
        train_tar_seq=train_labels,
        train_protein_lengths = train_protein_lengths,
        val_inp_seq=val_embedded_proteins,
        val_tar_seq=val_labels,
        val_protein_lengths = val_protein_lengths,
        pred_inp_seq=pred_embedded_proteins,
        pred_tar_seq=pred_labels,
        pred_protein_lengths = pred_protein_lengths
    )
    return (
        train_labels, 
        train_embedded_proteins, 
        train_protein_lengths, 
        val_labels, 
        val_embedded_proteins, 
        val_protein_lengths, 
        pred_labels, 
        pred_embedded_proteins, 
        pred_protein_lengths
    )

# freezing the parameters of BERT
BertModel_instance = ProteinBERTBoy()
for param in BertModel_instance.parameters():
    param.requires_grad = False

# saving BERT embedded dataset
if not bert_embedded_n_glyc_path.is_file():
    encoded_n_data = bert_encode_dataset(tokenized_n_glyc_path, bert_embedded_n_glyc_path)
else:
    pass

if not bert_embedded_o_glyc_path.is_file():
    encoded_o_data = bert_encode_dataset(tokenized_o_glyc_path, bert_embedded_o_glyc_path)
else:
    pass

# Downsampling training and validation datasets (only used in BERT + FFN)

In [None]:
encoded_o_data =  "foo"
encoded_n_data =  "foo"
def downsampling(labels, encoded_proteins):
    """
    Downsampling the dataset to have equally many 
    glycosylated and non-glycosylated positions.
    """
    glycosylated_positions = 0
    non_glycosylated_positions = 0
    glyco_encodings = list()
    non_glyco_encodings = list()
    labels = labels.flatten()
    encoded_positions = encoded_proteins.reshape(-1, 768) 
    
    for i in range(len(labels)):
        encoded_position = encoded_positions[i]
        if labels[i] == 6:
            glyco_encodings.append(encoded_position)
            glycosylated_positions += 1
        elif labels[i] == 5:
            non_glyco_encodings.append(encoded_position)
            non_glycosylated_positions += 1
        else:
            #Ignore padded positions
            pass
            
    non_glyco_encodings = np.asarray(non_glyco_encodings)
    rand_idx = np.random.permutation(len(non_glyco_encodings))[0:glycosylated_positions]
    dsampled_non_glyco_encodings = non_glyco_encodings[rand_idx] 
    
    glyco_encodings = np.asarray(glyco_encodings)

    dsampled_encodings = np.concatenate([dsampled_non_glyco_encodings, glyco_encodings])
    dsampled_labels = np.concatenate([np.zeros(glycosylated_positions, dtype=int),
                                             np.ones(glycosylated_positions, dtype=int)])
    
    return dsampled_encodings, dsampled_labels

def downsample_sequences(bert_encoded_data_path, save_path):
    """
    Run downsampling on train and validation partitions.
    """
    if bert_encoded_data_path.is_file():
        try:
            data_load = np.load(bert_encoded_data_path, allow_pickle=True)
            
            train_seqs = data_load['train_inp_seq'] 
            train_labels = data_load['train_tar_seq']
            
            val_seqs = data_load['val_inp_seq']
            val_labels = data_load['val_tar_seq']

        except KeyError:
            sys.exit("Object from data file could not be loaded. Please ensure that you are loading BERT embedded data file generated\
            from above cells.")

    train_seq_pos_dsampled, train_label_pos_dsampled = downsampling(train_labels, train_seqs)
    val_seq_pos_dsampled, val_label_pos_dsampled = downsampling(val_labels, val_seqs)
   
    save_train_inp = train_seq_pos_dsampled
    save_train_tar = train_label_pos_dsampled
    save_val_inp = val_seq_pos_dsampled
    save_val_tar = val_label_pos_dsampled
   
    np.savez_compressed(
            save_path,
            train_inp_seq=save_train_inp,
            train_tar_seq=save_train_tar,
            val_inp_seq=save_val_inp,
            val_tar_seq=save_val_tar
        )
    return (
        train_seq_pos_dsampled, 
        train_label_pos_dsampled, 
        val_seq_pos_dsampled, 
        val_label_pos_dsampled 
    )

downsampled_encoded_glyco_position_data_path_n = datapath /  'dsampled_BERT_encoded_n_glyc_data_t.npz' 
downsampled_encoded_glyco_position_data_path_o = datapath /  'dsampled_BERT_encoded_o_glyc_data_t.npz'

if not downsampled_encoded_glyco_position_data_path_n.is_file():
    print("Downsampling N-glycosylated dataset")
    downsampled_n_data = downsample_sequences(bert_embedded_n_glyc_path, downsampled_encoded_glyco_position_data_path_n)
else:
    pass

if not downsampled_encoded_glyco_position_data_path_o.is_file():
    print("Downsampling O-glycosylated dataset")
    downsampled_o_data = downsample_sequences(bert_embedded_o_glyc_path, downsampled_encoded_glyco_position_data_path_o)
else:
    pass

# One Hot/tokenized Encoding

In [None]:
def tokenized_to_one_hot(data_load_path, save_path):
    data_load = np.load(data_load_path, allow_pickle=True)
    
    train_inp, train_tar = torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq'])
    val_inp, val_tar = torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq'])
    pred_inp, pred_tar = torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq'])

    train_inp = torch.nn.functional.one_hot(train_inp, num_classes=30)
    val_inp = torch.nn.functional.one_hot(val_inp, num_classes=30)
    pred_inp = torch.nn.functional.one_hot(pred_inp, num_classes=30)
    
    np.savez_compressed(
        save_path,
        train_inp_seq=train_inp,
        train_tar_seq=train_tar,
        val_inp_seq=val_inp,
        val_tar_seq=val_tar,
        pred_inp_seq=pred_inp,
        pred_tar_seq=pred_tar,
    )
    
one_hot_path_o = datapath / 'one_hot_o_glyc_data.npz'
one_hot_path_n = datapath / 'one_hot_n_glyc_data.npz'

if not one_hot_path_o.is_file():
    tokenized_to_one_hot(tokenized_o_glyc_path, one_hot_path_o)
else:
    pass

if not one_hot_path_n.is_file():
    tokenized_to_one_hot(tokenized_n_glyc_path, one_hot_path_n)
else:
    pass

# Models: Training and Prediction

In [None]:
def write_token_dataset_to_original(dataset):
    """writes the test parition data to fasta and txt file to enable use in other models"""
    
    inp_tokenizer = TAPETokenizer(vocab='iupac')
    tar_tokenizer = TAPETokenizer(vocab='glycolysation')
    seqs = list()
    tars = list()
    
    for i in range(len(dataset)):
        as_protein = inp_tokenizer.convert_ids_to_tokens(dataset.inputs[i].data.numpy())
        seq = "".join([x for x in as_protein if x not in ("<pad>", "<mask>", "<cls>", "<sep>", "<unk>")])
        seqs.append(seq)

        as_one_hot = tar_tokenizer.convert_ids_to_tokens(dataset.labels[i].data.numpy())
        special_tokens = ("<pad>", "<mask>", "<cls>", "<sep>", "<unk>")
        as_one_hot = np.array([x for x in as_one_hot if x not in special_tokens], dtype=int)
        tars.append(as_one_hot)
        
    return seqs, tars


def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, epochs=20):
    """Function for training model"""
    # training loop
    best_loss = 10000000 #Just needs to be a high number 
    train_loss = list()
    val_loss = list()
    for epoch in range(epochs):
        print(f'Epoch:{epoch + 1}/{epochs}')
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
                dataset_size = t_dataset_size
            else:
                model.eval()
                loader = val_loader
                dataset_size = v_dataset_size
            cur_loss = 0
            running_acc = 0

            for x, y in loader:
                x = x.to(device)
                y = y.to(device)
                # reset parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    output = model(x)
                    _, preds = torch.max(output, 1)
                    loss = criterion(output, y)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                #stats
                cur_loss += loss.item()
                running_acc += torch.sum(preds == y)

            epoch_loss = cur_loss / dataset_size
            epoch_acc = running_acc / dataset_size
            
            if phase == 'train':
                train_loss.append(epoch_loss)
                scheduler.step()
            elif phase == 'val':
                val_loss.append(epoch_loss)
            
            #save best model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_weights = copy.deepcopy(model.state_dict())
            
            #print stats
            print(f'{phase} loss: {epoch_loss} accuracy: {epoch_acc}')
            
    model.load_state_dict(best_weights)
    return model


def predict_glycosylation_from_embedded_FFN(dataset, positions_to_predict, the_model):
    """Reads embedded PROTEIN BERT dataset and makes glycosylation prediciton on positions. 
    Returns np array of probability assigned to glycosylation, target, predictions and probability assigned to correct class."""
    
    pred_loader = data.DataLoader(prediction_dataset, batch_size=1, shuffle=False)
    the_model.eval()
    
    count = 0
    probs_for_true_class = list()
    probs_for_glyc_only = list() #Needs this for ROC and AUC curve
    predictions = list() #Need this for confusion matrix
    targets = list() 
    softmax_function = nn.Softmax(dim=1)

    with torch.set_grad_enabled(False):
        for AA, tgt in pred_loader:

            if count == positions_to_predict:
                break 

            # No glycosyaltion
            if tgt.item() == 5:
                target_index = 0
                targets.append(target_index)
            # Glycosylation
            elif tgt.item() == 6:
                target_index = 1
                targets.append(target_index)
            else:
                #Ignore padding
                target_index = "padding"

            if target_index != "padding":
                output = the_model(AA)
                _, preds = torch.max(output, 1)
                predictions.append(preds)
                probs = softmax_function(output).squeeze()
                probs_for_true_class.append(probs[target_index].item())
                probs_for_glyc_only.append(probs[1].item())

            count += 1
            
        output_predictions = np.vstack([np.asarray(probs_for_glyc_only), np.asarray(targets, dtype=int),
                                       np.asarray(predictions), np.asarray(probs_for_true_class)])
        
    return output_predictions

def train_model_transformer(model, criterion, optimizer, scheduler, train_loader, val_loader, pad_index, epochs=20):
    """Function for training transformer model"""
    # training loop
    best_loss = 10000000 #Just needs to be a high number 
    train_loss = list()
    val_loss = list()
    for epoch in range(epochs):
        print(f'Epoch:{epoch + 1}/{epochs}')
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader
            cur_loss = 0
            running_acc = 0
            n_batches = 0

            for x, y in loader:
                n_batches += 1
                x = x.to(device)
                y = y.to(device)
                # reset parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    output = model(x[:, 1:], y[:, :-1])

                    loss = criterion(output.transpose(1, 2), y[:, 1:])

                    pad_mask = y[:, 1:] != pad_index
                    preds = output.argmax(2)[pad_mask]
                    acc = (preds == y[:, 1:][pad_mask]).sum().to(dtype=torch.float) / preds.shape[0]

                    if phase == 'train':
                        loss.backward()
                        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                        optimizer.step()
                        
                #stats
                cur_loss += loss.item()
                running_acc += acc.item()

            epoch_loss = cur_loss / n_batches
            epoch_acc = running_acc / n_batches
            
            if phase == 'train':
                train_loss.append(epoch_loss)
                scheduler.step(epoch_loss)
            elif phase == 'val':
                val_loss.append(epoch_loss)
            
            #save best model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_weights = copy.deepcopy(model.state_dict())
            
            #print stats
            print(f'{phase} loss: {epoch_loss} accuracy: {epoch_acc}')
            
    model.load_state_dict(best_weights)
    return model

def predict_glycosylation_from_transformer(prediction_dataset, model, device):
    """Reads dataset and makes glycosylation prediciton on positions. 
    Returns np array of probability assigned to glycosylation, target, predictions and probability assigned to correct class."""
    
    model.eval()
    
    probs_for_true_class = list()
    probs_for_glyc_only = list() #Needs this for ROC and AUC curve
    predictions = list() #Need this for confusion matrix
    targets = list() 

    start_label = 2
    end_label = 3
    negative_label = 5
    positive_label = 6

    for i, (seq, tgt) in enumerate(prediction_dataset):
        print(f"Sequence {i+1} / {len(prediction_dataset)}")#, end='\r')

        # Remove start token & padding
        end_indices = np.where(seq == end_label)[0]
        if end_indices.shape[0] > 0:
            end_index = end_indices[0]
        else:
            end_index = seq.shape[0]
        seq = seq[1:end_index]
        tgt = tgt[1:end_index]

        tgt_labels = [1 if AA == positive_label else 0 for AA in tgt]
        targets.extend(tgt_labels)

        preds, probs = predict_sequence_glycosylation_transformer(model, seq, start_label, device)

        pred_labels = [1 if AA == positive_label else 0 for AA in preds]
        predictions.extend(pred_labels)

        prob_positions = torch.arange(0, probs.shape[0])
        probs_for_true_class.extend(probs[prob_positions, tgt].tolist())
        probs_for_glyc_only.extend(probs[prob_positions, positive_label].tolist())

    output_predictions = np.vstack([np.asarray(probs_for_glyc_only), np.asarray(targets, dtype=int),
                                np.asarray(predictions, dtype=int), np.asarray(probs_for_true_class)])
        
    return output_predictions

def predict_sequence_glycosylation_transformer(model, seq, start_label, device):
    with torch.no_grad():

        preds = [start_label]
        probs = []

        seq = torch.tensor(seq, device=device).unsqueeze(0)

        for _ in range(seq.shape[1]):
            tgt = torch.tensor(preds, device=device).unsqueeze(0)

            output = model(seq, tgt)

            next_preds = output[0, -1, :]
            next_label = next_preds.argmax()
            next_probs = softmax(next_preds, dim=0)

            preds.append(next_label.item())
            probs.append(next_probs)

    probs = torch.stack(probs)
    return preds[1:], probs

def cm_analysis(y_true, y_pred, labels, ymap=None, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    #https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7 steal!
    args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      labels:    string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    if ymap is not None:
        y_pred = [ymap[yi] for yi in y_pred]
        y_true = [ymap[yi] for yi in y_true]
        labels = [ymap[yi] for yi in labels]
    cm = confusion_matrix(y_true, y_pred)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
#                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
                annot[i, j] = f'{round(p,2)}%\n{c}'
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'True'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', ax=ax, cmap='OrRd', linewidths = 1.5, linecolor='black')
    plt.savefig("confusion_matrix", dpi=800)


def predict_glycosylation(prob_glyc, threshold = 0.5):
    """Needed this function for LSTM predictions, where the data was probability for glycosylation"""
    predictions = list()
    for prob in prob_glyc:
        if prob >= threshold:
            predictions.append(1)
        else:
            predictions.append(0)
            
    return predictions


def parse_output_file(path, new_protein_regex, glyc_pred_regex, no_match_string):
    """Parses the output file of netN or netO glyc tools and transforms them into a usable format."""
    net_preds = list()
    tar_seq = None
    parse_flag = False
    with open(path, "r") as file:
        for line in file:
            line = line.strip()
            new_protein = re.search(new_protein_regex, line)
            if new_protein is not None:
                length = int(new_protein.group(1))
                tar_seq = np.zeros(length)
            glyc_pred = re.search(glyc_pred_regex, line)
            if glyc_pred is not None:
                position = int(glyc_pred.group(1)) - 1
                score = float(glyc_pred.group(2))
                tar_seq[position] = score
                parse_flag = True
            if line[0:4] == "----" and parse_flag is True:
                net_preds.append(tar_seq)
                tar_seq = None
                parse_flag = False
            elif line == no_match_string:
                net_preds.append(tar_seq)
                tar_seq = None
                parse_flag = False
    return np.array(net_preds, dtype=object)    
    
    
# paths and other static stuff
ROOT_DIR = Path.cwd()
datapath = ROOT_DIR / 'data'
full_data_path = ROOT_DIR / 'data' / 'uniprot_sprot.xml.gz'
glyc_only_data_path = ROOT_DIR / 'data' / 'glyconly_uniprot_sprot.xml.gz'

encoded_data_path = ROOT_DIR / 'data' / f'encode_and_decode_seq_data.npz'
encoded_n_glyc_path = ROOT_DIR / 'data' / f'n_glyc_data.npz'
encoded_o_glyc_path = ROOT_DIR / 'data' / f'o_glyc_data.npz'
bert_embedded_n_glyc_path = ROOT_DIR / 'data' / f'bert_embedded_n_glyc_data.npz'
bert_embedded_o_glyc_path = ROOT_DIR / 'data' / f'bert_embedded_o_glyc_data.npz'

loss_data_path = ROOT_DIR / 'data' / 'loss_data.npz'
checkpoint_path = ROOT_DIR / 'data' / 'model_checkpoint.pt'

model_save_dir = ROOT_DIR / 'models'
model_save_dir.mkdir(exist_ok=True)

#device = torch.device("cuda")
device = torch.device("cpu")

# Transformer O glyc

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, embedding_size, dropout=0.1, max_seq_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_seq_len, embedding_size)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * (-math.log(10000.0) / embedding_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class BERTEncoder(nn.Module):

    def __init__(self):
        super(BERTEncoder, self).__init__()

        self.bert = tape.ProteinBertModel.from_pretrained('bert-base')

    def freeze_encoder(self):
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, src):
        # src: [batch_size, seq_length]

        mem_embed, _ = self.bert(src)

        # [batch_size, seq_length, embed_size]
        return mem_embed


class TransformerBERTDecoder(nn.Module):

    def __init__(
        self,
        trg_vocab_size,
        num_heads,
        num_layers,
        dim_feedforward,
        dropout,
        trg_pad_idx,
        max_seq_len,
        device,
        bert_embedding_size=768
    ):
        super(TransformerBERTDecoder, self).__init__()

        self.bert_embedding_size = bert_embedding_size
        self.sqrt_embedding_size = math.sqrt(self.bert_embedding_size)
        self.dropout = dropout
        self.trg_pad_idx = trg_pad_idx
        self.device = device

        self.trg_word_embedding = nn.Embedding(trg_vocab_size, self.bert_embedding_size)
        self.trg_position_embedding = PositionalEncoding(
            self.bert_embedding_size, self.dropout, max_seq_len
        )

        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.bert_embedding_size, nhead=num_heads,
            dim_feedforward=dim_feedforward, dropout=self.dropout
        )

        self.decoder_norm = nn.LayerNorm(normalized_shape=self.bert_embedding_size)

        self.decoder = nn.TransformerDecoder(
            decoder_layer=self.decoder_layer, num_layers=num_layers, norm=self.decoder_norm
        )
        
        self.out_fc = nn.Linear(
            in_features=self.bert_embedding_size,
            out_features=trg_vocab_size
        )

        self.initialize_decoder()

    def initialize_decoder(self):
        for param in self.decoder.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def generate_square_subsequent_mask(self, sz):
            mask = torch.triu(torch.ones((sz, sz), device=self.device)).transpose(0, 1)
            mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0)
            return mask

    def change_max_seq_len(self, max_seq_len):
        self.trg_position_embedding = PositionalEncoding(
            self.bert_embedding_size, self.dropout, max_seq_len
        ).to(self.device)

    def forward(self, mem_embed, trg):
        # mem_embed: [batch_size, seq_length, embed_size]
        # trg: [batch_size, seq_length]

        trg_embed = self.trg_word_embedding(trg)

        trg_embed = trg_embed * self.sqrt_embedding_size

        trg_embed = self.trg_position_embedding(trg_embed)

        trg_padding_mask = (trg == self.trg_pad_idx) # [batch_size, seq_length]

        trg_mask = self.generate_square_subsequent_mask(trg.shape[1]) # [seq_length, seq_length]

        out = self.decoder(
            trg_embed.transpose(0, 1),
            mem_embed.transpose(0, 1),
            tgt_mask=trg_mask,
            tgt_key_padding_mask=trg_padding_mask
        ).transpose(0, 1) # [batch_size, seq_length, embed_size]

        # [batch_size, seq_length, out_classes]
        return self.out_fc(out)


class TransformerBERTEncoderDecoder(nn.Module):

    def __init__(
        self,
        trg_vocab_size,
        num_heads,
        num_layers,
        dim_feedforward,
        dropout,
        trg_pad_idx,
        max_seq_len,
        device,
        bert_embedding_size=768
    ):
        super(TransformerBERTEncoderDecoder, self).__init__()

        self.encoder = BERTEncoder()

        self.decoder = TransformerBERTDecoder(
            trg_vocab_size,
            num_heads,
            num_layers,
            dim_feedforward,
            dropout,
            trg_pad_idx,
            max_seq_len,
            device,
            bert_embedding_size
        )

    def forward(self, src, trg):
        mem_embed = self.encoder(src)
        out = self.decoder(mem_embed, trg)
        return out

In [None]:
### Load data in dataloaders ###
glyc_type = "o"
batch_size = 6

train_npz = np.load(datapath / "o_embedded_train.npz", allow_pickle=True)
train = Dataset(torch.from_numpy(train_npz['train_inp_seq']), torch.from_numpy(train_npz['train_tar_seq']))
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val_npz = np.load(datapath / "o_embedded_val.npz", allow_pickle=True)
val = Dataset(torch.from_numpy(val_npz['val_inp_seq']), torch.from_numpy(val_npz['val_tar_seq']))
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
### Training ###

# Training hyperparameters
num_epochs = 20
learning_rate = 1e-6
device = torch.device("cuda")
save_model = 'yes_please'
timestamp = str(time.time())

# Model hyperparameters
num_heads = 8
num_layers = 3
dim_feedforward = 1024
dropout = 0.1
max_seq_len = 2500
trg_vocab_size = 7 # GLYC_VOCAB
trg_pad_idx = 0
trg_positive_idx = 6

model = TransformerBERTDecoder(
#model = TransformerBERTEncoderDecoder(
    trg_vocab_size=trg_vocab_size,
    num_heads=num_heads,
    num_layers=num_layers,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    trg_pad_idx=trg_pad_idx,
    max_seq_len=max_seq_len,
    device=device
).to(device)
#model.encoder.freeze_encoder()

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

loss_weights = torch.ones((trg_vocab_size,), device=device)
loss_weights[trg_positive_idx] = 100
criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx, weight=loss_weights).to(device)

trained_model = train_model_transformer(model, criterion, optimizer, scheduler, train_loader, val_loader, pad_index=trg_pad_idx, epochs=num_epochs)
if save_model == 'yes_please':
    filename = 'Transformer_atop_BERT_model' + '_' + glyc_type + '-' + timestamp
    torch.save(trained_model, model_save_dir / filename)

In [None]:
### Prediction ###
device = torch.device("cuda")
pred_npz = np.load(datapath / "o_embedded_pred.npz", allow_pickle=True)
cm_plot_labels = ["No O glycosylation", "O glycosylation"]
save_pred_path = datapath / 'Transformer_o_glyc_predictions'
pred = Dataset(torch.from_numpy(pred_npz['pred_inp_seq']), torch.from_numpy(pred_npz['pred_tar_seq']))

trained_model.to(device)
trained_model.decoder.device = device
out = predict_glycosylation_from_transformer(pred, trained_model, device)
pred=out[0] # Predicted glyc probability 
tar_seq=out[1] # Target label
actual_pred=out[2] # Predicted label
true_class_prob=out[3] # Predicted probability of target label
np.savez_compressed(save_pred_path, pred=out[0], tar_seq=out[1], actual_pred=out[2], true_class_prob=out[3])
cm_analysis(tar_seq, actual_pred, cm_plot_labels)

# Transformer N-glyc

In [None]:
### Load data in dataloaders ###
glyc_type = "n"
batch_size = 6

train_npz = np.load(datapath / "n_embedded_train.npz", allow_pickle=True)
train = Dataset(torch.from_numpy(train_npz['train_inp_seq']), torch.from_numpy(train_npz['train_tar_seq']))
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val_npz = np.load(datapath / "n_embedded_val.npz", allow_pickle=True)
val = Dataset(torch.from_numpy(val_npz['val_inp_seq']), torch.from_numpy(val_npz['val_tar_seq']))
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
### Training ###

# Training hyperparameters
num_epochs = 20
learning_rate = 1e-6
device = torch.device("cuda")
save_model = 'yes_please'
timestamp = str(time.time())

# Model hyperparameters
num_heads = 8
num_layers = 3
dim_feedforward = 1024
dropout = 0.1
max_seq_len = 2500
trg_vocab_size = 7 # GLYC_VOCAB
trg_pad_idx = 0
trg_positive_idx = 6

model = TransformerBERTDecoder(
#model = TransformerBERTEncoderDecoder(
    trg_vocab_size=trg_vocab_size,
    num_heads=num_heads,
    num_layers=num_layers,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    trg_pad_idx=trg_pad_idx,
    max_seq_len=max_seq_len,
    device=device
).to(device)
#model.encoder.freeze_encoder()

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

loss_weights = torch.ones((trg_vocab_size,), device=device)
loss_weights[trg_positive_idx] = 100
criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx, weight=loss_weights).to(device)

trained_model = train_model_transformer(model, criterion, optimizer, scheduler, train_loader, val_loader, pad_index=trg_pad_idx, epochs=num_epochs)
if save_model == 'yes_please':
    filename = 'Transformer_atop_BERT_model' + '_' + glyc_type + '-' + timestamp
    torch.save(trained_model, model_save_dir / filename)

In [None]:
### Prediction ###
device = torch.device("cuda")
pred_npz = np.load(datapath / "n_embedded_pred.npz", allow_pickle=True)
cm_plot_labels = ["No N glycosylation", "N glycosylation"]
save_pred_path = datapath / 'Transformer_n_glyc_predictions'
pred = Dataset(torch.from_numpy(pred_npz['pred_inp_seq']), torch.from_numpy(pred_npz['pred_tar_seq']))

trained_model.to(device)
trained_model.decoder.device = device
out = predict_glycosylation_from_transformer(pred, trained_model, device)
pred=out[0] # Predicted glyc probability 
tar_seq=out[1] # Target label
actual_pred=out[2] # Predicted label
true_class_prob=out[3] # Predicted probability of target label
np.savez_compressed(save_pred_path, pred=out[0], tar_seq=out[1], actual_pred=out[2], true_class_prob=out[3])
cm_analysis(tar_seq, actual_pred, cm_plot_labels)

# FFN N-glyc

In [None]:
class FFN(nn.Module):
    def __init__(self):
        super(FFN, self).__init__()
        #First input: BERT embedding size
        bert_position_embedding_size = 768 
        self.ff_l1_size = 450
        self.ff_l2_size = 300
        self.ff_l3_size = 200
        num_of_class = 2 #Gly no Gly
        
        self.ff = nn.Sequential(nn.Linear(bert_position_embedding_size, self.ff_l1_size),
                                nn.ReLU(),
                                BatchNorm1d(self.ff_l1_size),
                                nn.Dropout(0.35),
                                nn.Linear(self.ff_l1_size, self.ff_l2_size),
                                nn.ReLU(),
                                BatchNorm1d(self.ff_l2_size),
                                nn.Dropout(0.35),
                                nn.Linear(self.ff_l2_size, self.ff_l3_size),
                                nn.ReLU(),
                                BatchNorm1d(self.ff_l3_size),
                                nn.Linear(self.ff_l3_size, num_of_class))      
    def forward(self, x):
        out = self.ff(x)
        return out

In [None]:
### Load data in dataloaders ###
glyc_type = "n"
bert_embedded_data_path = datapath / 'dsampled_BERT_encoded_n_glyc_data.npz'
batch_size = 2
data_load = np.load(bert_embedded_data_path, allow_pickle=True)
train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)
t_dataset_size = len(train)

val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)
v_dataset_size = len(val)

In [None]:
### Training ###
ffn = FFN()
ffn.cuda()
device = torch.device("cuda")
timestamp = str(time.time())
criterion = nn.CrossEntropyLoss()
LEARNING_RATE = 0.001
optimizer = optim.Adam(ffn.parameters(), lr=LEARNING_RATE)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

save_model = 'yes_please'
trained_model = train_model(ffn, criterion, optimizer, exp_lr_scheduler, train_loader, val_loader, epochs=10)
if save_model == 'yes_please':
    filename = 'FFN_atop_BERT_model' + '_' + glyc_type + '-' + timestamp
    torch.save(trained_model, model_save_dir / filename)

In [None]:
### Predictions on prediction dataset ###
data_load = np.load(datapath /  'bert_embedded_n_glyc_data.npz', allow_pickle=True)
cm_plot_labels = ["No N glycosyalation", "N glcyosylation"]
save_pred_path = datapath / 'FFN_n_glyc_predictions'
trained_model.eval()
trained_model.cpu()   
pred_seqs = torch.from_numpy(data_load['pred_inp_seq']) 
pred_labels = torch.from_numpy( data_load['pred_tar_seq']).long()

### FFN takes one AA at a time
pred_seqs = pred_seqs.reshape(-1, 768) 
pred_labels = pred_labels.flatten()

prediction_dataset = Dataset(pred_seqs, pred_labels)
device = torch.device("cpu")
out = predict_glycosylation_from_embedded_FFN(prediction_dataset, None, trained_model)
pred=out[0]
tar_seq=out[1]
actual_pred=out[2]
true_class_prob=out[3]
np.savez_compressed(save_pred_path, glyc_prob=out[0], tar_seq=out[1], actual_pred=out[2], true_class_prob=out[3])
cm_analysis(tar_seq, actual_pred, cm_plot_labels)

# FFN O-glyc

In [None]:
### Load data in dataloaders ###
glyc_type = "o"
bert_embedded_data_path = datapath / 'dsampled_BERT_encoded_o_glyc_data.npz'
batch_size = 2
data_load = np.load(bert_embedded_data_path, allow_pickle=True)
train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)
t_dataset_size = len(train)
val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)
v_dataset_size = len(val)

In [None]:
### Training ###
ffn = FFN()
ffn.cuda()
device = torch.device("cuda")
timestamp = str(time.time())
criterion = nn.CrossEntropyLoss()
LEARNING_RATE = 0.001
optimizer = optim.Adam(ffn.parameters(), lr=LEARNING_RATE)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

save_model = 'yes_please'
trained_model = train_model(ffn, criterion, optimizer, exp_lr_scheduler, train_loader, val_loader, epochs=10)
if save_model == 'yes_please':
    filename = 'FFN_atop_BERT_model' + '_' + glyc_type + '-' + timestamp
    torch.save(trained_model, model_save_dir / filename)

In [None]:
### Prediction ###
data_load = np.load(datapath /  'bert_embedded_o_glyc_data.npz', allow_pickle=True)
cm_plot_labels = ["No O glycosyalation", "O glcyosylation"]
save_pred_path = datapath / 'FFN_o_glyc_predictions'
trained_model.eval()
trained_model.cpu()
pred_seqs = torch.from_numpy(data_load['pred_inp_seq']) 
pred_labels = torch.from_numpy( data_load['pred_tar_seq']).long() 

### FFN takes one AA at a time
pred_seqs = pred_seqs.reshape(-1, 768) 
pred_labels = pred_labels.flatten()

prediction_dataset = Dataset(pred_seqs, pred_labels)

device = torch.device("cpu")
out = predict_glycosylation_from_embedded_FFN(prediction_dataset, None, trained_model)
pred=out[0]
tar_seq=out[1]
actual_pred=out[2]
true_class_prob=out[3]
np.savez_compressed(save_pred_path, glyc_prob=out[0], tar_seq=out[1], actual_pred=out[2], true_class_prob=out[3])
cm_analysis(tar_seq, actual_pred, cm_plot_labels)

# LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(
            self,
            hidden_size = 50,
            src_vocab_size = 30,
            tgt_vocab_size = 7
    ):
        super(LSTM, self).__init__()

        # Recurrent layer
        self.lstm = nn.LSTM(input_size=src_vocab_size,
                            hidden_size=hidden_size,
                            num_layers=3,
                            bidirectional=False,
                            dropout=0.1)

        # Output layer
        self.l_out = nn.Linear(in_features=hidden_size,
                               out_features=tgt_vocab_size,
                               bias=False)

    def forward(self, x):
        # RNN returns output and last hidden state
        x, (h, c) = self.lstm(x)
        x = x.view(-1, self.lstm.hidden_size)
        x = F.relu(x)
        x = self.l_out(x)

        return x
    

def train_lstm(model, optimizer, lr_sched, criterion, load_model, save_model, epochs, checkpoint_path, loss_data_path, device):
    """
    Training loop for LSTM
    
    """
    
    train_loss = list()
    valid_loss = list()
    test_loss = list()

    if load_model:
        train_loss, valid_loss = load_checkpoint(model, optimizer, checkpoint_path, loss_data_path)

    # training loop
    print("Starting training for {} epochs".format(epochs))
    for epoch in range(epochs):
        model.train()
        cur_train_loss = 0
        cur_valid_loss = 0
        for i, (x, y) in enumerate(train_loader):
            #x = BertModel_instance(x)
            x = x.permute(1, 0, 2).float().to(device)  # change dims to match LSTM module expected dim
            y = y.long().to(device)
            output = model(x)
            loss = criterion(output, y.view(-1))
            cur_train_loss += loss.item()

            # propegate gradient
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # Evaluate on a single batch, do not propagate gradients
        with torch.no_grad():
            model.eval()
            for x, y in val_loader:
                #x = BertModel_instance(x)
                x = x.permute(1, 0, 2).float().to(device)
                y = y.long().to(device)
                output = model(x)
                loss = criterion(output, y.view(-1))
                cur_valid_loss += loss.item()

        # get mean losses
        train_loss.append(cur_train_loss / len(train))
        valid_loss.append(cur_valid_loss / len(val))

        #if epoch % 1 == 0:
        print("Epoch %2i : Train Loss %f, Validation Loss %f" % (epoch+1, train_loss[-1], valid_loss[-1]))

        if save_model:
            save_checkpoint(model, optimizer, train_loss, valid_loss, checkpoint_path, loss_data_path)
        lr_sched.step()

        # save final model
    torch.save({"model_state": model.state_dict()}, checkpoint_path)
    

def predict_lstm(model, checkpoint_path, test_loader, rnn_preds, rnn_preds_readable, device):
    """
    Make predictions on a dataset with a trained LSTM model
    """
    # load trained model
    model_data = torch.load(checkpoint_path)
    model.load_state_dict(model_data["model_state"])
    model = model.to(device)

    # make predictions on test set
    tar_tokenizer = TAPETokenizer(vocab='glycolysation')
    pred_array = list()
    tar_array = list()
    outfile = open(rnn_preds_readable, "w")
    with torch.no_grad():
        model.eval()
        for i, (x, y) in enumerate(test_loader):
            x = x.permute(1, 0, 2).float().to(device)
            output = model(x)

            # convert to human readable format and print to outfile
            output_max = F.softmax(output, dim=1).to(device).data.numpy()
            output_max = np.array([np.argmax(x) for x in output_max])  # get most likely token
            output_max = tar_tokenizer.convert_ids_to_tokens(output_max)
            y = tar_tokenizer.convert_ids_to_tokens(y[0].to("cpu").data.numpy())
            print("Prediction:", output_max, "Target:   ",  y, sep="\n", file=outfile)
            print("\n", file=outfile)

            # slice away start tokens and paddings for performance evaluation 
            # as we don't care about predictions in padding regions
            start_pos = y.index("<cls>")
            end_pos = y.index("<sep>")
            output_sliced = output[:, 5:]
            output_sliced = F.softmax(output_sliced[1:end_pos], dim=1)
            y_sliced = y[1:end_pos]
            y_sliced = [int(x) for x in y_sliced]

            pred_array.append(output_sliced[:, 1])
            tar_array.append(y_sliced)
    pred_array = np.hstack(pred_array)
    tar_array = np.hstack(tar_array)
    np.savez_compressed(datapath / rnn_preds, pred=pred_array, tar=tar_array)
    outfile.close()

## O-glyc with BERT encoded data

In [None]:
batch_size = 1
epochs = 50
load_model = False  # set True to load a checkpoint before resuming training
save_model = True  # set True to make a checkpoint of the model after each epoch
#device = torch.device("cpu")


# load data
data_load = np.load(datapath / "bert_embedded_o_glyc_data.npz", allow_pickle=True)

train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

# set paths for model checkpoints
checkpoint_path = datapath / 'model_rnn_o.pt'
loss_data_path = datapath / 'loss_data_o.npz'

# path for saving predictions
rnn_preds = datapath / "rnn_pred_o.npz"
rnn_preds_readable = datapath / "rnn_preds_translated_o.txt"

In [None]:
model = LSTM(src_vocab_size=768, tgt_vocab_size=7).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_lstm(
    model, 
    optimizer, 
    lr_sched, 
    criterion, 
    load_model, 
    save_model, 
    epochs,
    checkpoint_path,
    loss_data_path,
    device
)

In [None]:
# predict on test set
test = Dataset(torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq']).long())
test_loader = data.DataLoader(test, batch_size=1, shuffle=False)
model = LSTM(src_vocab_size=768, tgt_vocab_size=7).to(device)
predict_lstm(model, checkpoint_path, test_loader, rnn_preds, rnn_preds_readable, device)

## N-glyc with BERT encoded data

In [None]:
batch_size = 1
#epochs = 50
load_model = False  # set True to load a checkpoint before resuming training
save_model = True  # set True to make a checkpoint of the model after each epoch
#device = torch.device("cpu")

# load data
data_load = np.load(datapath / "bert_embedded_n_glyc_data.npz", allow_pickle=True)

train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

# set paths for model checkpoints
checkpoint_path = datapath / 'model_rnn_n.pt'
loss_data_path = datapath / 'loss_data_n.npz'

# path for saving predictions
rnn_preds = datapath / "rnn_pred_n.npz"
rnn_preds_readable = datapath / "rnn_preds_translated_n.txt"

In [None]:
model = LSTM(src_vocab_size=768, tgt_vocab_size=7).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_lstm(
    model, 
    optimizer, 
    lr_sched, 
    criterion, 
    load_model, 
    save_model, 
    epochs,
    checkpoint_path,
    loss_data_path,
    device
)

In [None]:
# predict on test set
test = Dataset(torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq']).long())
test_loader = data.DataLoader(test, batch_size=1, shuffle=False)
model = LSTM(src_vocab_size=768, tgt_vocab_size=7).to(device)
predict_lstm(model, checkpoint_path, test_loader, rnn_preds, rnn_preds_readable, device)

## O-glyc one hot encoded data

In [None]:
batch_size = 1
#epochs = 50
load_model = False  # set True to load a checkpoint before resuming training
save_model = True  # set True to make a checkpoint of the model after each epoch
device = torch.device("cpu")

# load data
data_load = np.load(datapath / "one_hot_o_glyc_data.npz", allow_pickle=True)

train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

# set paths for model checkpoints
checkpoint_path = datapath / 'model_rnn_o_one_hot.pt'
loss_data_path = datapath / 'loss_data_o_one_hot.npz'

# path for saving predictions
rnn_preds = datapath / "rnn_pred_o_one_hot.npz"
rnn_preds_readable = datapath / "rnn_preds_translated_o_one_hot.txt"

In [None]:
model = LSTM(src_vocab_size=30, tgt_vocab_size=7).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_lstm(
    model, 
    optimizer, 
    lr_sched, 
    criterion, 
    load_model, 
    save_model, 
    epochs,
    checkpoint_path,
    loss_data_path,
    device
)

In [None]:
# predict on test set
test = Dataset(torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq']).long())
test_loader = data.DataLoader(test, batch_size=1, shuffle=False)
predict_lstm(model, checkpoint_path, test_loader, rnn_preds, rnn_preds_readable, device)

## N-glyc one hot encoded data

In [None]:
batch_size = 1
#epochs = 50
load_model = False  # set True to load a checkpoint before resuming training
save_model = True  # set True to make a checkpoint of the model after each epoch
#device = torch.device("cpu")

# load data
data_load = np.load(datapath / "one_hot_n_glyc_data.npz", allow_pickle=True)

train = Dataset(torch.from_numpy(data_load['train_inp_seq']), torch.from_numpy(data_load['train_tar_seq']).long())
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)

val = Dataset(torch.from_numpy(data_load['val_inp_seq']), torch.from_numpy(data_load['val_tar_seq']).long())
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=True)

# set paths for model checkpoints
checkpoint_path = datapath / 'model_rnn_n_one_hot.pt'
loss_data_path = datapath / 'loss_data_n_one_hot.npz'

# path for saving predictions
rnn_preds = datapath / "rnn_pred_n_one_hot.npz"
rnn_preds_readable = datapath / "rnn_preds_translated_n_one_hot.txt"

In [None]:
model = LSTM(src_vocab_size=30, tgt_vocab_size=7).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_lstm(
    model, 
    optimizer, 
    lr_sched, 
    criterion, 
    load_model, 
    save_model, 
    epochs,
    checkpoint_path,
    loss_data_path,
    device
)

In [None]:
# predict on test set
test = Dataset(torch.from_numpy(data_load['pred_inp_seq']), torch.from_numpy(data_load['pred_tar_seq']).long())
test_loader = data.DataLoader(test, batch_size=1, shuffle=False)
model = LSTM(src_vocab_size=30, tgt_vocab_size=7).to(device)
predict_lstm(model, checkpoint_path, test_loader, rnn_preds, rnn_preds_readable, device)

In [None]:
# N glyc
cm_plot_labels = ["No N glycosyalation", "N glcyosylation"]

data_load = np.load(datapath / "rnn_pred_n_one_hot.npz", allow_pickle=True)
targets = data_load["tar"]
preds = data_load["pred"]
mask = preds >= 0.5
preds_binary = mask.astype(int)

cm_analysis(targets, preds_binary, cm_plot_labels)

In [None]:
data_load = np.load(datapath / "rnn_pred_n.npz", allow_pickle=True)
targets = data_load["tar"]
preds = data_load["pred"]
mask = preds >= 0.5
preds_binary = mask.astype(int)


cm_analysis(targets, preds_binary, cm_plot_labels)

In [None]:
# O glyc
cm_plot_labels = ["No O glycosyalation", "O glcyosylation"]

data_load = np.load(datapath / "rnn_pred_o_one_hot.npz", allow_pickle=True)
targets = data_load["tar"]
preds = data_load["pred"]
mask = preds >= 0.5
preds_binary = mask.astype(int)

cm_analysis(targets, preds_binary, cm_plot_labels)

In [None]:
data_load = np.load(datapath / "rnn_pred_o.npz", allow_pickle=True)
targets = data_load["tar"]
preds = data_load["pred"]
mask = preds >= 0.5
preds_binary = mask.astype(int)

cm_analysis(targets, preds_binary, cm_plot_labels)

# NetOglyc and NetNglyc

These programs were downloaded from https://services.healthtech.dtu.dk/service.php?NetOGlyc-4.0 and https://services.healthtech.dtu.dk/service.php?NetNGlyc-1.0 and run locally on the test set used in the above cells. Their installation requires a bit of tinkering, and so they are not included in this notebook.

# Performance comparison

In [None]:
def parse_output_file(path, new_protein_regex, glyc_pred_regex, no_match_string):
    """Parses the output file of netN or netO glyc tools and transforms them into a usable format."""
    net_preds = list()
    tar_seq = None
    parse_flag = False
    with open(path, "r") as file:
        for line in file:
            line = line.strip()
            new_protein = re.search(new_protein_regex, line)
            if new_protein is not None:
                length = int(new_protein.group(1))
                tar_seq = np.zeros(length)
            glyc_pred = re.search(glyc_pred_regex, line)
            if glyc_pred is not None:
                position = int(glyc_pred.group(1)) - 1
                score = float(glyc_pred.group(2))
                tar_seq[position] = score
                parse_flag = True
            if line[0:4] == "----" and parse_flag is True:
                net_preds.append(tar_seq)
                tar_seq = None
                parse_flag = False
            elif line == no_match_string:
                net_preds.append(tar_seq)
                tar_seq = None
                parse_flag = False
    return np.array(net_preds, dtype=object)


ROOT_DIR = Path.cwd()
#netNglyc_path = ROOT_DIR / 'data' / 'netNglyc_out.txt'
#netOglyc_path = ROOT_DIR / 'data' / 'netOglyc_out.txt'
rnn_path_n = ROOT_DIR / 'data' / 'rnn_pred_n.npz'
rnn_path_o = ROOT_DIR / 'data' / 'rnn_pred_o.npz'
rnn_one_hot_n = datapath / "rnn_pred_n_one_hot.npz"
rnn_one_hot_o = datapath / "rnn_pred_o_one_hot.npz"


preds_path_o = ROOT_DIR / "data" / "FFN_O_glyc_predictions.npz"
preds_path_n = ROOT_DIR / "data" / "FFN_N_glyc_predictions.npz"
trans_pred_path_o = ROOT_DIR / "data" / "Transformer_o_glyc_predictions.npz"
trans_pred_path_n = ROOT_DIR / "data" / "Transformer_n_glyc_predictions.npz"

# grab prediction scores from netNglyc output
#no_match_string = "No sites predicted in this sequence."
#new_protein_regex = r"Name:\s+test_seq_\d+\s+Length:\s+(\d+)"
#glyc_pred_regex = r"test_seq_\d+\s+(\d+)\s\w+\s+(\d{1}.\d+)"
#netN_preds = parse_output_file(netNglyc_path, new_protein_regex, glyc_pred_regex, no_match_string)

# grab prediction scores from netOglyc output
#no_match_string = "No sites predicted in this sequence."  # TODO figure out the no hit message and put it here
#new_protein_regex = r"Name:\s+ts_\d+\s+Length:\s+(\d+)"
#glyc_pred_regex = r"ts_\d+\s+\w+\s+(\d+)\s+(\d\.\d+)"
#netO_preds = parse_output_file(netOglyc_path, new_protein_regex, glyc_pred_regex, no_match_string)

# grab predictions scores from LSTM
rnn_preds_n = np.load(rnn_path_n, allow_pickle=True)["pred"]
rnn_tars_n = np.load(rnn_path_n, allow_pickle=True)["tar"]
rnn_preds_o = np.load(rnn_path_o, allow_pickle=True)["pred"]
rnn_tars_o = np.load(rnn_path_o, allow_pickle=True)["tar"]

rnn_preds_one_hot_n = np.load(rnn_one_hot_n, allow_pickle=True)["pred"]
rnn_tars_one_hot_n = np.load(rnn_one_hot_n, allow_pickle=True)["tar"]
rnn_preds_one_hot_o = np.load(rnn_one_hot_o, allow_pickle=True)["pred"]
rnn_tars_one_hot_o = np.load(rnn_one_hot_o, allow_pickle=True)["tar"]

# grab FNN scores
data_load_n = np.load(preds_path_n, allow_pickle=True)
ffn_preds_n = data_load_n['glyc_prob']
target_values_n = data_load_n['tar_seq'].astype(int)

data_load_o = np.load(preds_path_o, allow_pickle=True)
ffn_preds_o = data_load_o['glyc_prob']
target_values_o = data_load_o['tar_seq'].astype(int)

# grab trans tars
trans_pred_o = np.load(trans_pred_path_o, allow_pickle=True)["pred"]
trans_pred_n = np.load(trans_pred_path_n, allow_pickle=True)["pred"]
trans_tar_o = np.load(trans_pred_path_o, allow_pickle=True)["tar_seq"]
trans_tar_n = np.load(trans_pred_path_n, allow_pickle=True)["tar_seq"]


# flatten arrays to allow for ROC AUC calc
#targets_n = np.hstack(targets_n)
#targets_o = np.hstack(targets_o)
#netN_preds = np.hstack(netN_preds)
#netO_preds = np.hstack(netO_preds)
#trans_pred_o = np.hstack([x[:-1] for x in trans_pred_o])
#trans_pred_n = np.hstack([x[:-1] for x in trans_pred_n])

# netN and netO auc
#fpr_net_N, tpr_net_N, thresh_net_N = metrics.roc_curve(targets_n, netN_preds, pos_label=1)
#auc_net_N = metrics.auc(fpr_net_N, tpr_net_N)

#fpr_net_O, tpr_net_O, thresh_net_O = metrics.roc_curve(targets_o, netO_preds, pos_label=1)
#auc_net_O = metrics.auc(fpr_net_O, tpr_net_O)

# LSTM aucs
fpr_rnn_N, tpr_rnn_N, thresh_rnn_N = metrics.roc_curve(rnn_tars_n, rnn_preds_n, pos_label=1)
auc_rnn_N = metrics.auc(fpr_rnn_N, tpr_rnn_N)

fpr_rnn_O, tpr_rnn_O, thresh_rnn_O = metrics.roc_curve(rnn_tars_o, rnn_preds_o, pos_label=1)
auc_rnn_O = metrics.auc(fpr_rnn_O, tpr_rnn_O)

fpr_rnn_one_hot_N, tpr_rnn_one_hot_N, thresh_rnn_N = metrics.roc_curve(rnn_tars_one_hot_n, rnn_preds_one_hot_n, pos_label=1)
auc_rnn_one_hot_N = metrics.auc(fpr_rnn_one_hot_N, tpr_rnn_one_hot_N)

fpr_rnn_one_hot_O, tpr_rnn_one_hot_O, thresh_rnn_O = metrics.roc_curve(rnn_tars_one_hot_o, rnn_preds_one_hot_o, pos_label=1)
auc_rnn_one_hot_O = metrics.auc(fpr_rnn_one_hot_O, tpr_rnn_one_hot_O)

# FNN aucs
fpr_ffn_N, tpr_ffn_N, thresh_ffn_N = metrics.roc_curve(target_values_n, ffn_preds_n, pos_label=1)
auc_ffn_N = metrics.auc(fpr_ffn_N, tpr_ffn_N)

fpr_ffn_O, tpr_ffn_O, thresh_ffn_O = metrics.roc_curve(target_values_o, ffn_preds_o, pos_label=1)
auc_ffn_O = metrics.auc(fpr_ffn_O, tpr_ffn_O)

# trans aucs
fpr_trans_O, tpr_trans_O, thresh_trans_O = metrics.roc_curve(trans_tar_o, trans_pred_o, pos_label=1)
auc_trans_O = metrics.auc(fpr_trans_O, tpr_trans_O)

fpr_trans_N, tpr_trans_N, thresh_trans_N = metrics.roc_curve(trans_tar_n, trans_pred_n, pos_label=1)
auc_trans_N = metrics.auc(fpr_trans_N, tpr_trans_N)

xmin, xmax, ymin, ymax = 0.0, 0.7, 0.3, 1.0

In [None]:
# N-plot
ax = plt.subplot()
#ax.plot(fpr_net_N, tpr_net_N)
ax.plot(fpr_rnn_N, tpr_rnn_N)
ax.plot(fpr_ffn_N, tpr_ffn_N)
ax.plot(fpr_trans_N, tpr_trans_N)
ax.plot(fpr_rnn_one_hot_N, tpr_rnn_one_hot_N)
ax.plot(np.linspace(0, 1), np.linspace(0, 1), "--")

#ax.legend(["netNglyc-1.0, AUC={}".format(round(auc_net_N, 4)),
#            "BERT + LSTM, AUC={}".format(round(auc_rnn_N, 4)),
#            "BERT + FFN, AUC={}".format(round(auc_ffn_N, 4)),
#            "BERT + Transf. decoder, AUC={}".format(round(auc_trans_N, 4)),
#            "LSTM, AUC={}".format(round(auc_rnn_one_hot_N, 4))],
#          prop={"size": 12})
ax.legend(["BERT + LSTM, AUC={}".format(round(auc_rnn_N, 4)),
            "BERT + FFN, AUC={}".format(round(auc_ffn_N, 4)),
            "BERT + Transf. decoder, AUC={}".format(round(auc_trans_N, 4)),
            "LSTM, AUC={}".format(round(auc_rnn_one_hot_N, 4))],
          prop={"size": 12})

ax.axis([xmin,xmax,ymin,ymax])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.title("ROC curves for N-linked glycosylation")
plt.ylabel("TPR")
plt.xlabel("FPR")
plt.show()
plt.savefig('N_glyc.png', dpi=800)
#plt.close()

In [None]:
# O-plot
ax = plt.subplot()
#ax.plot(fpr_net_O, tpr_net_O)
ax.plot(fpr_rnn_O, tpr_rnn_O)
ax.plot(fpr_ffn_O, tpr_ffn_O)
ax.plot(fpr_trans_O, tpr_trans_O)
ax.plot(fpr_rnn_one_hot_O, tpr_rnn_one_hot_O)
ax.plot(np.linspace(0, 1), np.linspace(0, 1), "--")

#ax.legend(["netOglyc-3.1, AUC={}".format(round(auc_net_O, 4)),
#            "BERT + LSTM, AUC={}".format(round(auc_rnn_O, 4)),
#            "BERT + FFN, AUC={}".format(round(auc_ffn_O, 4)),
#            "BERT + Transf. decoder, AUC={}".format(round(auc_trans_O, 4)),
#            "LSTM, AUC={}".format(round(auc_rnn_one_hot_O, 4))],
#          prop={"size": 12})


ax.legend(["BERT + LSTM, AUC={}".format(round(auc_rnn_O, 4)),
            "BERT + FFN, AUC={}".format(round(auc_ffn_O, 4)),
            "BERT + Transf. decoder, AUC={}".format(round(auc_trans_O, 4)),
            "LSTM, AUC={}".format(round(auc_rnn_one_hot_O, 4))],
          prop={"size": 12})

ax.axis([xmin,xmax,ymin,ymax])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.title("ROC curves for O-linked glycosylation")
plt.ylabel("TPR")
plt.xlabel("FPR")
plt.show()
plt.savefig('O_glyc.png', dpi=800)
#plt.close()