In [1]:
import random
import transformers
from transformers import (AutoTokenizer, AutoModel, PreTrainedTokenizerFast,
                          AdamW, BertForSequenceClassification, AutoModelForSequenceClassification)
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from typing import List
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import logging
import os
from functools import lru_cache
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import DataLoader, Dataset
from argparse import Namespace
from sklearn.metrics import classification_report, multilabel_confusion_matrix
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import torchmetrics
from torchmetrics import F1Score as f1
from torchmetrics.functional import accuracy, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pylab import rcParams
from matplotlib import rc
import pytorch_lightning.callbacks as pt_callbacks
from torchmetrics.functional import accuracy
from captum.attr import (LayerConductance, IntegratedGradients, LayerIntegratedGradients,
                         configure_interpretable_embedding_layer,
                         remove_interpretable_embedding_layer)
from captum.attr import visualization as viz


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
#Model params
model_name = 'glycosyltransferase_roberta_350000_steps'
tokenizer= PreTrainedTokenizerFast(tokenizer_file="pks_tokenizer/tokenizer.json")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
maxlength = 200

#Define model card 

class all_GT_Model(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.bert = AutoModelForSequenceClassification.from_pretrained(model_name,\
                                          num_labels =n_classes,\
                                          output_hidden_states=True,\
                                          output_attentions=True,
                                          return_dict=False)
    self.dropout=nn.Dropout(0.2)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)[0]
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

class C_GT_Model(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.bert = AutoModelForSequenceClassification.from_pretrained(model_name,\
                                          num_labels =n_classes,\
                                          output_hidden_states=True,\
                                          output_attentions=True,
                                          return_dict=False)
    self.dropout=nn.Dropout(0.2)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)[0]
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

In [3]:
#Load trained model 
agt_model = all_GT_Model.load_from_checkpoint(checkpoint_path="./GTBERT-gtclassifer.ckpt", n_classes=2)
agt_model.to(device)
agt_model.eval()
agt_model.zero_grad()

Some weights of the model checkpoint at glycosyltransferase_roberta_350000_steps were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at glycosyltransferase_roberta_350000_steps and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out

In [4]:
cgt_model = C_GT_Model.load_from_checkpoint(checkpoint_path="./CGTBERT-cgt-like.ckpt", n_classes=2)
cgt_model.to(device)
cgt_model.eval()
cgt_model.zero_grad()

Some weights of the model checkpoint at glycosyltransferase_roberta_350000_steps were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at glycosyltransferase_roberta_350000_steps and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out

In [3]:
def predict_all_gt(sequence):
    encoding = tokenizer(
        sequence,
        max_length=maxlength,
        return_tensors='pt',
    )
    encoding.to(device)
    _, test_prediction = agt_model(encoding["input_ids"], encoding["attention_mask"], encoding["token_type_ids"])
    test_prediction = test_prediction.detach().cpu().flatten().numpy()
    return test_prediction[0]

def predict_c_gt(sequence):
    encoding = tokenizer(
        sequence,
        max_length=maxlength,
        return_tensors='pt',
    )
    encoding.to(device)
    _, test_prediction = cgt_model(encoding["input_ids"], encoding["attention_mask"], encoding["token_type_ids"])
    test_prediction = test_prediction.detach().cpu().flatten().numpy()
    return test_prediction[0]

In [4]:
def strip_fasta(fasta_path):
    from Bio import SeqIO
    line_by_line=[]
    with open(fasta_path) as handle:
        for record in SeqIO.parse(handle, "fasta"):
            line_by_line.append(record.seq)
    with open(fasta_path+".csv", 'w') as f:
      for line in line_by_line:
        f.write(f"{line}\n")

In [9]:
strip_fasta("./Full_ferox.fasta.transdecoder.pep")

In [10]:
from tqdm import tqdm
df = pd.read_csv('Full_ferox.fasta.transdecoder.pep.csv', names=["sequence"])
df

Unnamed: 0,sequence
0,SLIRKLTSLRRCTPPLAKALTFDKVKHCLNIIEMIQGGGWRGREVG...
1,AQSLINRGKILRILKDNLLKAQQRMERMANLHRTNKSFDIGDWVYL...
2,RPHPQSRAALVTEPSSSPPFPLHRDPRARPLSPRSSRSRPGLISSP...
3,FSFLKNLSLSIHGRQRILHSHLKSNPEILSSSRIIAFLSTQMIQDH...
4,KSSSRMAFYLGDNAITWASHKQKTVALSSCEAEFVTTSGAACQAIW...
...,...
126493,MGAGRTTETVHARERSLQSNMSTNSSVSARNLPKCNLAGVIFGCTH...
126494,KFPQFSLPIKNPEEKTERERERERERERPSLKSESMVLLLLYHTFV...
126495,MESKSTRGRFTLLKQSSMAPERNNGSDAVDATGFEEENEVPANIRL...
126496,PSTHTLSTVEVLYSWEIIPYVKRSALVLCVSELVMEVSGHLQMSAM...


In [27]:
allgt=[ i for i in tqdm(df.sequence) if predict_all_gt(str(i)) > 0.9]

  0%|                                                                                         | 0/56180 [00:00<?, ?it/s]

NameError: name 'agt_model' is not defined

In [10]:
filtered = [i for i in allgt if len(i)>=400]
filtered2 = [i for i in filtered if i[0]=="M"]
filtered2

['MECARAIAGRDNARVQQLLWMLNELSSPYGDTEQKVASYFLQALFARLTSSGPRTLRTLLSASDRTASFDSTRRTALKFQELSPWSSFGHVACNGAILESFLDGVAAAASSSSSSASSSTSPQRLHILDLSNTFCTQWPTLLEALATRSADDTPHLTITTVVSSPSASSSSAQRVMKEIGARMEKFARLMGVPFKFNIVHHAADLSTLDLDRLNLAADSSLLAINCVNALHGISPTGGRRDAFIASLRRLRPRILTVVEEEADMGSDEEEEGFVRGFGEALRFFSSYLESLEESLPKTSNERLALERAAGRAVMDVVSCPASESRERREKAGGWSARMRAAGFMPTNFSDDVADDVRALLRRYREGWTMTTAAEEAAGSSAGIFLAWKDEPVVWASAWRLAHV*',
 'MTNLNGHAPLMANGNGHAVAPFANGSNGHAAAAAPLHVGVLPAAGPGHHDHFLELSKRLALLGCRISFFSTDSELQLLYSPQLPSNLHLVPLSPSPTTSASPNAAYDSISVSLAAFLSSSTPPVDWLLCDYAPYWAPPIARSLNVKSAFISIYSAAAMAFLGPPQDLILNKKWPIKLQSDGLTVVPEWVHFYTPIAFKPHEVPRLPRNMEDSPDKPYPPVNVTDFYRFGVTIRDTDAVAIRSFTDLEQPQWFGLLQDRIFKKPVLPLGLFPPTPTPSPSDDQWSRQLAWLDRRPRGSVVYAAHGIGVILSKEEVGEIARGLELSGLPFVWALRWAEGEASPLPEGFEGRVGDRGIVCKGWVPQVRFLAHPAIGGFLTNCGWSSVGEALQFGVPLILLSMVRDQKTNEKVLADRGVGKEVPREGQDRTFTAEGIAATVRMVVVEKEGELYRAKAKEMMGSIGDGKAQNGQIKVFVEYLKDHKPVRI*',
 'MAIRTAAAIGGYLAQNFAASAGIRCAPCRVFHESTARRPLYLFANQRLDDDPSLTTSRARDRDKHRSAKCDRSRPTVQDRSRPHVSAHAPASMSRYST

In [11]:
len(filtered2)

95

In [12]:
cgt1=[ i for i in tqdm(allgt) if predict_c_gt(str(i)) >= 0.5]
cgt1

100%|█████████████████████████████████████████████████████████████████████████████████| 619/619 [00:09<00:00, 66.22it/s]


['GKIFKEELLPVRLQKNPQHLPLLKSQKMYAHNHQLSSASLLKNTFANISPHKAETQTETERAMSSDVPKRQSPPHIALLPSAGMGHFAPFCRLAAALSSRGCDVTFITTQPTVSSAESLRTSDLLSSFPSIR']

In [13]:
filtered_cgt = [i for i in cgt1 if len(i)>=450]
filtered_cgt

[]

In [15]:
#torch.save(agt_model, "./agt")
#torch.save(cgt_model, "./cgt")

In [5]:
agt = torch.load("./agt")
agt.eval()
cgt = torch.load("./cgt")
cgt.eval()

C_GT_Model(
  (bert): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(52000, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_fe

In [6]:
class cgt2layers(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.layer1 = agt
    self.layer2 = cgt
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
    output = self.layer1(input_ids, attention_mask=attention_mask)[1][0][0]
    if output >= 0.50:
      output = self.layer2(input_ids, attention_mask=attention_mask)[1][0][0].squeeze()
    else:
      output = torch.tensor(0) 
    return output.detach().cpu().flatten().numpy()[0]


In [7]:
model = cgt2layers(2)
model = model.to(device)

In [8]:
def model_predict(sequence):
    encoding = tokenizer(
        sequence,
        max_length=maxlength,
        return_tensors='pt',
    )
    encoding.to(device)
    pred = model(encoding["input_ids"], encoding["attention_mask"], encoding["token_type_ids"])
    return pred

In [13]:
potential_cgts=[ i for i in tqdm(df.sequence) if model_predict(str(i)) >= 0.7]

100%|███████████████████████████████████████████████████████████████████████████| 126498/126498 [33:09<00:00, 63.59it/s]


In [14]:
potential_cgts

['MGHLTPFCRLAALLSSRGGGDLDVSFITAEPTVSLKEYLQIRDLLSSFPSIKSHRFKVPELSPSQFPSSPDPFFQQWESIRRCAHLLPPLLASATALIIDTASASAVLPVAKQLNIPAYILFTSSASMLSLVAHFPSHVISSTSPSAPLMSNILIPGLKHPVPAAWVPPPLHVPGHIFSTLTVDNGRCLPEADGVIINTFDALEPEVLAALNGGRVAHGLPRVFAVGPLVAPPPPQQMEETGGGDYPIAWLDRQRDKSVVYVSFGSRTAMSPEQIRELAAGLERSGCSFLWVIKTKKVDRDEEETDLGALLGEGYVERVKGRGLVVNGWVEQEEILRHRAVGGFVSHCGWNSVTEAAVA',
 'MLPFLSHPVEPFLSASCSASLSAAAYPPSPFRYLFLRRFLILPNAPSTAYFPKISLRRLLDLPHLTPLIVPPFRFADYSSFLSSLSTVPSSEPFLPIRDPPHGYSSS*',
 'DDSPGATATTNDIDPGDELRIFRVRLQQIRHALFDMATDHPPYLHPHTTSDDHHHHDHSCTTPSSTYTDVKGFAEPPTASWAIVIIRYLSTTSTRITLAGLAGAATTEAAKAAWGDLA',
 'IFSIAATLTEFETLDLGSQTLDPGCPALSRTFQSLVGLAALIRGHRNPTASQAASLEPPLLSHRLLPWSLPSCRTNLTTLPTTFLDADRLLPKTQRLRLKLPVVAQV*',
 'PFSYTPLVSATVTLLKRWSSFSSFTPLALSPRALYSTLFSKPPSLLHHLSIPQLATGRSSFSIPHLHATTLRTVALHTLVHVEAPQHFHLSLPDPTYVAVHNTCTIPSDL*',
 'MAHRRALPTPPSAQSWVELAPWWPDNQMVQRLRGHRVGWGTRVVGTPHVVSHLWSHSDRSSAGVSGLAKVVVTGLQQTPALSLTALHLYLGVTVPRAYIGRWTLYELN*']

In [71]:
df2 = pd.read_csv('train_glycosyltransferase.fasta', names=["sequence"])

In [72]:
all_cgts=[ i for i in df2.sequence if model_predict(str(i)) >= 0.9]
len(all_cgts)

1890

In [73]:
len(all_cgts)

1890

In [75]:
all_cgts2=[ i for i in df2.sequence if model_predict(str(i)) >= 0.5]
len(all_cgts2)

4755

In [74]:
n=0
with open("all_potential_CGT_predicted.fasta", "w") as f:
    for line in all_cgts:
        f.write(f">sequence_{n}\n{line}\n")
        n+=1

In [33]:
all_cgts

['MMGDLTTSFPATTLTTNDQPHVVVCSGAGMGHLTPFLNLASALSSAPYNCKVTLLIVIPLITDAESHHISSFFSSHPTIHRLDFHVNLPAPKPNVDPFFLRYKSISDSAHRLPVHLSALSPPISAVFSDFLFTQGLNTTLPHLPNYTFTTTSARFFTLMSYVPHLAKSSSSSPVEIPGLEPFPTDNIPPPFFNPEHIFTSFTISNAKYFSLSKGILVNTFDSFEPETLSALNSGDTLSDLPPVIPIGPLNELEHNKQEELLPWLDQQPEKSVLYVSFGNRTAMSSDQILELGMGLERSDCRFIWVVKTSKIDKDDKSELRKLFGEELYLKLSEKGKLVKWVNQTEILGHTAVGGFLSHCGWNSVMEAARRGVPILAWPQHGDQRENAWVVEKAGLGVWEREWASGIQAAIVEKVKMIMGNNDLRKSAMKVGEEAKRACDVGGSSATALMNIIGSLKR',
 'MMGDLTTSFPATTLTTNEQPHVVVCSGAGMGHLIPFLNLASTLSSAPYRCKVTLLIVIPLITDAESHHISSFFSSHPTIHRLDFHVNLPAPKPNVDPFFLRYKSISDSAHRLPVHLSTLAPPISAVFSDFLFTQGLNTTLPHLPNYTFTTTSARFFTLMSYVPHLAKSSSSSPVEIPGLEPFPTDNIPPPFFNPDHIFTSFTISNANYLSLSKGIIVNTFDSFEPETLSALNSGDSLPDLPPVIPIGPLNELEHNKQEELLPWLDQQPEKSVLYVSFGNRTAMSSDQILELGMGLERSDCRFIWVVKTSKIDKDDKSELRKLFGEELYVKLSEKGKLVKWVNQTEILGHTAVGGFLSHCGWNSVMEAARRGVPILAWPQHGDQRENAWVVEKAGLGVWEREWSSGIQVAIVEKVKMIMGNNDLRNSAVRVGEEAKRACDVGGSSATALMNIIGSLKR',
 'MAEAGVLLAGGGTAGHVNPLLAVADELRRRHPHGRFGVLGTAEGLEARLVPEHGYDLHVVPRVPLPRRPTPDWL

In [59]:
strip_fasta("./ginger.fasta")

In [60]:
df = pd.read_csv('ginger.fasta.csv', names=["sequence"])
df

Unnamed: 0,sequence
0,MAAYQKPKAIFMAFGTKGDVFPIAAIASAFACDQKQYQVVLITHRA...
1,MEKRLKIWIYKEGEPPIVHGGPAASIYAIEGHFISEMEREANPFVA...
2,MAGMSNASSFVPLLLLLLLILLLILSFSFFDQQRPYNFKAPLAVLR...
3,MGASSVATTRRPSKSSFSYLSRVIVFLAAAVALLVVFYGVYTTYSS...
4,MSPSFDWWGEDAHRGTSVVVKMENPNWSISEISSPADNDEDYSGEG...
...,...
1225,MAQKKGQRGATSVAADDSEVVSRVPLQAVLLADSFNLRFRPITLER...
1226,MAEAGVQIQAFQTLTLLPSNIVVKLNTESAHVYEASSILTGGRATA...
1227,MSTSHSALCSTLSSRPWTFSTPSTLRTAPSWLSARTLPSTSTLSSS...
1228,MVSLVEFVCQYPKCNRSVAYFPLSFTSMSILVQALLQNATESMKES...


In [65]:
all_cgts=[ i for i in df.sequence if model_predict(str(i)) >= 0.8]
len(all_cgts)

10

In [69]:
all_cgts

['MTEPQPYADRRRTPHIVFLASSGMGHLNPFLRLAAQLSARDCRVSLLTTHPTVSAAEERHIDAFFTAFPGVRRLDLHLPLLDPSELNCADPFTLRFESIRRSAHLLPRVLADASPPVSAAIVDISVASSYLPAVAEMGIPSYILFISCAAMLALCAYFPTYLATKNTFGVGDIDIPGVWTVPKSSVPVVLHDQNQLFAKQFLENGQALPKANGILVNTLHAWEPEALAALNEGKVTPDMPPVLAIGPLLPVPKMSVEGESVSAPLPWLDRQPDRSVVYVSFGNRTGMQAEQIRELGIGLKRSGLRFLWVVKSKVVDRAEVEVGLEELLGEEYLAKIKDRGMVVKDWVEQAEILRHRAVGGFLSHCGWNSVVEAALYGVRILGWPMGGDQRVSAAAMARGGLGIWVEGWSWQAEEPIKAEEISERLKDLMAVEGLRASATTMGAAAADVGGTSYQVLTHFIESLTL',
 'MEAKGFDKSDSGADAPHVVLLATPGAGHLIPLAEFAKLLVDRHGFSVTIITYTVFASKAQDALLSSLAPSVASLVLPAVPFDDLPPDARIETIISVVAVRSVPTLRAELSRLQASFNVVALVGDLFATEAIAVARDLGVPPYLFFPSNLLTLSLILHLPELDASVACEYRDLPEPLRLPGCVPIPGPELLHPIQDRSNEVYKWILRHGRRYRDAEGIIGNTFEAFEPDAAKVLKQGQPPLYLVGPLTQSRPAGAEEDAECLRWLDKQPVGSVLFVSFGSGGTLTTAQLAELALGLEMSGQRFLWVVRSPSDGGNASANYFTTQSKTDPFPFLPAGFVERTREVGLLIPSWAPQVSVLGHAATGGFLSHCGWNSTLESVKAGVPIIAWPLFAEQRQNAVMLAEGAKIALRLRAAEDGLVPREEVARVVKELMEGQEGKSARQRVSELQEAALRCLEEGGAALEALDEVVNRWKSRN',
 'MEAKGFDKPDSGADAPHVVLLATPGAGHLIPLAEFAKLLVDRHGFSVT

In [70]:
model_predict(input())

MEAKGFDKPDSGADAPHVVLLATPGAGHLIPLAEFAKLLVDRHGFSVTIITYTVFASKTQDALLSSLAPSVASLVLPAVPFDDLPPDARIETIMSVVAVRSVPTLRAELSRLQASFNVVALVGDLFATEAIAVARDLGVLPYLFFPSNLLTLSLILHLPELDASVACEYRDLPEPLRLPGCVPIPGPELLHPIQDRSNEVYKWILRHGRRYRDAEGIIANTFEAFEPDAAKVLKQGQPPLYLVGPLTQSRPAGAEEDAECLRWLDKQPVGSVLFVSFGSGGTLTTAQLAELALGLEMSGQRFLWVVRSPSDGGNASANYFTTQSKADPFPFLPAGFVERTREVGLLIPSWAPQVAVLGHAATGGFLSHCGWNSTLESVKAGVPIIAWPLFAEQRQNAVMLAEGAKIALRLRAAEDGLVPREEVARVVKELMEGQEGKSARQRVSELQEAALRCLEEGGAALEALDEVVNRWKSRN


0.8810906