In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from gensim.models import FastText
import pandas as pd



from data_preprocessing import *
from model_utilities import *
from CONSTANTS import *
from ATT_MIL import *
from utilities import *

torch.manual_seed(42)
np.random.seed(42)



  from .autonotebook import tqdm as notebook_tqdm


### data preparing

In [None]:
gasaid_test=['./gasaid-viruses/GASAID_test.fasta']
gasaid_df=read_data_from_file(gasaid_test)

gasaid_datas = gasaid_df['Sequence'].values  
gasaid_ids = gasaid_df['ID'].values  

gasaid_list_data,gasaid_list_id=remove_duplicateSeq(gasaid_datas,gasaid_ids)
gasaid_list_data,gasaid_list_id=remove_duplicateIds(gasaid_list_data,gasaid_list_id)
gasaid_list_data,gasaid_list_id=remove_x_seq(gasaid_list_data,gasaid_list_id)
gasaid_list_data,gasaid_list_id=remove_incomplete_seq(gasaid_list_data,gasaid_list_id)


gasaid_test_df = pd.DataFrame({"ID":gasaid_list_id, "Data":gasaid_list_data})
gasaid_test_df.to_csv("./processed-data-csv/gasaid_cleaned_test_data.csv", index=False)
print(f"test gasaid Saved")

6387it [00:00, 235750.35it/s]


test gasaid Saved


In [None]:
file_names_gasaid_datas=["./gasaid-viruses/africa.fasta",
                "./gasaid-viruses/asia.fasta",
                "./gasaid-viruses/europe.fasta",
                "./gasaid-viruses/north.fasta",
                "./gasaid-viruses/south.fasta"
             ]

df=read_data_from_file(file_names_gasaid_datas)
# train only
datas = df['Sequence'].values  
ids = df['ID'].values  

list_data,list_id=remove_duplicateSeq(datas,ids)
list_data,list_id=remove_duplicateIds(list_data,list_id)
list_data,list_id=remove_x_seq(list_data,list_id)
list_data,list_id=remove_incomplete_seq(list_data,list_id)
list_data,list_id=remove_similar_sequences(list_data,list_id)

gasaid_data = pd.DataFrame({"ID": list_id, "Data": list_data})
gasaid_data.to_csv("./processed-data-csv/gasaid_cleaned_train_data.csv", index=False)

print(f"train gasaid Saved")



43192it [00:00, 263211.59it/s]
100%|██████████| 7463/7463 [01:38<00:00, 75.68it/s] 
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 8/8 [00:00<00:00, 511.98it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]
100%|██████████| 7/7 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]
100%|██████████| 12/12 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]
100%|██████████| 30/30 [00:00<?, ?it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 8/8 [00:00<00:00, 28244.47it/s]
100%|██████████| 26/26 [00:00<?, ?it/s]
100%|██████████| 47/47 [00:00<00:00, 3016.65it/s]
100%|██████████| 6/6 [00:00<00:00, 1969.47it/s]
100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 35/35 [00:00<00:00, 2783.64it/s]
100%|██████████| 25/25 [00:00<00:00, 31998.05it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██

train gasaid Saved


In [None]:
file_names_ncbi_datas=["./ncbi-viruses/Australia.fa",
                "./ncbi-viruses/africa.fa",
                "./ncbi-viruses/asia.fa",
                "./ncbi-viruses/europe.fa",
                "./ncbi-viruses/north.fa",
                "./ncbi-viruses/south.fa"
             ]

df=read_data_from_file(file_names_ncbi_datas)

# train only
datas = df['Sequence'].values  
ids = df['ID'].values  

list_data,list_id=remove_duplicateSeq(datas,ids)
list_data,list_id=remove_duplicateIds(list_data,list_id)
list_data,list_id=remove_x_seq(list_data,list_id)
list_data,list_id=remove_incomplete_seq(list_data,list_id)

df = pd.DataFrame({"ID": list_id, "Data": list_data})
df["year"] = df["ID"].apply(lambda x: x.split("|")[-2])

# split test 23-24-25
# Filter rows where year > 2023
test_df = df[df['year'] >= '2023']
train_df = df[df['year'] < '2023']

list_data,list_id=remove_similar_sequences(train_df['Data'].values,train_df['ID'].values)

new_train_df = pd.DataFrame({"ID": list_id, "Data": list_data})
new_train_df.to_csv("./processed-data-csv/ncbi_cleaned_train_data.csv", index=False)
print(f"ncbi Train Saved")

test_df.to_csv("./processed-data-csv/ncbi_cleaned_test_data.csv", index=False)
print(f"ncbi Test Saved")


195824it [00:01, 128779.14it/s]
100%|██████████| 22391/22391 [16:48<00:00, 22.21it/s]
100%|██████████| 422/422 [00:01<00:00, 361.75it/s] 
100%|██████████| 1783/1783 [00:18<00:00, 95.70it/s] 
100%|██████████| 2565/2565 [00:37<00:00, 68.56it/s] 
100%|██████████| 2424/2424 [00:27<00:00, 88.53it/s] 
100%|██████████| 456/456 [00:00<00:00, 772.12it/s] 
100%|██████████| 986/986 [00:05<00:00, 180.22it/s] 
100%|██████████| 969/969 [00:04<00:00, 198.06it/s] 
100%|██████████| 25/25 [00:00<00:00, 4571.75it/s]
100%|██████████| 80/80 [00:00<00:00, 1825.71it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]
100%|██████████| 4/4 [00:00<00:00, 3820.82it/s]
100%|██████████| 39/39 [00:00<00:00, 2928.98it/s]
100%|██████████| 26/26 [00:00<00:00, 25593.03it/s]
100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 7/7 [00:00<00:00, 1158.28it/s]
100%|██████████| 6/6 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<?, ?it/s]
100%|██████████| 9/9 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 1241.10it/s]
100%|███████

ncbi Train Saved
ncbi Test Saved


### Training

In [None]:
file_name_ncbi_datas="./processed-data-csv/ncbi_cleaned_train_data.csv"
file_name_gasaid_datas="./processed-data-csv/gasaid_cleaned_train_data.csv"


df_gasaid=read_data_from_csv(file_name_gasaid_datas)
df_ncbi=read_data_from_csv(file_name_ncbi_datas)


df = pd.concat([df_gasaid, df_ncbi], ignore_index=True)

get_lsub_sequence(df)



llongest 775
lshortest 201


In [None]:
df["Class"] = df["Class"].str.lower()  
labels = np.array((df["Class"] != "human").astype(int))  ## human->0  animal->1

ids=df["Virus_ID"]
seq_ids=df["Seq_ID"]+" "+df["Virus_ID"]

# convert string id to numeric
_,ids = np.unique(ids, return_inverse=True)
_,seq_ids = np.unique(seq_ids, return_inverse=True)

In [8]:
datas=df["Sequence"]
# Get unique bag IDs
unique_bag_ids = np.unique(ids)

# Split bag IDs into train and test
train_ids, val_ids = train_test_split(unique_bag_ids, test_size=0.2, random_state=42)

# Get indices corresponding to train/test bag IDs
train_indices = np.where(np.isin(ids, train_ids))[0]
val_indices = np.where(np.isin(ids, val_ids))[0]

# # Create train data
train_datas = datas[train_indices]
train_ids = ids[train_indices]
train_seq_ids = seq_ids[train_indices]
train_labels = labels[train_indices]



# # Create val data
val_datas = datas[val_indices]
val_ids = ids[val_indices]
val_seq_ids = seq_ids[val_indices]
val_labels = labels[val_indices]

print("length of train ",train_datas.shape)
print("length of validation ",val_datas.shape)


length of train  (76702,)
length of validation  (19328,)


In [11]:
train_token_datas = [ASW(sequence,CONSTANTS.l_sub) for sequence in train_datas.tolist()]

# Train FastText (similar API to Word2Vec)
ft_model = FastText(
    sentences=tqdm(train_token_datas, desc="FastText Training"),
    vector_size=SG_EMBEDD_SIZE,
    window=SG_WINDOW,
    sg=0,  # 1 = skip-gram, 0 = CBOW
    min_count=1,
    workers=5,
    epochs=10
)

train_seq_embeddings = np.array([ft_model.wv[kmer] for kmer in tqdm(train_datas, desc="FastText inference")])


FastText Training: 100%|██████████| 76702/76702 [00:05<00:00, 14743.42it/s]
FastText inference: 100%|██████████| 76702/76702 [05:49<00:00, 219.74it/s]


In [None]:
train_loader = create_data_loader(train_datas, train_labels,train_ids,train_seq_ids,ft_model)
val_loader = create_data_loader(val_datas, val_labels,val_ids,val_seq_ids,ft_model)

FastText inference: 100%|██████████| 76702/76702 [00:36<00:00, 2087.04it/s]
100%|██████████| 41707/41707 [21:55<00:00, 31.71it/s]
 27%|██▋       | 20621/76702 [13:03<41:48, 22.36it/s]  

KeyboardInterrupt: 

 27%|██▋       | 20621/76702 [13:15<41:48, 22.36it/s]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
model = ATT_MIL(N_HEAD,ENCODER_N_LAYERS,EMBEDDING_SIZE,INTERMIDIATE_DIM).to(device)  
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
criterion = nn.BCELoss().to(device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)


print(f"Using device: {device}") 
print('Start Training')
for epoch in range(1, 10+1):
    loss = train(epoch, train_loader)
    scheduler.step(loss)  # Update LR based on loss
    if scheduler.num_bad_epochs >= 5:  # Stop after 10 consecutive non-improving epochs
        print(f"Stopping early: No improvement for {scheduler.num_bad_epochs} epochs")
        break

In [None]:
torch.save(model.state_dict(), "./models/model_weights.pth")
ft_model.save("./models/ft_skipgram.model")