In [17]:
!pip install -q fair-esm peft transformers torch biopython

In [None]:
import numpy as np 
import pandas as pd 
from Bio import SeqIO
import os
import seaborn as sns
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
import esm
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
warnings.simplefilter(action='ignore', category=FutureWarning)

Data_dir = "/kaggle/input/south-north-gasaid"
filenames = ["south-animal-H1N1.fasta","south-animal-H3N2.fasta","south-human-H1N1.fasta" ,"south-human-H3N2.fasta"]
file_paths = [os.path.join(Data_dir, file) for file in filenames]

In [None]:
def readDataFromFile(filenames):
    dfs = []  

    for filename in filenames:
        file_path = os.path.abspath(filename)  
        df = pd.DataFrame.from_records([
            {
                "Class": (record.description.split("|")[-2]) if "|" in record.description else record.description,  
                "virus_ID": "|".join(record.description.split("|")[2:]) if "|" in record.description else record.description,
                "seq_ID": "|".join(record.description.split("|")[1:2]) if "|" in record.description else record.description,
                "Sequence": str(record.seq),  
                "Length": len(record.seq),  
            }
            for record in SeqIO.parse(file_path, "fasta")
        ])
          
        dfs.append(df)  

    return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()    

In [None]:
print(file_paths[0:2])
df=readDataFromFile(file_paths[0:2])

print(df.info())


['/kaggle/input/south-north-gasaid/south-animal-H1N1.fasta', '/kaggle/input/south-north-gasaid/south-animal-H3N2.fasta']
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1637 entries, 0 to 1636
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Class     1637 non-null   object
 1   virus_ID  1637 non-null   object
 2   seq_ID    1637 non-null   object
 3   Sequence  1637 non-null   object
 4   Length    1637 non-null   int64 
dtypes: int64(1), object(4)
memory usage: 64.1+ KB
None


In [None]:
# Load ESM-2 Model & Tokenizer
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)  # Fast tokenizer
model = AutoModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set model to inference mode

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 1280, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 1280, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-32): 33 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmInter

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return row["Sequence"]

def collate_fn(batch):
    sequences = batch  

    tokenized_batch = tokenizer(list(sequences), padding=True, truncation=True, return_tensors="pt")
    return tokenized_batch.to(device)  

dataset = SequenceDataset(df)
dataloader = DataLoader(dataset, batch_size=20, shuffle=False, collate_fn=collate_fn)

In [None]:
num_samples = len(df)
hidden_dim = model.config.hidden_size  # Get embedding dimension from model
cls_embeddings = torch.zeros((num_samples, hidden_dim), device=device)  
index = 0  

with torch.no_grad():
    for batch in tqdm(dataloader, desc="Running Inference"):
        tokenized_inputs = batch  
        results = model(**tokenized_inputs, output_hidden_states=True)

        # Extract CLS Embeddings (First token from last layer)
        batch_cls_embeddings = results.hidden_states[-1][:, 0, :]  # Shape: (batch_size, hidden_dim)
        batch_size = batch_cls_embeddings.shape[0]
        cls_embeddings[index : index + batch_size] = batch_cls_embeddings
        index += batch_size
        

cls_embeddings.shape

Running Inference:   0%|          | 0/82 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Running Inference: 100%|██████████| 82/82 [05:29<00:00,  4.02s/it]


torch.Size([1637, 1280])

In [None]:
# labels, mapping = pd.factorize(df["Class"])  # Vectorized encoding
# labels = torch.tensor(labels, dtype=torch.int8, device=device)  
df["Class"] = df["Class"].str.lower()  # Ensure consistent casing
labels = torch.tensor((df["Class"] != "human").astype(int), dtype=torch.int8, device=device)

seq_IDs = df["seq_ID"].to_numpy()  # OR df["Length"].values
virus_ID = df["virus_ID"].to_numpy()  # OR df["Length"].values

torch.save(labels.cpu(), "south_labels.pt")
torch.save(cls_embeddings.cpu(), "south_embd.pt")
np.save("south_seq_IDs.npy", seq_IDs)
np.save("south_virus_ID.npy", virus_ID)