In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import gc
import os

from sklearn.metrics import precision_recall_fscore_support

def compute_fmax(y_true, y_pred_probs, steps=20):
    """
    Tính F-max bằng cách quét nhiều ngưỡng (threshold) khác nhau.
    y_true: Ma trận nhãn thực tế (0 hoặc 1)
    y_pred_probs: Ma trận xác suất dự đoán (0.0 đến 1.0)
    steps: Số lượng ngưỡng để thử (20 steps nghĩa là thử 0.05, 0.10, ... 0.95)
    """
    print("  > Calculating F-max...")
    best_f1 = 0.0
    best_threshold = 0.0
    
    # Chỉ quét các ngưỡng từ 0.05 đến 0.95
    thresholds = np.linspace(0.01, 0.99, steps)
    
    for t in thresholds:
        # Chuyển xác suất thành nhãn 0/1 dựa trên ngưỡng t
        y_pred_binary = (y_pred_probs >= t).astype(int)
        
        # Tính F1 (average='micro' hoặc 'samples' thường dùng cho multi-label)
        # Trong CAFA chuẩn dùng weighted, ở đây dùng 'samples' để ước lượng nhanh
        p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred_binary, average='samples', zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t
            
    print(f"  [RESULT] Best Threshold: {best_threshold:.2f} | Validation F-Max: {best_f1:.4f}")
    return best_f1, best_threshold

# ==========================================
# 1. CẤU HÌNH 
# ==========================================
class Config:
    # Chỉ dùng ESM-2
    ESM_DIR = '/kaggle/input/cafa6-protein-embeddings-esm2'
    
    TRAIN_TERMS = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
    TEST_FASTA = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
    
    # Homology 
    HOMOLOGY_1 = '/kaggle/input/foldseek-blastp-parthenos/submission.tsv'
    HOMOLOGY_2 = '/kaggle/input/foldseek-cafa/foldseek_submission.tsv'
    
    NUM_LABELS = 1500
    BATCH_SIZE = 64      
    LR = 0.001
    EPOCHS = 2           
    DEVICE = torch.device("cpu")

print(f"Running optimized pipeline on: {Config.DEVICE}")

# ==========================================
# 2. MODEL & DATA UTILS
# ==========================================
class ResidualBlock(nn.Module):
    def __init__(self, in_features, hidden_features, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_features, in_features),
            nn.BatchNorm1d(in_features)
        )
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.relu(x + self.block(x))

class ProteinClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.bn_input = nn.BatchNorm1d(input_dim)
        # Giảm kích thước mạng một chút để nhẹ CPU
        self.layer1 = nn.Linear(input_dim, 512) 
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.res1 = ResidualBlock(512, 256)
        self.layer_out = nn.Linear(512, num_classes)
        
    def forward(self, x):
        x = self.bn_input(x)
        x = self.dropout(self.relu(self.layer1(x)))
        x = self.res1(x)
        return self.layer_out(x)

def get_label_matrix(train_ids):
    print("  > Creating Labels...")
    df_terms = pd.read_csv(Config.TRAIN_TERMS, sep="\t")
    top_terms = df_terms['term'].value_counts().index[:Config.NUM_LABELS].tolist()
    term_to_idx = {term: i for i, term in enumerate(top_terms)}
    
    pid_to_idx = {pid: i for i, pid in enumerate(train_ids)}
    y_data = np.zeros((len(train_ids), Config.NUM_LABELS), dtype=np.float32)
    
    # Filter & Fill
    df_filtered = df_terms[df_terms['term'].isin(top_terms) & df_terms['EntryID'].isin(train_ids)]
    
    # Dùng numpy indexing cho nhanh
    term_indices = df_filtered['term'].map(term_to_idx).dropna().astype(int).values
    pid_indices = df_filtered['EntryID'].map(pid_to_idx).dropna().astype(int).values
    
    # Đảm bảo độ dài khớp nhau (đôi khi map sinh ra NaN)
    valid_len = min(len(term_indices), len(pid_indices))
    y_data[pid_indices[:valid_len], term_indices[:valid_len]] = 1.0
    
    return y_data, term_to_idx

# ==========================================
# CẬP NHẬT HÀM TRAIN & PREDICT
# ==========================================
def train_and_predict_with_score(X, y, X_test, model_name):
    print(f"\n>>> TRAINING {model_name} WITH VALIDATION <<<")
    
    # 1. Chia tập Train thành Train (90%) và Val (10%) để tính điểm
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=42)
    
    print(f"  Train size: {len(X_train)} | Val size: {len(X_val)}")
    
    # Dataset
    train_ds = torch.utils.data.TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train).float())
    # Không cần shuffle cho Val và Test
    val_ds = torch.utils.data.TensorDataset(torch.tensor(X_val).float()) 
    test_ds = torch.utils.data.TensorDataset(torch.tensor(X_test).float())
    
    # Dataloader (CPU optimized)
    train_dl = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=0)
    val_dl = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=0)
    test_dl = DataLoader(test_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Model Setup
    model = ProteinClassifier(X_train.shape[1], Config.NUM_LABELS).to(Config.DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)
    criterion = nn.BCEWithLogitsLoss()
    
    # --- TRAINING LOOP ---
    for epoch in range(Config.EPOCHS):
        model.train()
        train_loss = 0
        for inputs, targets in train_dl:
            inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        print(f"  Epoch {epoch+1}: Loss {train_loss/len(train_dl):.4f}")
        
    # --- VALIDATION & SCORING ---
    print("  > Validating...")
    model.eval()
    val_preds = []
    with torch.no_grad():
        for (inputs,) in val_dl:
            inputs = inputs.to(Config.DEVICE)
            logits = model(inputs)
            val_preds.append(torch.sigmoid(logits).cpu().numpy())
    
    val_preds = np.vstack(val_preds)
    
    # Gọi hàm tính F-max tại đây
    fmax, best_thresh = compute_fmax(y_val, val_preds)
    
    # --- PREDICTION ON TEST SET ---
    print("  > Predicting on Test Set...")
    test_preds = []
    with torch.no_grad():
        for (inputs,) in tqdm(test_dl, desc="Inference"):
            inputs = inputs.to(Config.DEVICE)
            logits = model(inputs)
            test_preds.append(torch.sigmoid(logits).cpu().numpy())
            
    return np.vstack(test_preds)

# ==========================================
# 3. PIPELINE CHÍNH (ESM-2 ONLY)
# ==========================================

# A. Lấy danh sách ID cần dự đoán
print(">>> Loading Target IDs...")
test_ids_list = []
with open(Config.TEST_FASTA, 'r') as f:
    for line in f:
        if line.startswith('>'):
            test_ids_list.append(line.strip()[1:].split()[0])
print(f"Total Targets: {len(test_ids_list)}")

# B. Load ESM-2 Data
print(">>> Loading ESM-2 Data...")
# Load IDs an toàn
try:
    df_ids = pd.read_csv(os.path.join(Config.ESM_DIR, "protein_ids.csv"))
    all_pids = df_ids["protein_id"].tolist()
except:
    all_pids_raw = np.load(os.path.join(Config.ESM_DIR, "protein_ids.npy"), allow_pickle=True)
    all_pids = [i.decode('utf-8') if isinstance(i, (bytes, np.bytes_)) else i for i in all_pids_raw]

# Load Embeddings (mmap)
all_embeds = np.load(os.path.join(Config.ESM_DIR, "protein_embeddings.npy"), mmap_mode='r')
pid_to_idx_map = {pid: i for i, pid in enumerate(all_pids)}

# C. Chuẩn bị Train/Test Arrays
print(">>> Preparing Train/Test Matrices...")
# Xác định Train IDs
df_terms = pd.read_csv(Config.TRAIN_TERMS, sep="\t")
train_pids_target = set(df_terms['EntryID'].unique())
valid_train = [pid for pid in train_pids_target if pid in pid_to_idx_map]

print(f"  Valid Train Proteins: {len(valid_train)}")

# X_train
train_indices = [pid_to_idx_map[pid] for pid in valid_train]
X_train = np.array([all_embeds[i] for i in train_indices])
y_train, term_map = get_label_matrix(valid_train)

# X_test
emb_dim = X_train.shape[1]
X_test = np.zeros((len(test_ids_list), emb_dim), dtype=np.float32)
for i, pid in enumerate(test_ids_list):
    if pid in pid_to_idx_map:
        X_test[i] = all_embeds[pid_to_idx_map[pid]]

# D. Train & Predict
preds_dl = train_and_predict_with_score(X_train, y_train, X_test, "ESM-2")

# Dọn dẹp RAM ngay
del X_train, y_train, all_embeds, df_terms, train_indices
gc.collect()

# E. Format DL Results
print(">>> Formatting Results...")
idx_to_term = {i: t for t, i in term_map.items()}
dl_results = []

for i, pid in enumerate(tqdm(test_ids_list)):
    scores = preds_dl[i]
    # Lấy top scores > 0.01
    indices = np.where(scores > 0.01)[0]
    for idx in indices:
        dl_results.append((pid, idx_to_term[idx], scores[idx]))

df_dl = pd.DataFrame(dl_results, columns=['Id', 'Term', 'Score'])
del preds_dl
gc.collect()

# ==========================================
# 4. HOMOLOGY MERGE & SAVE
# ==========================================
print(">>> Merging with Homology Data...")

try:
    # Load Homology
    df_hom1 = pd.read_csv(Config.HOMOLOGY_1, sep='\t', header=None, names=['Id', 'Term', 'Score_1'])
    df_hom2 = pd.read_csv(Config.HOMOLOGY_2, sep='\t', header=None, names=['Id', 'Term', 'Score_2'])
    
    # Merge Homology (Max)
    df_hom_merged = pd.merge(df_hom1, df_hom2, on=['Id', 'Term'], how='outer').fillna(0)
    df_hom_merged['Score_Homology'] = np.maximum(df_hom_merged['Score_1'], df_hom_merged['Score_2'])
    df_hom_merged = df_hom_merged[['Id', 'Term', 'Score_Homology']]
    
    # Merge with DL
    print("  Final Join...")
    df_final = pd.merge(df_dl, df_hom_merged, on=['Id', 'Term'], how='outer').fillna(0)
    
    # Logic: Max(DL, Homology)
    df_final['Final_Score'] = np.maximum(df_final['Score'], df_final['Score_Homology'])
    
except Exception as e:
    print(f"Warning: Homology merge failed ({e}). Saving DL only.")
    df_final = df_dl
    df_final['Final_Score'] = df_final['Score']

# Save
print(">>> Saving submission.tsv...")
submission = df_final[['Id', 'Term', 'Final_Score']]
submission.columns = ['Id', 'Term', 'Score']
submission = submission[submission['Score'] > 0.01]
submission['Score'] = submission['Score'].round(3)

submission.to_csv('submission.tsv', sep='\t', header=False, index=False)
print("Done!")

Running optimized pipeline on: cpu
>>> Loading Target IDs...
Total Targets: 224309
>>> Loading ESM-2 Data...
>>> Preparing Train/Test Matrices...
  Valid Train Proteins: 82404
  > Creating Labels...

>>> TRAINING ESM-2 WITH VALIDATION <<<
  Train size: 74163 | Val size: 8241
  Epoch 1: Loss 0.0162
  Epoch 2: Loss 0.0123
  > Validating...
  > Calculating F-max...
  [RESULT] Best Threshold: 0.22 | Validation F-Max: 0.2989
  > Predicting on Test Set...


Inference: 100%|██████████| 3505/3505 [00:13<00:00, 259.83it/s]


>>> Formatting Results...


100%|██████████| 224309/224309 [00:08<00:00, 25816.83it/s]


>>> Merging with Homology Data...
  Final Join...
>>> Saving submission.tsv...
Done!
