In [None]:
import warnings
warnings.filterwarnings('ignore')
!pip install biopython obonet --quiet
!pip install transformers biopython --quiet

In [None]:
import torch
from transformers import EsmTokenizer, EsmModel
from Bio import SeqIO
import os
from tqdm.auto import tqdm
import gc
import numpy as np

# --- Danh sách Top Taxa ---
top_taxa = ["9606", "10090", "3702", "559292", "10116", "284812", 
            "83333", "7227", "6239", "83332"]
taxon_to_index_top = {taxon: i for i, taxon in enumerate(top_taxa)}
others_index = len(top_taxa)
num_taxon_top = len(top_taxa) + 1  # 11 chiều

# --- Hàm tạo vector taxonomy one-hot ---
def prot_taxon_onehot(taxon_id):
    vec = torch.zeros(num_taxon_top, dtype=torch.float32)
    if taxon_id in taxon_to_index_top:
        vec[taxon_to_index_top[taxon_id]] = 1
    else:
        vec[others_index] = 1
    return vec

# --- Load ESM-2 ---
model_name = "facebook/esm2_t30_150M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model_esm2 = EsmModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_esm2 = model_esm2.to(device)
model_esm2.eval()
EMBEDDING_DIM = model_esm2.config.hidden_size
print(f"ESM embedding dim: {EMBEDDING_DIM}, Device: {device}")

# --- Parse FASTA ---
fasta_file = "/kaggle/input/train-test-cafa6-v2/Test/testsuperset.fasta"
prot_to_seq = {}
prot_to_taxon = {}

for record in SeqIO.parse(fasta_file, "fasta"):
    # Header dạng: >ProteinID TaxonID
    parts = record.id.split()
    pid = parts[0]
    taxon_id = parts[1] if len(parts) > 1 else None
    prot_to_seq[pid] = str(record.seq)
    prot_to_taxon[pid] = taxon_id

print(f"Total proteins in test set: {len(prot_to_seq)}")

# --- Hàm process batch ---
def process_and_embed_batch(prot_ids, max_length=1024):
    features_list = []
    for pid in prot_ids:
        seq = prot_to_seq[pid]
        taxon_id = prot_to_taxon.get(pid, None)
        
        # Chunk nếu quá dài
        chunks = [seq[i:i+max_length] for i in range(0, len(seq), max_length)]
        chunk_embeddings = []

        for chunk in chunks:
            tokens = tokenizer(
                chunk,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
                add_special_tokens=False
            )
            input_ids = tokens['input_ids'].to(device)
            attention_mask = tokens['attention_mask'].to(device)

            with torch.no_grad():
                output = model_esm2(input_ids=input_ids, attention_mask=attention_mask)
                emb = output.last_hidden_state[:, 0, :]  # CLS token
                chunk_embeddings.append(emb.squeeze(0))

            del input_ids, attention_mask, output
            torch.cuda.empty_cache()

        seq_emb = torch.stack(chunk_embeddings, dim=0).mean(dim=0)
        taxon_vec = prot_taxon_onehot(taxon_id).to(device)
        features = torch.cat([seq_emb, taxon_vec], dim=0)
        features_list.append((pid, features.cpu()))
        
        del chunk_embeddings, seq_emb, taxon_vec
        torch.cuda.empty_cache()

    return features_list

# --- Process tất cả protein với progress bar ---
all_prot_ids = list(prot_to_seq.keys())
final_features_list = []

pbar = tqdm(total=len(all_prot_ids), desc="Processing proteins", unit="protein")
x = 0
for pid in all_prot_ids:
    if x % 2000 == 0:
        print(x)
    batch_feat = process_and_embed_batch([pid])
    final_features_list.extend(batch_feat)
    pbar.update(1)

# --- Lưu ra file npy ---
protein_ids = [item[0] for item in final_features_list]
feature_tensors = [item[1] for item in final_features_list]
X_test = torch.stack(feature_tensors).numpy()
protein_ids_test = np.array(protein_ids)

output_dir = "/kaggle/working/"
os.makedirs(output_dir, exist_ok=True)

np.save(os.path.join(output_dir, "X_test.npy"), X_test)
np.save(os.path.join(output_dir, "protein_ids_test.npy"), protein_ids_test)

print(f"✓ X_test shape: {X_test.shape}")
print(f"✓ protein_ids_test shape: {protein_ids_test.shape}")
print("Saved X_test.npy and protein_ids_test.npy")
