## The exmple of using VirulentHunter

In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from Bio import SeqIO

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel

#### 1. 

In [8]:
# 1. Download to your local or from Hugging Face Hub 
BASE_MODEL_PATH = "/mnt/data/cs/ESM2_Model/esm2_t30_150M_UR50D" 
VirulentHunter_Binary_MODEL_PATH = 'models/binary'
VirulentHunter_Multi_label_Model_PATH = 'models/multi-label'
LABEL_INFO_PATH = 'data/labels.csv'

In [6]:
# 2. Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
binary_model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL_PATH)
binary_model       = PeftModel.from_pretrained(binary_model, VirulentHunter_Binary_MODEL_PATH)
binary_model.eval()

category_model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL_PATH,
                                                                        num_labels=14,)
category_model     = PeftModel.from_pretrained(category_model, VirulentHunter_Multi_label_Model_PATH)
category_model.eval()
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  
binary_model.to(device)
category_model.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at /mnt/data/cs/ESM2_Model/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at /mnt/data/cs/ESM2_Model/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): EsmForSequenceClassification(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 640, padding_idx=1)
          (dropout): Dropout(p=0.0, inplace=False)
          (position_embeddings): Embedding(1026, 640, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-29): 30 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=640, out_features=640, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.2, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=640, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
         

In [12]:
label_info = pd.read_csv(LABEL_INFO_PATH)

In [20]:
# 3. define the predict function
def predict(input_fasta, output_path, max_length=2000):
    print(f'Read fasta from {input_fasta}')
    sequences_dict = {}
    for record in SeqIO.parse(input_fasta, "fasta"):
        sequences_dict[record.id] = record.seq
    
    binary_logits = {}     
    for seq_id, sequence in tqdm(sequences_dict.items()):       
        encoding = tokenizer(str(sequence), truncation=True, return_tensors='pt', 
                                  padding='max_length', max_length=max_length)
        encoding = encoding.to(device)
        with torch.no_grad():   # 
            outputs = binary_model(**encoding) # 
            logits = outputs.logits        # 
            binary_logits[seq_id] = np.round(torch.nn.functional.softmax(logits,dim=-1).cpu().tolist()[0],3)

    prob_df = pd.DataFrame(binary_logits.values(), binary_logits.keys(), columns=['no_vf_prob', 'vf_prob']) 
    prob_df['id'] = prob_df.index
    prob_df = prob_df[['id', 'vf_prob']]
    prob_df.reset_index(drop=True, inplace=True)

    for cat in label_info['category'].unique():
        prob_df[cat] = 0.0
    
    for seq_id, sequence in tqdm(sequences_dict.items()):
        vf_prob = prob_df.loc[prob_df['id']==seq_id, 'vf_prob'].values[0]
        if vf_prob >=  0.5:
            encoding = tokenizer(str(sequence), truncation=True, return_tensors='pt', 
                                  padding='max_length', max_length=max_length)
            encoding = encoding.to(device)
            with torch.no_grad():   
                outputs = category_model(**encoding) 
                logits = outputs.logits        
                probs = torch.nn.functional.sigmoid(logits)
                probs = np.round(probs.cpu().numpy().squeeze().tolist(),3).tolist()
                prob_df.loc[prob_df['id']==seq_id, 'Exotoxin':'Regulation'] = probs

    prob_df.to_csv(os.path.join(output_path, f'predict_results.csv'), sep=',')
    return prob_df

In [15]:
input_fasta_file = 'data/test.fasta'
output_path = 'results/'

In [21]:
results = predict(input_fasta_file, output_path)

Read fasta from data/test.fasta


100%|██████████| 22/22 [00:03<00:00,  6.78it/s]
100%|██████████| 22/22 [00:02<00:00,  7.69it/s]


In [22]:
results

Unnamed: 0,id,vf_prob,Exotoxin,Stress survival,Biofilm,Immune modulation,Invasion,Adherence,Effector delivery system,Nutritional/Metabolic factor,Motility,Antimicrobial activity/Competitive advantage,Others,Post-translational modification,Exoenzyme,Regulation
0,VFG037170(gb|WP_001081754),0.856,0.999,0.01,0.004,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.001,0.0,0.0,0.0
1,VFG037179(gb|YP_001085084),0.682,0.973,0.019,0.002,0.0,0.0,0.0,0.0,0.01,0.002,0.006,0.016,0.0,0.0,0.004
2,VFG037189(gb|WP_000632992),0.955,0.996,0.002,0.008,0.0,0.0,0.003,0.0,0.004,0.001,0.0,0.001,0.0,0.0,0.0
3,sp|A0R6D9|MAK_MYCS2,0.032,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,sp|A0R6H7|IRTB_MYCS2,0.926,0.003,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.001,0.0,0.0,0.0
5,sp|A0R7F9|RS6_MYCS2,0.022,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,sp|A0R7G6|INO1_MYCS2,0.107,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,VFG048596(gb|WP_014229387),0.978,0.0,0.0,0.0,0.0,0.0,0.003,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
8,VFG048641(gb|WP_004224342),0.992,0.0,0.0,0.0,0.0,0.0,0.0,0.999,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,VFG014502(gb|WP_003378750),0.994,0.001,0.0,0.001,0.001,0.0,0.0,0.0,0.003,1.0,0.0,0.001,0.0,0.0,0.001
