In [1]:
# from google.colab import drive
# drive.mount('/content/drive')
fs = '/content/drive/My Drive/'
fs = 'data/'

!pip install glycowork
!pip install 'glycowork[ml]'



In [2]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Fetching ESM Representations

Facebook `fair esm` provides protein embeddings for amino acid sequences.

In [None]:
# Run the following on GPU

# from google.colab import drive
# drive.mount('/content/drive')

# !pip install glycowork
# !pip install 'glycowork[ml]'
# !pip install transformers

import torch
from transformers import AutoTokenizer, AutoModel
from glycowork.glycan_data.loader import glycan_binding
import pickle

# model_checkpoint = "facebook/esm2_t33_650M_UR50D"
model_checkpoint = "facebook/esm1b_t33_650M_UR50S"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Fetch all lectin amino acid sequences (1392 lectins from glycan_binding)
all_prots = list(set(glycan_binding.target.tolist()))
batch_size = 10

# Path to save the pickle file
filename = fs + 'esm1b_embeddings_full.pkl'

# Function to process a batch of proteins
def process_batch(proteins):
    inputs = tokenizer(proteins, padding=True, max_length=1000, truncation=True, return_tensors="pt")
    inputs = {key: tensor.to(device) for key, tensor in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state

    return {prot: last_hidden_states[i, 1:len(prot) + 1].mean(0).cpu().numpy().tolist() for i, prot in enumerate(proteins)}

# Iterate over the proteins in batches and save the results
with open(filename, 'ab') as file:  # 'ab' for appending in binary mode
    for i in range(0, len(all_prots), batch_size):
        batch_prots = all_prots[i:i + batch_size]
        prot_dic = process_batch(batch_prots)
        pickle.dump(prot_dic, file)

print("All data processed and saved successfully.")


After executing the following code, `data/esm1b_embeddings_full.pkl` and `data/esm2_embeddings_full.pkl` are generated.

# Loading Saved ESM Embeddings

In [2]:
import pickle

# Path to the pickle file
filename = fs + 'esm1b_embeddings_full.pkl'

# Initialize an empty dictionary to hold all the data
prot_dic = {}

# Open the pickle file and read from it
with open(filename, 'rb') as file:
    while True:
        try:
            # Load the data from the file and update the main dictionary
            data = pickle.load(file)
            prot_dic.update(data)
        except EOFError:
            # End of file reached
            break

# Check the loaded data
print(f"Loaded data for {len(prot_dic)} proteins.")


Loaded data for 1392 proteins.


# Conversion of IUPAC Glycan Motif String Representation to Graph Object

# Training Lectin Oracle Model from Scratch

## Create Torch DataLoaders

In [3]:
from glycowork.glycan_data.loader import glycan_binding
from glycowork.motif.graph import glycan_to_nxGraph
from glycowork.motif.processing import get_lib, expand_lib


lib = get_lib(glycan_binding.columns.unique().tolist()[:-2])
lib = expand_lib(lib, ['GlcNAcOS', 'GalOS', 'HexNAc'])
print(len(lib))
# generate a dictionary to look up graph representation of glycan

glycan_graph_dic = {glyc_string: glycan_to_nxGraph(glyc_string, libr=lib) for glyc_string in glycan_binding.columns.unique().tolist()[:-2]}
len(glycan_graph_dic)

87


927

In [4]:
import pickle

import numpy as np
import pandas as pd

#retrieve representations for protein sequences
filename = fs + 'stored_protein_embeddings.pkl'

old_dic = {}
with open(filename, 'rb') as file:
    old_dic = pickle.load(file)

print(f"Loaded data for {len(old_dic)} proteins.")

#converts data into format (protein, glycan, binding)
def generate_pair_data(df):
  """creates 3-tuples of lectin, glycan, and the corresponding binding value measured in relative fluorescence units"""
  out = []
  for k in range(len(df)):
    for j in range(df.shape[1]-1):
      if ~np.isnan(df.iloc[k,j]) and np.isfinite(df.iloc[k,j]):
        out.append((df.loc[k,'target'], df.columns.values.tolist()[j], df.iloc[k,j]))
  return out

data = generate_pair_data(glycan_binding.iloc[:, :-1])
df = pd.DataFrame(data, columns = ['seq', 'glycan', 'binding'])
df.to_csv(fs + 'all_binding_pairs.csv', index=False)
df.head()

Loaded data for 1393 proteins.


Unnamed: 0,seq,glycan,binding
0,AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...,Fuc(a1-2)Gal,0.293462
1,AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...,Fuc(a1-2)Gal(b1-3)GalNAc,-1.316793
2,AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...,Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b...,-0.860744
3,AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...,Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b...,-1.211838
4,AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...,Fuc(a1-2)Gal(b1-3)GalNAc(b1-3)Gal,-0.335253


In [5]:
from sklearn.model_selection import train_test_split
import pandas as pd


df = pd.read_csv(fs + 'all_binding_pairs.csv')
print(df.head())

df['seq'] = df['seq'].map(old_dic)

df['glycan'] = df['glycan'].map(glycan_graph_dic)

df_train, df_test = train_test_split(df, test_size = 0.1, random_state = 1, shuffle = True)
df_train.head()

                                                 seq  \
0  AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...   
1  AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...   
2  AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...   
3  AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...   
4  AADSIPSISPTGIITPTPTQSGMVSNCNKFYDVHSNDGCSAIASSQ...   

                                              glycan   binding  
0                                       Fuc(a1-2)Gal  0.293462  
1                           Fuc(a1-2)Gal(b1-3)GalNAc -1.316793  
2  Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b... -0.860744  
3  Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b... -1.211838  
4                  Fuc(a1-2)Gal(b1-3)GalNAc(b1-3)Gal -0.335253  


Unnamed: 0,seq,glycan,binding
55048,"[0.04727237671613693, 0.25703755021095276, 0.0...","(0, 1, 2)",-0.143531
511556,"[0.07094207406044006, 0.2685469388961792, 0.02...","(0, 1, 2)",-0.160575
76083,"[0.03967776522040367, 0.25796017050743103, 0.0...","(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",-0.26808
557630,"[0.10139428824186325, 0.2674408257007599, 0.03...","(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",-0.282433
207864,"[0.016354024410247803, 0.233301043510437, 0.12...","(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",-0.32405


In [13]:
from sklearn.model_selection import KFold, train_test_split
from torch_geometric.utils.convert import from_networkx
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset

def create_folds(data, n_splits=10):

    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    for train_idx, val_idx in kf.split(data):
        yield train_idx, val_idx

def dataset_to_graphs(glycan_graphs, labels, libr=None, label_type = torch.float):

    if libr is None:
      libr = lib
    data = [from_networkx(k) for k in glycan_graphs]

    for data_obj, label in zip(data, labels):
        data_obj.y = torch.tensor(label, dtype = label_type)

    return data


def dataset_to_dataloader(glycan_graphs, labels, libr=lib, batch_size=128, shuffle=True, drop_last = False, extra_feature=None, label_type=torch.float):

    if libr is None:
      libr = lib

    # Converting glycans and labels to PyTorch Geometric Data objects
    glycan_graphs = dataset_to_graphs(glycan_graphs, labels, libr = libr, label_type = label_type)

    if extra_feature is not None:
        for graph, feature in zip(glycan_graphs, extra_feature):
            graph.train_idx = torch.tensor(feature, dtype = torch.float)
    # Generating the dataloader from the data objects
    return DataLoader(glycan_graphs, batch_size = batch_size, shuffle = shuffle, drop_last = drop_last)


def get_dataloaders(dataset, train_idx, val_idx, batch_size=128):
    """
    dataset (pd.DataFrame)
    train_idx: array of indices to use as examples for training the model during cross validation
    val_idx: array of indices to use as tests for validating the model during cross validation
    batch_size: size of minibatches
    """
    train_subs = Subset(dataset, train_idx)
    val_subs = Subset(dataset, val_idx)

    train_loader = dataset_to_dataloader(train_subs.dataset.glycan.tolist(), train_subs.dataset.binding.tolist(), libr=lib, batch_size=128, shuffle=True, extra_feature=train_subs.dataset.seq.tolist(), label_type= torch.float)
    val_loader = dataset_to_dataloader(val_subs.dataset.glycan.tolist(), val_subs.dataset.binding.tolist(), libr=lib, batch_size=128, shuffle=False, extra_feature=val_subs.dataset.seq.tolist(), label_type= torch.float)

    return {'train': train_loader, 'val': val_loader}


##  LectinOracle Model Architecture Specification

In [14]:
#defining LectinOracle
import torch.nn as nn
from torch_geometric.nn import TopKPooling, GraphConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

def sigmoid_range(x, low, high):
    "Sigmoid function with range `(low, high)`"
    return torch.sigmoid(x) * (high - low) + low

class SigmoidRange(nn.Module):
    "Sigmoid module with range `(low, x_max)`"
    def __init__(self, low, high):
      super(SigmoidRange, self).__init__()
      self.low, self.high = low,high
    def forward(self, x): return sigmoid_range(x, self.low, self.high)

class LectinOracle(nn.Module):
    def __init__(self, input_size_glyco, hidden_size, num_classes, data_min,
               data_max, input_size_prot = 1280, n_layers = 1):
        super(LectinOracle,self).__init__()
        self.input_size_prot = input_size_prot # 1280 as defined by ESM1b
        self.input_size_glyco = input_size_glyco # 84 as defined by the glycan library len(lib)+1
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.n_layers = n_layers

        # these are the node embeddings for the glycan graph
        self.item_embedding = torch.nn.Embedding(num_embeddings = self.input_size_glyco, embedding_dim = self.hidden_size)

        # doesn't seem like this line is used
        self.glyco_encoder = nn.Embedding(self.input_size_glyco, self.hidden_size, padding_idx = self.input_size_glyco - 1)

        # 3 GCN layers with Pooling
        self.conv1 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool1 = TopKPooling(self.hidden_size, ratio = 0.8)
        self.conv2 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool2 = TopKPooling(self.hidden_size, ratio = 0.8)
        self.conv3 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool3 = TopKPooling(self.hidden_size, ratio = 0.8)


        # Layers related to ESM1b
        self.prot_encoder1 = nn.Linear(self.input_size_prot, 400)
        self.prot_encoder2 = nn.Linear(400, 128)
        self.dp_prot1 = nn.Dropout(0.2)
        self.dp_prot2 = nn.Dropout(0.1)
        self.act_prot1 = torch.nn.LeakyReLU()
        self.act_prot2 = torch.nn.LeakyReLU()
        self.bn_prot1 = nn.BatchNorm1d(400)
        self.bn_prot2 = nn.BatchNorm1d(128)

        # Fully connected

        self.dp1 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128+2*self.hidden_size, int(np.round(self.hidden_size/2)))
        self.fc2 = nn.Linear(int(np.round(self.hidden_size/2)), num_classes)
        self.bn1 = nn.BatchNorm1d(int(np.round(self.hidden_size/2)))
        self.act1 = torch.nn.LeakyReLU()

        self.sigmoid = SigmoidRange(data_min, data_max)


    def forward(self, prot, nodes, edge_index, batch, inference = False):
        # 1280 -> 400 -> 128  (two hidden layers, feedforward with dropout 0.2 on the first, then dropout 0.1 on the second)
        # batch norm applied after Leaky ReLU activation??
        embedded_prot = self.bn_prot1(self.act_prot1(self.dp_prot1(self.prot_encoder1(prot))))
        embedded_prot = self.bn_prot2(self.act_prot2(self.dp_prot2(self.prot_encoder2(embedded_prot))))

        # x is the node vector which originaly has len(lib) size but we use torch.nn.Embedding which
        x = self.item_embedding(nodes)
        x = x.squeeze(1)

        x = F.leaky_relu(self.conv1(x, edge_index))

        x, edge_index, _, batch, _, _= self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1) # global mean pool, global max pool

        x = F.leaky_relu(self.conv2(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1)

        x = F.leaky_relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1)

        x = x1 + x2 + x3 # is this a bit like ResNet direct connections?

        h_n = torch.cat((embedded_prot, x), 1) # concatenated representation of lectin, glycan before fed into feed forward layers

        h_n = self.bn1(self.act1(self.fc1(h_n)))

        # some kind of ensemble averaging going on?
        #1
        x1 = self.fc2(self.dp1(h_n))
        #2
        x2 = self.fc2(self.dp1(h_n))
        #3
        x3 = self.fc2(self.dp1(h_n))
        #4
        x4 = self.fc2(self.dp1(h_n))
        #5
        x5 = self.fc2(self.dp1(h_n))
        #6
        x6 = self.fc2(self.dp1(h_n))
        #7
        x7 = self.fc2(self.dp1(h_n))
        #8
        x8 = self.fc2(self.dp1(h_n))

        out =  self.sigmoid(torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0))

        if inference:
          return out, embedded_prot, x
        else:
          return out

## Model Fitting

In [None]:
from glycowork.ml.model_training import train_model

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)


data_min = glycan_binding.iloc[:, :-2].min().min()
data_max = glycan_binding.iloc[:, :-2].max().max()

# Setting up 10 fold cross validation

n_splits = 10
for train_idx, val_idx in create_folds(df_train, n_splits):
    print(train_idx, val_idx)
    dataloaders = get_dataloaders(df_train, train_idx, val_idx)
    model = LectinOracle(input_size_prot = 1280, input_size_glyco = len(lib) + 1, hidden_size = 128, num_classes = 1, data_min = data_min, data_max = data_max)
    model.apply(init_weights)
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0005, weight_decay = 0.0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 80)
    criterion = nn.MSELoss().cuda()
    model_ft = train_model(model, dataloaders, optimizer=optimizer, scheduler=scheduler, criterion=criterion, num_epochs = 100, patience = 20, mode = 'regression')
    break

[     0      1      2 ... 508179 508180 508181] [     6      7     30 ... 508129 508149 508171]


# Using Pretrained LectinOracle Model

In [None]:
import numpy as np
import pandas as pd
import torch
from glycowork.glycan_data.loader import lib, unwrap, build_custom_df, df_glycan, glycan_binding
from glycowork.ml.processing import dataset_to_dataloader
from glycowork.motif.tokenization import prot_to_coded
from glycowork.ml.models import prep_model

# Revised helper function that will enable us to retrieve both the prot representation of the model and the binding prediction
def get_multi_pred(prot, glycans, model, prot_dic,
                   background_correction = False, correction_df = None,
                   batch_size = 128, libr = None, flex = False, mode = 'rep'):
  """Inner function to actually get predictions for lectin-glycan binding from LectinOracle-type model\n
  | Arguments:
  | :-
  | prot (string): protein amino acid sequence
  | glycans (list): list of glycans in IUPACcondensed
  | model (PyTorch object): trained LectinOracle-type model
  | prot_dic (dictionary): dictionary of type protein sequence:ESM1b representation
  | background_correction (bool): whether to correct predictions for background; default:False
  | correction_df (dataframe): background prediction for (ideally) all provided glycans; default:None
  | batch_size (int): change to batch_size used during training; default:128
  | libr (dict): dictionary of form glycoletter:index
  | flex (bool): depends on whether you use LectinOracle (False) or LectinOracle_flex (True); default:False\n
  | Returns:
  | :-
  | Returns dataframe of glycan sequences and predicted binding to prot
  """
  if libr is None:
      libr = lib
  # Preparing dataset for PyTorch
  if flex:
      prot = prot_to_coded([prot])
      feature = prot * len(glycans)
  else:
      rep = prot_dic.get(prot, "new protein, no stored embedding")
      feature = [rep] * len(glycans)
  train_loader = dataset_to_dataloader(glycans, [0.99]*len(glycans),
                                         libr = libr, batch_size = batch_size,
                                         shuffle = False, extra_feature = feature)
  model = model.eval()
  res = []
  # Get predictions for each mini-batch
  for k in train_loader:
    x, y, edge_index, prot, batch = k.labels, k.y, k.edge_index, k.train_idx, k.batch
    x, y, edge_index, prot, batch = x.to(device), y.to(device), edge_index.to(device), prot.view(max(batch) + 1, -1).float().to(device), batch.to(device)
    pred, reps, _ = model(prot, x, edge_index, batch, inference=True)
    if mode == 'rep':
      res.append(reps)
    else:
      res.append(pred)
  res = unwrap([res[k].detach().cpu().numpy() for k in range(len(res))])

  return res


In [None]:
import torch
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool as gap


def sigmoid_range(x, low, high):
    "Sigmoid function with range `(low, high)`"
    return torch.sigmoid(x) * (high - low) + low


class SigmoidRange(torch.nn.Module):
    "Sigmoid module with range `(low, x_max)`"

    def __init__(self, low, high):
      super(SigmoidRange, self).__init__()
      self.low, self.high = low, high

    def forward(self, x):
        return sigmoid_range(x, self.low, self.high)


class LectinOracle(nn.Module):
    def __init__(self, input_size_glyco, hidden_size, num_classes, data_min,
               data_max, input_size_prot = 1280, n_layers = 1):

        super(LectinOracle,self).__init__()
        self.input_size_prot = input_size_prot # 1280
        self.input_size_glyco = input_size_glyco # 84
        self.hidden_size = hidden_size # 128
        self.num_classes = num_classes # 1
        self.n_layers = n_layers # 1

        self.conv1 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool1 = TopKPooling(self.hidden_size, ratio = 0.8)
        self.conv2 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool2 = TopKPooling(self.hidden_size, ratio = 0.8)
        self.conv3 = GraphConv(self.hidden_size, self.hidden_size)
        self.pool3 = TopKPooling(self.hidden_size, ratio = 0.8)

        # Glycan node embedding 84 -> 128
        self.item_embedding = torch.nn.Embedding(num_embeddings = self.input_size_glyco, embedding_dim = self.hidden_size, padding_idx = self.input_size_glyco - 1)

        # ESM1b 1280 -> 400 -> 128
        self.prot_encoder1 = nn.Linear(self.input_size_prot, 400)
        self.prot_encoder2 = nn.Linear(400, 128)

        self.dp1 = nn.Dropout(0.5)
        self.dp_prot1 = nn.Dropout(0.2)
        self.dp_prot2 = nn.Dropout(0.1)
        self.fc1 = nn.Linear(128+2*self.hidden_size, int(np.round(self.hidden_size/2)))
        self.fc2 = nn.Linear(int(np.round(self.hidden_size/2)), num_classes)
        self.bn1 = nn.BatchNorm1d(int(np.round(self.hidden_size/2)))
        self.bn_prot1 = nn.BatchNorm1d(400)
        self.bn_prot2 = nn.BatchNorm1d(128)
        self.act1 = torch.nn.LeakyReLU()
        self.act_prot1 = torch.nn.LeakyReLU()
        self.act_prot2 = torch.nn.LeakyReLU()
        self.sigmoid = SigmoidRange(data_min, data_max)


    def forward(self, prot, nodes, edge_index, batch, inference = False):

        # ESM1b 1280 -> 400 -> 128
        embedded_prot = self.bn_prot1(self.act_prot1(self.dp_prot1(self.prot_encoder1(prot))))
        embedded_prot = self.bn_prot2(self.act_prot2(self.dp_prot2(self.prot_encoder2(embedded_prot))))

        # glycan graph
        x = self.item_embedding(nodes)
        x = x.squeeze(1)

        x = F.leaky_relu(self.conv1(x, edge_index))

        x, edge_index, _, batch, _, _= self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1)

        x = F.leaky_relu(self.conv2(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1)

        x = F.leaky_relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim = 1)

        # some resnet style architecture going on here? this is not just GCN layered on top of each other...
        x = x1 + x2 + x3

        # glycan graph embeddings + ESM1 embeddings concactenated this is the start of the fully connected layer
        h_n = torch.cat((embedded_prot, x), 1)

        h_n = self.bn1(self.act1(self.fc1(h_n)))

        # some kind of ensemble method going on here?
        #1
        x1 = self.fc2(self.dp1(h_n))
        #2
        x2 = self.fc2(self.dp1(h_n))
        #3
        x3 = self.fc2(self.dp1(h_n))
        #4
        x4 = self.fc2(self.dp1(h_n))
        #5
        x5 = self.fc2(self.dp1(h_n))
        #6
        x6 = self.fc2(self.dp1(h_n))
        #7
        x7 = self.fc2(self.dp1(h_n))
        #8
        x8 = self.fc2(self.dp1(h_n))

        out =  self.sigmoid(torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0))

        if inference:
          return out, embedded_prot, x
        else:
          return out


# Choosing the right computing architecture
device = "cuda" if torch.cuda.is_available() else "cpu"


from glycowork.glycan_data.loader import lib

print(len(lib))

model = LectinOracle(len(lib), num_classes = 1)
model.load_state_dict(torch.load('models/glycowork_lectinoracle_600.pt', map_location = device))
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")

# LeOr_flex = prep_model('LectinOracle_flex', 1, trained = True)

2366
Number of parameters: 983089


In [None]:


# Choosing the right computing architecture
device = "cuda" if torch.cuda.is_available() else "cpu"

from glycowork.motif.processing import get_lib, expand_lib

lib = get_lib(glycan_binding.columns.unique().tolist()[:-2])
lib = expand_lib(lib, ['GlcNAcOS', 'GalOS', 'HexNAc'])

print(len(lib))

model = LectinOracle(input_size_prot = 1280, input_size_glyco = len(lib), hidden_size = 128,
          num_classes = 1, data_min = glycan_binding.iloc[:, :-2].min().min(), data_max = glycan_binding.iloc[:, :-2].max().max())

model.load_state_dict(torch.load('models/LectinOracle_565.pt', map_location=torch.device('cpu')))
# model_ft = model_ft.cuda()


num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")


83


RuntimeError: Error(s) in loading state_dict for LectinOracle:
	Missing key(s) in state_dict: "conv1.lin_rel.weight", "conv1.lin_rel.bias", "conv1.lin_root.weight", "pool1.select.weight", "conv2.lin_rel.weight", "conv2.lin_rel.bias", "conv2.lin_root.weight", "pool2.select.weight", "conv3.lin_rel.weight", "conv3.lin_rel.bias", "conv3.lin_root.weight", "pool3.select.weight". 
	Unexpected key(s) in state_dict: "conv1.lin_l.weight", "conv1.lin_l.bias", "conv1.lin_r.weight", "pool1.weight", "conv2.lin_l.weight", "conv2.lin_l.bias", "conv2.lin_r.weight", "pool2.weight", "conv3.lin_l.weight", "conv3.lin_l.bias", "conv3.lin_r.weight", "pool3.weight". 
	size mismatch for item_embedding.weight: copying a param with shape torch.Size([84, 128]) from checkpoint, the shape in current model is torch.Size([83, 128]).
	size mismatch for glyco_encoder.weight: copying a param with shape torch.Size([84, 128]) from checkpoint, the shape in current model is torch.Size([83, 128]).

In [None]:
# making predictions for human glycan motifs
df_species = build_custom_df(df_glycan, 'df_species')
glyc = df_species.loc[df_species.Species == 'Homo_sapiens'].glycan.unique().tolist()

# fetching glycan motifs associated with pancreatic cancer according to curated df_disease
df_disease = build_custom_df(df_glycan, 'df_disease')
df_panc = df_disease[df_disease['disease_association'] == 'pancreatic_cancer']
biomarkers = df_panc.glycan.unique().tolist()
biomarkers[:10]

['Fuc(a1-3)[Gal(b1-4)]Glc6S-ol',
 'Fuc(a1-3)[Gal(b1-4)]GlcNAc6S(b1-4)Gal(b1-4)[Fuc(a1-3)]GlcNAc-ol',
 'Fuc(a1-3)[Gal(b1-4)]GlcNAc6S(b1-6)Man(a1-6)Man(b1-4)GlcNAc-ol',
 'Fuc(a1-3)[Gal(b1-4)]GlcNAc6S(b1-6)Man(a1-6)[Man(a1-3)]Man(b1-4)GlcNAc-ol',
 'Fuc4S(a1-3)[Gal(b1-4)]GlcNAc(b1-2)Man(a1-6)Man(b1-4)GlcNAc-ol',
 'GalNAc(a1-4)GlcA(b1-2)Glc-ol',
 'GalNAc(a1-4)GlcA(b1-3)Gal(b1-3)Gal(b1-4)Xyl-ol',
 'GalNAc(a1-4)GlcA(b1-3)GlcNAc-ol',
 'GalNAc(a1-4)GlcA(b1-4)GlcNAc-ol',
 'GlcA(b1-3)Glc-ol']

In [None]:
#getting learned protein represenations from LectinOracle
#this takes a long time to execute
prot_reps = {}
binding_pred = {}
for i, seq in enumerate(glycan_binding.target.tolist()[:2]):
  if(i%10 == 0):
    print(i)
  prot_reps[seq] = np.array(get_multi_pred(seq, [glyc[0]], LeOr, old_dic)).squeeze()
  binding_pred[seq] = unwrap(get_multi_pred(seq, glyc, LeOr, old_dic, mode='pred'))

# pd.DataFrame(prot_reps).T.to_csv('/content/drive/My Drive/LeOr_embeddings.csv')
# motif_pred = pd.DataFrame(binding_pred).T
# motif_pred.columns = glyc
# motif_pred.to_csv('/content/drive/My Drive/human_glycan_motif_predicted.csv')

In [None]:
len(lib)

2366

['3-Anhydro-Gal(a1-3)Gal(b1-4)3-Anhydro-Gal(a1-3)Gal4S',
 '3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S',
 '3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S',
 '3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S',
 '3-Anhydro-Gal(a1-3)Gal4S(b1-4)3-Anhydro-Gal2S(a1-3)Gal4S(b1-4)3-Anhydro-Gal(a1-3)Gal4S',
 '3dGal(b1-3)[Fuc(a1-4)]Glc',
 '3dGal(b1-4)Glc',
 '4d8dNeu5Ac(a2-3)Gal(b1-4)Glc',
 '4dNeu5Ac(a2-3)Gal(b1-4)Glc',
 '7dNeu5Ac(a2-3)Gal(b1-4)Glc',
 '8dNeu5Ac(a2-3)Gal(b1-4)Glc',
 '9dNeu5Ac(a2-3)Gal(b1-4)Glc',
 'Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara',
 'Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara(a1-5)Ara',
 'Fuc(a1-2)Gal',
 'Fuc(a1-2)Gal(b1-3)GalNAc',
 'Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b1-4)Glc',
 'Fuc(a1-2)Gal(b1-3)GalNAc(a1-3)[Fuc(a1-2)]Gal(b1-4)GlcNAc',
 'Fuc(a1-2)Gal(b1-3)GalNAc(b1-3)Gal',
 'Fuc(a1-2)Gal(b1-3)GalNAc(b1-3)Gal(a1-4)Gal(b1-4)Glc',
 'Fuc