In [1]:
# !pip install --no-cache-dir gensim
# import gensim
# print(gensim.__version__)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from __future__ import print_function
import argparse
import numpy as np
import torch.optim as optim
import torch.utils.data as data_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from collections import Counter
from collections import defaultdict
from tqdm import tqdm
from gensim.models import Word2Vec,KeyedVectors
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import os
import pandas as pd
from itertools import chain
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Auto-detect GPU


In [3]:
def readDataFromFile(filename):
    file_path = os.path.abspath(filename)  # Ensure absolute path

    # Read CSV file
    df = pd.read_csv(file_path)

    # Rename columns for clarity
    df.columns = ["ID", "Sequence"]

    # Extract virus_ID (part after first "|") and seq_ID (part after second "|")
    df["Virus_ID"] = df["ID"].apply(lambda x: "".join(x.split("|")[1:]) if "|" in x else "")
    df["Seq_ID"] = df["ID"].apply(lambda x: x.split("|")[0] if "|" in x else "")
    df["Class"] = df["ID"].apply(lambda x: x.split("|")[-1] if "|" in x else "")
    df["Length"] = df["Sequence"].apply(lambda x: len(x))

    return df[["Sequence","Virus_ID", "Seq_ID", "Class","Length"]]  # Return relevant columns


In [4]:
def ASW(sequence, l_sub, n):
    """
        sequence (str): The original viral sequence.
        l_sub (int): The length of each subsequence.
        n (int): The number of subsequences to generate.
    """
    l = len(sequence)
    
 
    if n > 1:
        l_stride = (l - l_sub) // (n - 1)
    else:
        l_stride = 1  
    
    subsequences = []
 
    for i in range(0, min(n * l_stride, l - l_sub + 1), l_stride):
        subsequences.append(sequence[i:i + l_sub])
    
    return subsequences


In [5]:
# 0->human  1-> animals
class GatedAttention(nn.Module):
    def __init__(self,nhead,encoderNlayers,embeddingSize,intermidiateDim):
        super(GatedAttention, self).__init__()
        self.M = embeddingSize
        self.L = intermidiateDim
        self.encoderNlayers=encoderNlayers
        self.ATTENTION_BRANCHES = 1
        self.nhead=nhead

        # embedding 
        self.encoder_layer = TransformerEncoderLayer(d_model=self.M, nhead=self.nhead)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=self.encoderNlayers)
        
        # instance level 
        self.attention_V_1 = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix V
            nn.Tanh()
        )

        self.attention_U_1 = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix U
            nn.Sigmoid()
        )

        self.attention_w_1 = nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)


        # bag level 
        self.attention_V_2 = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix V
            nn.Tanh()
        )

        self.attention_U_2 = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix U
            nn.Sigmoid()
        )

        self.attention_w_2 = nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)


        
        # classifier
        self.classifier = nn.Sequential(
            nn.Conv1d(in_channels=self.ATTENTION_BRANCHES, out_channels=128, kernel_size=4, padding='same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5, padding='same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=2),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=7, padding='same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=2),
            nn.Flatten(),  # Converts to 1D before fully connected layers
            nn.Linear(128 * ((self.M) // 4), 256),  # Adjust size based on sequence length
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 1),
            nn.Sigmoid() 
        )

    
    def forward(self, datas,ids,Seq_ids):
        #### STEP 1:embeddings
        datas = datas.float()  # Ensure correct dtype
        instances=self.transformer_encoder(datas) 
        
        #### STEP 2: INSTANCE-LEVEL ATTENTION ####
        # Apply attention mechanisms per bag (over instances_per_bag)
        A_V = self.attention_V_1(instances)  
        A_U = self.attention_U_1(instances)  
        A = self.attention_w_1(A_V * A_U)
        A = torch.transpose(A, 1, 0)  
        inner_bags = torch.unique_consecutive(Seq_ids)
      
        output = torch.empty(((len(inner_bags), self.M))).to(device)
        super_ids = torch.empty(((len(inner_bags))))
        for i, bag in enumerate(inner_bags):
            A_vec=F.softmax(A[0][Seq_ids == bag],dim=0)
            output[i] = torch.matmul(A_vec, instances[Seq_ids == bag])
            super_ids[i]=ids[Seq_ids == bag][0]
        
        ### STEP 3: BAG-LEVEL ATTENTION ####
        A_V_2 = self.attention_V_2(output)  
        A_U_2 = self.attention_U_2(output)  
        A_2 = self.attention_w_2(A_V_2 * A_U_2)  
        A_2 = torch.transpose(A_2, 1,0)   

      
        outer_bags = torch.unique_consecutive(super_ids)
        output2 = torch.empty(((len(outer_bags), self.M))).to(device)

        for i, bag in enumerate(outer_bags):
            A_vec_2=F.softmax(A_2[0][super_ids == bag],dim=0)
            output2[i] = torch.matmul(A_vec_2, output[super_ids == bag])

        
        
        ### STEP 4: CLASSIFICATION ####
        # output2 = output2.view(output2.shape[0], -1)  # Flatten over bags_per_bag for classification
        output2 = output2.unsqueeze(1)  # Add a channel dimension


        Y_prob = self.classifier(output2)  # Shape: [batch_size, 1]
        Y_hat = torch.ge(Y_prob, 0.5).float()  # Convert probabilities to binary predictions
        return Y_prob, Y_hat, A

In [6]:
class MILDataset(Dataset):
    def __init__(self, datas, ids, seq_ids, labels):
        self.datas = datas  # Instance features
        self.ids = ids # Virus (outer bag) IDs
        self.seq_ids = seq_ids  # Sequence (inner bag) IDs
        self.labels = labels.to("cpu")  # Labels at the virus (outer bag) level

        # Unique IDs for outer bags (viruses) and their indices
        self.unique_virus_ids, self.virus_indices = torch.unique(self.ids, return_inverse=True)
        
        # Unique IDs for inner bags (sequences) and their indices
        self.unique_seq_ids, self.seq_indices = torch.unique(self.seq_ids, return_inverse=True)

        # Mapping from virus to instance indices  2d array each list is the virus data indecies
        self.virus_bag_indices_list = [torch.where(self.virus_indices == i)[0].to("cpu") for i in tqdm(range(len(self.unique_virus_ids)))]

        # Mapping from sequence to instance indices 2d array each list is the seq data indecies
        self.seq_bag_indices_list = [torch.where(self.seq_indices == i)[0].to("cpu") for i in tqdm( range(len(self.unique_seq_ids)))]

        # Labels assigned at the virus level (each virus gets one label)
        self.virus_labels = [self.labels[indices[0]] for indices in self.virus_bag_indices_list]

        # Precomputed bag-of-bags structure (virus → [seq])
        self.virus_seq_map = {}  # Maps virus_id -> list of sequence indices
        for i, virus_id in tqdm(enumerate(self.unique_virus_ids)):
            self.virus_seq_map[virus_id.item()] = list((self.seq_ids[self.virus_bag_indices_list[i]].tolist()))

        # Precomputed bag IDs for each virus and sequence
        self.precomputed_virus_ids = [torch.full((indices.shape[0],), self.unique_virus_ids[i], dtype=torch.long) 
                                      for i, indices in enumerate(self.virus_bag_indices_list)]

      
        self.datas = self.datas.cpu()


    def __len__(self):
        return len(self.unique_virus_ids)  # Number of unique viruses (outer bags)

    def __getitem__(self, index):
        """ Return outer bag (virus), inner bags (sequences), and instance-level data. """
        
        # Get all instance indices belonging to this virus
        virus_instance_indices = self.virus_bag_indices_list[index]
        # Retrieve instance-level data
        virus_data = self.datas[virus_instance_indices]
        virus_label = self.virus_labels[index]
        virus_id = self.precomputed_virus_ids[index]
        # Find which sequences belong to this virus
       
        seq_ids_in_virus = self.virus_seq_map[virus_id[0].item()]

        return {
            "virus_id": virus_id,
            "virus_data": virus_data,
            "virus_label": virus_label,
            "seq_id": seq_ids_in_virus
        }


def collate_fn(batch):
    """ Custom collate function for Bag-of-Bags MIL """

    batch_size = len(batch)

    all_virus_ids = []
    all_virus_data = []
    all_virus_labels = []
    all_virus_seq_ids = []
   

    for item in batch:
        virus_id = item["virus_id"].tolist()
        virus_data = item["virus_data"].tolist()
        virus_label = item["virus_label"]
        seq_id = item["seq_id"]

        all_virus_seq_ids.extend(seq_id)
        all_virus_ids.extend(virus_id)
        all_virus_data.extend(virus_data)
        all_virus_labels.append(virus_label)
    
    # Convert to tensors
    batch_virus_labels = torch.tensor(all_virus_labels, dtype=torch.float)
    batch_seq_ids = torch.tensor(all_virus_seq_ids, dtype=torch.float)
    batch_virus_datas = torch.tensor(all_virus_data, dtype=torch.float)
    batch_virus_ids = torch.tensor(all_virus_ids, dtype=torch.float)


    return batch_virus_datas, batch_virus_ids,batch_seq_ids, batch_virus_labels


In [7]:
def train(epoch,dataloader):
    model.train()
    running_loss=0.
    acc=0
    human=0
    animal=0
    total_samples=0
    


    for batch_data, batch_ids,batch_seq_ids, batch_labels in tqdm(dataloader, desc="Processing Batches"):
        batch_data,batch_ids,batch_seq_ids, batch_labels  = batch_data.to(device), batch_ids.to(device),batch_seq_ids.to(device), batch_labels.to(device)
        Y_prob, Y_hat, A =model(batch_data,batch_ids,batch_seq_ids)
        Y_prob=Y_prob.squeeze(1)
        loss = criterion(Y_prob, batch_labels)
        # Optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        acc += ((Y_hat == batch_labels).sum().item())
        total_samples += batch_labels.size(0)  # Track the total number of samples processed
        human+=np.sum(Y_hat.cpu().numpy() == 0)
        animal+=np.sum(Y_hat.cpu().numpy() == 1)
    print(f'Epoch: {epoch}, Loss: {running_loss:.4f}, LR: {scheduler.get_last_lr()}')
    acc=acc/total_samples*100
    print(f'acc: {acc:.1f}%')
    print("human = ",human)
    print("animal = ",animal)
    return running_loss

In [8]:
def test(dataloader):
    # print(len(labels))
    model.eval()
    acc=0
    human=0
    animal=0
    total_samples=0
    output=[]
    b_lables=[]
    with torch.no_grad():
         for batch_data, batch_ids,batch_seq_ids, batch_labels in tqdm(dataloader, desc="Processing Batches"):
            batch_data,batch_ids,batch_seq_ids, batch_labels  = batch_data.to(device), batch_ids.to(device),batch_seq_ids.to(device), batch_labels.to(device)
            Y_prob, Y_hat, A =model(batch_data,batch_ids,batch_seq_ids)         
            # output+=out_embed
            # b_lables+=batch_labels
            Y_prob=Y_prob.squeeze(1)
            Y_hat = Y_hat.view_as(batch_labels)
            acc += ((Y_hat == batch_labels).sum().item())
            human+=np.sum(Y_hat.cpu().numpy() == 0)
            animal+=np.sum(Y_hat.cpu().numpy() == 1)
            total_samples += batch_labels.size(0)  # Track the total number of samples processed

    acc=acc/total_samples*100
    print(f'acc: {acc:.1f}%')
    print("human = ",human)
    print("animal = ",animal)

    return output,b_lables

In [9]:
# read dataset
fileNameOriginalDatas="/kaggle/input/ncbi-data-csv/ncbi_data.csv"
df=readDataFromFile(fileNameOriginalDatas)

# get the length of the longest seq
llongest=max(df['Length'])
lshortest=min(df['Length'])
print("llongest",llongest)
llongest=max(df['Length'])
print("lshortest",lshortest)


llongest 775
lshortest 202


In [10]:
n=193
lower_bound = int(llongest / n)
upper_bound = int(llongest - n + 1)
l_sub_array=np.arange(lower_bound, upper_bound + 1)
l_sub=lshortest-n+1
if l_sub not in l_sub_array:
    print("error ASW")
print(l_sub)



10


In [11]:

df["Class"] = df["Class"].str.lower()  #Ensure consistent casing
labels = np.array((df["Class"] != "human").astype(int))
ids=df["Virus_ID"]
seq_ids=df["Seq_ID"]+" "+df["Virus_ID"]

# convert string id to numeric
_,ids = np.unique(ids, return_inverse=True)
_,seq_ids = np.unique(seq_ids, return_inverse=True)

In [34]:
datas=df["Sequence"]
# Get unique bag IDs
unique_bag_ids = np.unique(ids)

# Split bag IDs into train and test
train_ids, test_val_ids = train_test_split(unique_bag_ids, test_size=0.5, random_state=42)
test_ids, val_ids = train_test_split(test_val_ids, test_size=0.001, random_state=42)

# Get indices corresponding to train/test bag IDs
train_indices = np.where(np.isin(ids, train_ids))[0]
test_indices = np.where(np.isin(ids, test_ids))[0]
val_indices = np.where(np.isin(ids, val_ids))[0]

# # Create train data
train_datas = datas[train_indices]
train_ids = ids[train_indices]
train_seq_ids = seq_ids[train_indices]
train_labels = labels[train_indices]

# # Create test data
test_datas = datas[test_indices]
test_ids = ids[test_indices]
test_seq_ids = seq_ids[test_indices]
test_labels = labels[test_indices]


# # Create val data
val_datas = datas[val_indices]
val_ids = ids[val_indices]
val_seq_ids = seq_ids[val_indices]
val_labels = labels[val_indices]

print(train_datas.shape)


(96487,)


In [13]:
# need validation
sg_embed_size=30
sg_window=5
# Transformer Parameters
nhead = 5         # Number of attention heads
encoderNlayers = 2       # Number of transformer layers
embeddingSize=sg_embed_size
intermidiateDim=512


In [14]:
# Apply ASW 

train_datas = [ASW(sequence,l_sub, n) for sequence in train_datas.tolist()]

train_labels= np.repeat(train_labels, n).tolist()
train_ids=np.repeat(train_ids, n).tolist()
train_seq_ids=np.repeat(train_seq_ids, n).tolist()

print(len(train_datas))
print(len(train_datas[0]))
print(len(train_datas[0][0]))
# Apply skip gram
# Convert k-mers into embeddings
w2v_model = Word2Vec(sentences=tqdm(train_datas, desc=" Skip gram Training"), vector_size=sg_embed_size, window=sg_window, sg=1, min_count=1, workers=5)
# word_vectors = KeyedVectors.load("/kaggle/working/word2vec_vectors.kv")

train_seq_embeddings = np.array([w2v_model.wv[kmer] for kmer in tqdm(train_datas,desc="Skip gram inference")])
train_seq_embeddings = np.array(list(chain.from_iterable(train_seq_embeddings)))


96487
193
10


 Skip gram Training: 100%|██████████| 96487/96487 [00:02<00:00, 34078.43it/s]
Skip gram inference: 100%|██████████| 96487/96487 [00:33<00:00, 2868.66it/s]


In [15]:
train_seq_embeddings=torch.tensor(train_seq_embeddings).to(device)
train_ids=torch.tensor(train_ids).to(device)
train_seq_ids=torch.tensor(train_seq_ids).to(device)
train_labels=torch.tensor(train_labels).to(device)

In [16]:
train_mildataset = MILDataset(train_seq_embeddings, train_ids,train_seq_ids, train_labels)
train_loader = DataLoader(train_mildataset, batch_size=16, shuffle=True,num_workers=4, collate_fn=collate_fn)

100%|██████████| 41854/41854 [00:52<00:00, 799.26it/s]
100%|██████████| 96487/96487 [02:01<00:00, 795.42it/s]
41854it [00:04, 9279.44it/s] 


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Auto-detect GPU
model = GatedAttention(nhead,encoderNlayers,embeddingSize,intermidiateDim).to(device)  # Move model to GPU
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
criterion = nn.BCELoss().to(device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)


print(f"Using device: {device}") 
print('Start Training')
for epoch in range(1, 10+1):
    loss = train(epoch, train_loader)
    scheduler.step(loss)  # Update LR based on loss
    if scheduler.num_bad_epochs >= 5:  # Stop after 10 consecutive non-improving epochs
        print(f"Stopping early: No improvement for {scheduler.num_bad_epochs} epochs")
        break



Using device: cuda
Start Training


  return F.conv1d(
Processing Batches: 100%|██████████| 2616/2616 [09:35<00:00,  4.55it/s]


Epoch: 1, Loss: 501.5304, LR: [0.001]
acc: 830.0%
human =  19810
animal =  22044


Processing Batches: 100%|██████████| 2616/2616 [09:33<00:00,  4.56it/s]


Epoch: 2, Loss: 331.8105, LR: [0.001]
acc: 842.5%
human =  19946
animal =  21908


Processing Batches: 100%|██████████| 2616/2616 [09:34<00:00,  4.56it/s]


Epoch: 3, Loss: 290.0883, LR: [0.001]
acc: 845.3%
human =  19938
animal =  21916


Processing Batches: 100%|██████████| 2616/2616 [09:34<00:00,  4.56it/s]


Epoch: 4, Loss: 262.2073, LR: [0.001]
acc: 845.8%
human =  19925
animal =  21929


Processing Batches: 100%|██████████| 2616/2616 [09:32<00:00,  4.57it/s]


Epoch: 5, Loss: 253.3361, LR: [0.001]
acc: 848.7%
human =  19963
animal =  21891


Processing Batches: 100%|██████████| 2616/2616 [09:32<00:00,  4.57it/s]


Epoch: 6, Loss: 243.9743, LR: [0.001]
acc: 848.9%
human =  19964
animal =  21890


Processing Batches: 100%|██████████| 2616/2616 [09:30<00:00,  4.59it/s]


Epoch: 7, Loss: 236.1119, LR: [0.001]
acc: 848.2%
human =  19949
animal =  21905


Processing Batches: 100%|██████████| 2616/2616 [09:33<00:00,  4.56it/s]


Epoch: 8, Loss: 230.4979, LR: [0.001]
acc: 849.3%
human =  19967
animal =  21887


Processing Batches: 100%|██████████| 2616/2616 [09:31<00:00,  4.58it/s]


Epoch: 9, Loss: 226.4064, LR: [0.001]
acc: 849.6%
human =  19957
animal =  21897


Processing Batches: 100%|██████████| 2616/2616 [09:31<00:00,  4.57it/s]

Epoch: 10, Loss: 229.5003, LR: [0.001]
acc: 849.2%
human =  19938
animal =  21916





In [18]:
torch.save(model, "model.pth")


In [19]:
print(np.sum(train_labels.cpu().numpy() == 0))
print(np.sum(train_labels.cpu().numpy() == 1))

# print(np.sum(test_labels.cpu().numpy() == 0))
# print(np.sum(test_labels.cpu().numpy() == 1))

6820620
11801371


In [22]:
print('Start Testing on train')
out=test(train_loader)


Start Testing on train


Processing Batches: 100%|██████████| 2616/2616 [02:54<00:00, 15.02it/s]

acc: 97.8%
human =  20059
animal =  21795





In [40]:
# Apply ASW 

test_datas = [ASW(sequence,l_sub, n) for sequence in test_datas.tolist()]

test_labels= np.repeat(test_labels, n).tolist()
test_ids=np.repeat(test_ids, n).tolist()
test_seq_ids=np.repeat(test_seq_ids, n).tolist()


# Apply skip gram
keys_wv=set(list(w2v_model.wv.key_to_index.keys()))
   
# Convert k-mers into embeddings
test_seq_embeddings = np.array([
    w2v_model.wv[k] if k in keys_wv else np.zeros(30)
    for kmer in tqdm(test_datas, desc="Skip gram inference")  
    for k in kmer  # kmer should be defined first
])


Skip gram inference:   0%|          | 0/96394 [00:00<?, ?it/s]

KeyError: "Key 'LKGIAPLQLR' not present"

In [45]:
sety=set()
for kmer in tqdm(test_datas, desc="Skip gram inference"):  
    for k in kmer:  # kmer should be defined first
        if k not in keys_wv:
            sety.add(k)


print(len(sety))
print(len(np.unique(test_datas)))


Skip gram inference:   0%|          | 0/96394 [00:00<?, ?it/s][A
Skip gram inference:   6%|▌         | 5871/96394 [00:00<00:01, 58702.78it/s][A
Skip gram inference:  12%|█▏        | 11742/96394 [00:00<00:01, 56536.90it/s][A
Skip gram inference:  18%|█▊        | 17402/96394 [00:00<00:01, 54490.78it/s][A
Skip gram inference:  24%|██▍       | 23286/96394 [00:00<00:01, 56145.76it/s][A
Skip gram inference:  30%|███       | 28937/96394 [00:00<00:01, 56270.09it/s][A
Skip gram inference:  36%|███▌      | 34572/96394 [00:00<00:01, 55609.49it/s][A
Skip gram inference:  42%|████▏     | 40518/96394 [00:00<00:00, 56843.28it/s][A
Skip gram inference:  48%|████▊     | 46658/96394 [00:00<00:00, 58274.15it/s][A
Skip gram inference:  55%|█████▍    | 52645/96394 [00:00<00:00, 58766.26it/s][A
Skip gram inference:  61%|██████    | 58825/96394 [00:01<00:00, 59694.62it/s][A
Skip gram inference:  68%|██████▊   | 65123/96394 [00:01<00:00, 60692.40it/s][A
Skip gram inference:  74%|███████▍  | 71445

115011
291204


In [37]:
test_seq_embeddings=torch.tensor(test_seq_embeddings).to(device)
test_ids=torch.tensor(test_ids).to(device)
test_seq_ids=torch.tensor(test_seq_ids).to(device)
test_labels=torch.tensor(test_labels).to(device)

In [38]:
test_mildataset = MILDataset(test_seq_embeddings, test_ids,test_seq_ids, test_labels)
test_loader = DataLoader(test_mildataset, batch_size=64, shuffle=True,num_workers=0, collate_fn=collate_fn)

100%|██████████| 41812/41812 [00:53<00:00, 788.79it/s]
100%|██████████| 96394/96394 [02:01<00:00, 791.83it/s]
41812it [00:04, 10362.60it/s]


In [39]:
print('Start Testing on test data')
out2=test(test_loader)


Start Testing on test data


Processing Batches: 100%|██████████| 654/654 [07:18<00:00,  1.49it/s]

acc: 97.6%
human =  20184
animal =  21628



