In [1]:
import os
import sys
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchmetrics.classification import MultilabelF1Score, MultilabelAccuracy
from tqdm.auto import tqdm
import networkx as nx       # D√πng cho GOPropagator
from collections import defaultdict # D√πng cho NegativePropagator

# T·∫Øt c·∫£nh b√°o kh√¥ng c·∫ßn thi·∫øt (t√πy ch·ªçn)
import warnings
warnings.filterwarnings('ignore')

print("‚úÖ Libraries imported successfully!")

‚úÖ Libraries imported successfully!


# CONFIG

In [2]:
class Config:
    # --- Control Flags ---
    IS_TRAINING = False       # Set True ƒë·ªÉ train l·∫°i 3 model, False ƒë·ªÉ n·ªôp b√†i
    FORCE_RERUN_PREDICT = False 
    
    # --- Paths ---
    ESM_DIR = '/kaggle/input/cafa6-protein-embeddings-esm2'
    TRAIN_TERMS = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
    TEST_FASTA = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
    GO_OBO = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
    GOA_FILE = '/kaggle/input/protein-go-annotations/goa_uniprot_all.csv'
    HOMOLOGY_FILE = '/kaggle/input/foldseek-cafa/foldseek_submission.tsv'
    
    # --- Model Checkpoints (Multi-Ontology) ---
    # Thay ƒë·ªïi path n√†y tr·ªè ƒë·∫øn dataset ch·ª©a 3 model sau khi b·∫°n train xong
    CKPT_BPO = '/kaggle/input/multi-ontology/pytorch/default/1/model_BPO.pth'
    CKPT_CCO = '/kaggle/input/multi-ontology/pytorch/default/1/model_CCO.pth'
    CKPT_MFO = '/kaggle/input/multi-ontology/pytorch/default/1/model_MFO.pth'
    
    # --- Model Configs (Customize per Aspect) ---
    EMBED_DIM = 1280
    BATCH_SIZE = 128
    LR = 1e-3
    EPOCHS = 10
    
    # S·ªë l∆∞·ª£ng nh√£n cho t·ª´ng lo·∫°i (BPO th∆∞·ªùng nhi·ªÅu nh√£n h∆°n)
    NUM_LABELS_BPO = 1500
    NUM_LABELS_CCO = 800
    NUM_LABELS_MFO = 800
    
    # --- Post-processing ---
    THRESHOLD = 0.005 # H·∫° th·∫•p threshold m·ªôt ch√∫t v√¨ ch√∫ng ta s·∫Ω l·ªçc k·ªπ h∆°n
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"‚öôÔ∏è Running on device: {Config.DEVICE}")

‚öôÔ∏è Running on device: cpu


In [3]:
def get_term_aspect_map(obo_path):
    """H√†m ƒë·ªçc file OBO ƒë·ªÉ bi·∫øt GO Term n√†o thu·ªôc nh√≥m n√†o (BPO/CCO/MFO)"""
    print(f"üìö Parsing Aspect (Namespace) from {obo_path}...")
    term_to_aspect = {}
    current_term = None
    
    # Map namespace trong OBO sang t√™n vi·∫øt t·∫Øt chu·∫©n
    ns_map = {
        'biological_process': 'BPO',
        'cellular_component': 'CCO',
        'molecular_function': 'MFO'
    }
    
    with open(obo_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith("[Term]"):
                current_term = None
            elif line.startswith("id: GO:"):
                current_term = line[4:]
            elif line.startswith("namespace:") and current_term:
                ns = line.split()[1]
                if ns in ns_map:
                    term_to_aspect[current_term] = ns_map[ns]
    
    print(f"   Mapped {len(term_to_aspect)} terms to aspects.")
    return term_to_aspect

# DATASET

In [4]:
class ProteinDataset(Dataset):
    def __init__(self, embeddings, labels=None):
        """
        embeddings: numpy array ho·∫∑c tensor (N_samples, Embed_Dim)
        labels: numpy array ho·∫∑c tensor (N_samples, Num_Labels) - Optional
        """
        # Chuy·ªÉn ƒë·ªïi sang FloatTensor ngay t·ª´ ƒë·∫ßu ƒë·ªÉ ƒë·ª° t·ªën c√¥ng trong loop
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        
        if labels is not None:
            self.labels = torch.tensor(labels, dtype=torch.float32)
        else:
            self.labels = None
            
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        # Tr·∫£ v·ªÅ c·∫∑p (X, y) ho·∫∑c (X)
        if self.labels is not None:
            return self.embeddings[idx], self.labels[idx]
        return self.embeddings[idx]

# MODEL

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features, hidden_features, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_features, in_features),
            nn.BatchNorm1d(in_features)
        )
        self.relu = nn.ReLU()
    def forward(self, x): return self.relu(x + self.block(x))

class ProteinClassifier(nn.Module):
    def __init__(self, input_dim=1280, num_classes=1500):
        super().__init__()
        self.bn_input = nn.BatchNorm1d(input_dim)
        self.layer1 = nn.Linear(input_dim, 512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.res_block = ResidualBlock(512, 256)
        self.out = nn.Linear(512, num_classes)
    def forward(self, x):
        x = self.bn_input(x)
        x = self.dropout(self.relu(self.layer1(x)))
        x = self.res_block(x)
        return self.out(x)

# DATA MANAGER

In [6]:
class CAFA6DataManager:
    def __init__(self, config):
        self.cfg = config
        self.term_to_idx = {} 
        self.pid_to_emb_idx = {}
        self.all_embeds = None
        self.test_ids = []
        # Load map aspect ngay khi kh·ªüi t·∫°o
        self.term_aspect_map = get_term_aspect_map(config.GO_OBO)

    def _load_safe_ids(self, path):
        try:
            ids = np.load(path, allow_pickle=True)
            if len(ids) > 0 and isinstance(ids[0], (bytes, np.bytes_)):
                return [i.decode('utf-8') for i in ids]
            return ids.tolist()
        except: return []

    def load_embeddings(self):
        print("üì• Loading Embeddings source...")
        try: 
            all_pids = pd.read_csv(os.path.join(self.cfg.ESM_DIR, "protein_ids.csv"))['protein_id'].tolist()
        except: 
            all_pids = self._load_safe_ids(os.path.join(self.cfg.ESM_DIR, "protein_ids.npy"))
            
        self.all_embeds = np.load(os.path.join(self.cfg.ESM_DIR, "protein_embeddings.npy"), mmap_mode='r')
        self.pid_to_emb_idx = {pid: i for i, pid in enumerate(all_pids)}
        print(f"   Loaded {len(all_pids)} protein embeddings.")

    def _get_labels_matrix(self, valid_pids, target_aspect, num_labels):
        """T·∫°o matrix Y ch·ªâ ch·ª©a c√°c nh√£n thu·ªôc Aspect ƒëang train"""
        print(f"   Processing Labels Matrix for aspect: {target_aspect}...")
        df = pd.read_csv(self.cfg.TRAIN_TERMS, sep="\t")
        
        # 1. Map term sang aspect v√† L·ªçc
        df['aspect'] = df['term'].map(self.term_aspect_map)
        df = df[df['aspect'] == target_aspect]
        
        if df.empty:
            raise ValueError(f"Kh√¥ng t√¨m th·∫•y d·ªØ li·ªáu cho aspect {target_aspect}. Ki·ªÉm tra OBO parser.")

        # 2. L·∫•y Top Labels c·ªßa Aspect ƒë√≥
        top_terms = df['term'].value_counts().index[:num_labels].tolist()
        self.term_to_idx = {t: i for i, t in enumerate(top_terms)}
        
        # 3. T·∫°o Matrix Y
        pid_to_idx = {pid: i for i, pid in enumerate(valid_pids)}
        y = np.zeros((len(valid_pids), num_labels), dtype=np.float32)
        
        df = df[df['term'].isin(top_terms) & df['EntryID'].isin(valid_pids)]
        row_idx = df['EntryID'].map(pid_to_idx)
        col_idx = df['term'].map(self.term_to_idx)
        valid = row_idx.notna() & col_idx.notna()
        
        y[row_idx[valid].astype(int), col_idx[valid].astype(int)] = 1.0
        return y

    def prepare_train_loaders(self, aspect, num_labels):
        """Chu·∫©n b·ªã DataLoader ri√™ng cho t·ª´ng Aspect"""
        if self.all_embeds is None: self.load_embeddings()

        print(f"üõ†Ô∏è Preparing Training Data for [{aspect}]...")
        
        # Filter Valid Train IDs
        df_terms = pd.read_csv(self.cfg.TRAIN_TERMS, sep="\t", usecols=['EntryID'])
        train_targets = set(df_terms['EntryID'].unique())
        valid_pids = [p for p in train_targets if p in self.pid_to_emb_idx]
        
        # Construct X
        train_emb_indices = [self.pid_to_emb_idx[p] for p in valid_pids]
        X = np.array([self.all_embeds[i] for i in train_emb_indices])
        
        # Construct y (theo aspect)
        y = self._get_labels_matrix(valid_pids, aspect, num_labels)
        
        # Split & Loaders
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=42)
        
        train_ds = ProteinDataset(X_train, y_train)
        val_ds = ProteinDataset(X_val, y_val)
        
        train_dl = DataLoader(train_ds, batch_size=self.cfg.BATCH_SIZE, shuffle=True, num_workers=0)
        val_dl = DataLoader(val_ds, batch_size=self.cfg.BATCH_SIZE, shuffle=False, num_workers=0)
        
        del X, y, X_train, y_train, X_val, y_val
        gc.collect()
        return train_dl, val_dl
    
    def prepare_test_loader(self):
        # Gi·ªØ nguy√™n code c≈©
        if self.all_embeds is None: self.load_embeddings()

        print(f"üìñ Reading Test Fasta: {self.cfg.TEST_FASTA}")
        self.test_ids = []
        with open(self.cfg.TEST_FASTA, 'r') as f:
            for line in f:
                if line.startswith('>'): self.test_ids.append(line.strip()[1:].split()[0])
        
        print("   Constructing Test Matrix...")
        X_test = np.zeros((len(self.test_ids), self.cfg.EMBED_DIM), dtype=np.float32)
        
        found = 0
        for i, pid in enumerate(self.test_ids):
            if pid in self.pid_to_emb_idx:
                X_test[i] = self.all_embeds[self.pid_to_emb_idx[pid]]
                found += 1
        print(f"   Found vectors for {found}/{len(self.test_ids)} test proteins.")     
        test_dl = DataLoader(ProteinDataset(X_test), batch_size=self.cfg.BATCH_SIZE, shuffle=False, num_workers=0)
        return test_dl

# TRAINER 

In [7]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_dl = train_loader
        self.val_dl = val_loader
        self.cfg = config
        
        # --- FIX: L·∫§Y S·ªê LABEL T·ª™ OUTPUT LAYER C·ª¶A MODEL ---
        # model.out l√† l·ªõp Linear cu·ªëi c√πng: nn.Linear(512, num_classes)
        self.num_labels = model.out.out_features 
        
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
        self.criterion = nn.BCEWithLogitsLoss()
        
        # Metrics: D√πng self.num_labels thay v√¨ config.NUM_LABELS
        self.f1_score = MultilabelF1Score(num_labels=self.num_labels, average='micro').to(config.DEVICE)
        self.acc_score = MultilabelAccuracy(num_labels=self.num_labels, average='micro').to(config.DEVICE)

    def train_epoch(self, epoch_idx):
        self.model.train()
        total_loss = 0
        self.f1_score.reset()
        self.acc_score.reset()
        
        for X, y in self.train_dl:
            X, y = X.to(self.cfg.DEVICE), y.to(self.cfg.DEVICE)
            
            self.optimizer.zero_grad()
            out = self.model(X)
            loss = self.criterion(out, y)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            self.f1_score.update(out, y)
            self.acc_score.update(out, y)
            
        avg_loss = total_loss / len(self.train_dl)
        f1 = self.f1_score.compute()
        acc = self.acc_score.compute()
        print(f"   Epoch {epoch_idx}/{self.cfg.EPOCHS} | Loss: {avg_loss:.4f} | Train F1: {f1:.4f} | Acc: {acc:.4f}")

    def validate(self):
        self.model.eval()
        self.f1_score.reset()
        with torch.no_grad():
            for X, y in self.val_dl:
                X, y = X.to(self.cfg.DEVICE), y.to(self.cfg.DEVICE)
                out = self.model(X)
                self.f1_score.update(out, y)
        print(f"   üî• Validation F1: {self.f1_score.compute():.4f}")

    def fit(self):
        print("üöÄ Start Training...")
        for epoch in range(1, self.cfg.EPOCHS + 1):
            self.train_epoch(epoch)
        self.validate()
        print("‚úÖ Training Completed.")

    def predict(self, test_dl):
        print("üîÆ Predicting...")
        self.model.eval()
        preds = []
        with torch.no_grad():
            for X in tqdm(test_dl):
                X = X.to(self.cfg.DEVICE)
                preds.append(torch.sigmoid(self.model(X)).cpu().numpy())
        return np.vstack(preds)

# GO PROPAGATOR

In [8]:
import networkx as nx
class GOPropagator:
    def __init__(self, obo_path, term_to_idx):
        self.obo_path = obo_path
        self.term_to_idx = term_to_idx
        self.idx_to_term = {v: k for k, v in term_to_idx.items()}
        self.graph = None
        
    def load_obo(self):
        print(f"üå≥ Parsing OBO structure from: {self.obo_path}")
        self.graph = nx.DiGraph()
        
        current_term = None
        with open(self.obo_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith("[Term]"):
                    current_term = None
                elif line.startswith("id: GO:"):
                    current_term = line[4:]
                    self.graph.add_node(current_term)
                elif line.startswith("is_a: GO:") and current_term:
                    parent = line.split()[1]
                    # Th√™m c·∫°nh t·ª´ Con -> Cha (ƒë·ªÉ d·ªÖ duy·ªát ng∆∞·ª£c l√™n)
                    self.graph.add_edge(current_term, parent)
                    
    def propagate(self, probs_matrix):
        """
        C·∫≠p nh·∫≠t ƒëi·ªÉm s·ªë theo quy t·∫Øc: ƒêi·ªÉm cha = max(ƒêi·ªÉm cha, ƒêi·ªÉm con)
        """
        if self.graph is None: self.load_obo()
        
        print("üåä Starting Score Propagation (Bottom-up)...")
        
        # 1. L·ªçc ra c√°c GO terms c√≥ trong m√¥ h√¨nh c·ªßa ch√∫ng ta
        # S·∫Øp x·∫øp theo th·ª© t·ª± topo (t·ª´ l√° l√™n g·ªëc) ƒë·ªÉ lan truy·ªÅn hi·ªáu qu·∫£ nh·∫•t
        # Tuy nhi√™n, ch·ªâ c·∫ßn ƒë·∫£m b·∫£o duy·ªát qua t·∫•t c·∫£ quan h·ªá
        
        # L·∫•y danh s√°ch index c√°c term c√≥ trong m√¥ h√¨nh
        valid_terms = set(self.term_to_idx.keys())
        
        # T√¨m c√°c c·∫°nh quan h·ªá m√† c·∫£ Cha v√† Con ƒë·ªÅu n·∫±m trong t·∫≠p nh√£n d·ª± ƒëo√°n
        relevant_edges = []
        for child, parent in self.graph.edges():
            if child in valid_terms and parent in valid_terms:
                relevant_edges.append((self.term_to_idx[child], self.term_to_idx[parent]))
        
        print(f"   Found {len(relevant_edges)} parent-child relationships within the {len(valid_terms)} predicted labels.")
        
        # 2. Th·ª±c thi lan truy·ªÅn
        # Do c·∫•u tr√∫c DAG ph·ª©c t·∫°p, ta l·∫∑p v√†i l·∫ßn ƒë·ªÉ ƒë·∫£m b·∫£o ƒëi·ªÉm lan truy·ªÅn h·∫øt t·ª´ ƒë√°y l√™n ƒë·ªânh
        # (Ho·∫∑c d√πng topological sort nh∆∞ng l·∫∑p 3-4 l·∫ßn l√† ƒë·ªß cho ƒë·ªô s√¢u GO)
        new_probs = probs_matrix.copy()
        
        for i in range(3): # L·∫∑p 3 l·∫ßn ƒë·ªÉ lan truy·ªÅn s√¢u
            count_updates = 0
            for child_idx, parent_idx in relevant_edges:
                # Quy t·∫Øc: Score cha = Max(Score cha, Score con)
                # D√πng np.maximum ƒë·ªÉ t√≠nh to√°n song song tr√™n to√†n b·ªô batch proteins
                original_parent_scores = new_probs[:, parent_idx]
                updated_scores = np.maximum(new_probs[:, parent_idx], new_probs[:, child_idx])
                
                # Ki·ªÉm tra xem c√≥ thay ƒë·ªïi kh√¥ng (optional, ƒë·ªÉ debug)
                # new_probs[:, parent_idx] = updated_scores
                
                # Optimization: Ch·ªâ g√°n l·∫°i
                new_probs[:, parent_idx] = updated_scores
                
            print(f"   Iteration {i+1} complete.")
            
        return new_probs


# NEGATIVE PROPAGATOR

In [9]:
class NegativePropagator:
    def __init__(self, config):
        self.cfg = config
        self.children = defaultdict(list)
        self.negative_pairs = set()

    def parse_obo(self):
        print(f"üå≥ Parsing OBO: {self.cfg.GO_OBO}") # <--- D√πng GO_OBO
        term = None
        with open(self.cfg.GO_OBO, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('[Term]'): term = None
                elif line.startswith('id: GO:'): term = line[4:]
                elif line.startswith('is_a: GO:') and term:
                    self.children[line.split()[1]].append(term)

    def _get_descendants(self, root):
        desc = set()
        stack = [root]
        while stack:
            curr = stack.pop()
            if curr in self.children:
                for kid in self.children[curr]:
                    if kid not in desc:
                        desc.add(kid)
                        stack.append(kid)
        return desc

    def load_negatives(self, target_ids):
        if not os.path.exists(self.cfg.GOA_FILE): # <--- D√πng GOA_FILE
            print("‚ö†Ô∏è GOA File not found. Skipping.")
            return

        print("üö´ Loading Negative Annotations...")
        target_set = set(target_ids)
        chunk_iter = pd.read_csv(self.cfg.GOA_FILE, chunksize=500000, usecols=['protein_id', 'go_term', 'qualifier'])
        
        for chunk in chunk_iter:
            neg = chunk[(chunk['qualifier'].str.contains('NOT', na=False)) & (chunk['protein_id'].isin(target_set))]
            for _, row in neg.iterrows():
                pid, term = row['protein_id'], row['go_term']
                self.negative_pairs.add(f"{pid}_{term}")
                for d in self._get_descendants(term):
                    self.negative_pairs.add(f"{pid}_{d}")
        print(f"   Found {len(self.negative_pairs)} negative constraints.")

# MAIN EXECUTION

In [10]:
def load_specialist_model(ckpt_path, num_labels):
    print(f"üì• Loading specialist model from {ckpt_path}...")
    checkpoint = torch.load(ckpt_path, map_location=Config.DEVICE)
    
    # Init Model
    model = ProteinClassifier(Config.EMBED_DIM, num_labels).to(Config.DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint['term_to_idx']

In [11]:
def get_weighted_thresholds(term_to_idx, train_terms_path, min_t=0.01, max_t=0.15):
    """
    T√≠nh ng∆∞·ª°ng ƒë·ªông d·ª±a tr√™n t·∫ßn su·∫•t xu·∫•t hi·ªán c·ªßa nh√£n trong t·∫≠p Train.
    - Nh√£n xu·∫•t hi·ªán nhi·ªÅu -> Threshold cao (max_t)
    - Nh√£n xu·∫•t hi·ªán √≠t -> Threshold th·∫•p (min_t)
    """
    print("‚öñÔ∏è Calculating Weighted Thresholds based on Term Frequency...")
    
    # 1. ƒê·∫øm s·ªë l·∫ßn xu·∫•t hi·ªán c·ªßa t·ª´ng term trong t·∫≠p Train
    df = pd.read_csv(train_terms_path, sep="\t", usecols=['term'])
    term_counts = df['term'].value_counts()
    
    # 2. Kh·ªüi t·∫°o vector threshold (m·∫∑c ƒë·ªãnh l√† min_t)
    num_labels = len(term_to_idx)
    thresholds = np.full(num_labels, min_t, dtype=np.float32)
    
    # 3. T√≠nh to√°n threshold cho t·ª´ng nh√£n
    max_count = term_counts.iloc[0] # S·ªë l∆∞·ª£ng c·ªßa nh√£n ph·ªï bi·∫øn nh·∫•t
    
    found_count = 0
    for term, idx in term_to_idx.items():
        if term in term_counts.index:
            count = term_counts[term]
            # C√¥ng th·ª©c n·ªôi suy tuy·∫øn t√≠nh:
            # T_i = Min + (Max - Min) * (Count_i / Max_Count)
            freq_ratio = count / max_count
            thresholds[idx] = min_t + (max_t - min_t) * freq_ratio
            found_count += 1
            
    print(f"   Calculated thresholds for {found_count}/{num_labels} terms.")
    print(f"   Range: {thresholds.min():.4f} -> {thresholds.max():.4f}")
    
    return thresholds

In [12]:
def load_foldseek_matrix(foldseek_path, target_pids, term_to_idx, shape):
    """
    ƒê·ªçc file Foldseek v√† chuy·ªÉn th√†nh ma tr·∫≠n numpy [N_test, N_labels]
    kh·ªõp v·ªõi th·ª© t·ª± c·ªßa model Deep Learning.
    """
    print(f"üß¨ Loading Foldseek results from: {foldseek_path}")
    
    # 1. T·∫°o ma tr·∫≠n r·ªóng
    fs_matrix = np.zeros(shape, dtype=np.float32)
    
    # 2. T·∫°o map nhanh cho Protein ID (ID -> Row Index)
    pid_to_row = {pid: i for i, pid in enumerate(target_pids)}
    
    # 3. ƒê·ªçc file Foldseek (TSV: EntryID, term, score)
    # L∆∞u √Ω: File Foldseek th∆∞·ªùng r·∫•t n·∫∑ng, c·∫ßn ƒë·ªçc t·ªëi ∆∞u
    try:
        # Gi·∫£ s·ª≠ file kh√¥ng c√≥ header, c·ªôt 0 l√† ID, 1 l√† Term, 2 l√† Score
        df = pd.read_csv(foldseek_path, sep='\t', header=None, names=['EntryID', 'term', 'score'], 
                         dtype={'score': np.float32}, usecols=[0, 1, 2])
        
        # 4. L·ªçc d·ªØ li·ªáu
        # Ch·ªâ l·∫•y nh·ªØng d√≤ng c√≥ ID n·∫±m trong Test Set v√† Term n·∫±m trong Top 1500 c·ªßa m√¨nh
        df = df[df['EntryID'].isin(pid_to_row) & df['term'].isin(term_to_idx)]
        
        print(f"   Mapping {len(df)} Foldseek predictions to matrix...")
        
        # 5. ƒêi·ªÅn v√†o ma tr·∫≠n (Vectorized operation s·∫Ω nhanh h∆°n loop)
        # Map ID v√† Term sang index
        row_indices = df['EntryID'].map(pid_to_row).values
        col_indices = df['term'].map(term_to_idx).values
        scores = df['score'].values
        
        # G√°n ƒëi·ªÉm (n·∫øu c√≥ tr√πng l·∫∑p, d√≤ng sau ƒë√® d√≤ng tr∆∞·ªõc - ho·∫∑c d√πng h√†m aggregate n·∫øu c·∫ßn)
        fs_matrix[row_indices, col_indices] = scores
        
        print("   Foldseek matrix constructed successfully.")
        return fs_matrix
        
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading Foldseek: {e}")
        print("   Returning empty matrix (only DL model will be used).")
        return np.zeros(shape, dtype=np.float32)

def blend_predictions(dl_probs, fs_probs):
    """
    Chi·∫øn thu·∫≠t: ∆Øu ti√™n DL, ch·ªâ d√πng Foldseek n·∫øu DL = 0
    """
    # Copy ma tr·∫≠n DL
    final_probs = dl_probs.copy()
    
    # T√¨m nh·ªØng ch·ªó m√† DL model ho√†n to√†n kh√¥ng bi·∫øt g√¨ (score x·∫•p x·ªâ 0)
    # Nh∆∞ng Foldseek l·∫°i t√¨m th·∫•y manh m·ªëi
    mask = (dl_probs < 0.001) & (fs_probs > 0.0)
    
    # ƒêi·ªÅn ƒëi·ªÉm Foldseek v√†o nh·ªØng ch·ªó ƒë√≥
    final_probs[mask] = fs_probs[mask]
    
    return final_probs

In [13]:
if __name__ == "__main__":
    
    # --- ‚ö†Ô∏è FIX PATHS HERE IF YOU HAVE TRAINED MODELS ---
    # Ensure these point to the correct files (e.g., model_CCO.pth, model_MFO.pth)
    # If you haven't trained CCO/MFO yet, this code will safely skip them if they fail to load.
    aspects_config = [
        ('BPO', Config.NUM_LABELS_BPO, 'model_BPO.pth' if Config.IS_TRAINING else Config.CKPT_BPO),
        ('CCO', Config.NUM_LABELS_CCO, 'model_CCO.pth' if Config.IS_TRAINING else Config.CKPT_CCO),
        ('MFO', Config.NUM_LABELS_MFO, 'model_MFO.pth' if Config.IS_TRAINING else Config.CKPT_MFO)
    ]

    # ==========================\n    # PHASE 1: TRAINING\n    # ==========================
    if Config.IS_TRAINING:
        print("üöÄ STARTING MULTI-ONTOLOGY TRAINING PIPELINE...")
        
        dm = CAFA6DataManager(Config)
        dm.load_embeddings()
        
        for aspect, num_lbls, save_path in aspects_config:
            print(f"\n{'='*40}")
            print(f" >>> Training Specialist Model: {aspect} ({num_lbls} labels) <<<")
            print(f"{'='*40}")
            
            try:
                # 1. Prepare Data
                train_dl, val_dl = dm.prepare_train_loaders(aspect, num_lbls)
                
                # 2. Init & Train
                model = ProteinClassifier(Config.EMBED_DIM, num_lbls).to(Config.DEVICE)
                trainer = Trainer(model, train_dl, val_dl, Config)
                trainer.fit()
                
                # 3. Save
                checkpoint = {
                    'model_state_dict': model.state_dict(),
                    'term_to_idx': dm.term_to_idx,
                    'config': {'num_labels': num_lbls}
                }
                torch.save(checkpoint, save_path)
                print(f"‚úÖ Model saved to: {save_path}")
                
            except Exception as e:
                print(f"‚ùå Training failed for {aspect}: {e}")
            
            # 4. Cleanup
            if 'model' in locals(): del model
            if 'trainer' in locals(): del trainer
            if 'train_dl' in locals(): del train_dl
            if 'val_dl' in locals(): del val_dl
            gc.collect()
            torch.cuda.empty_cache()
            
        print("\nüéâ TRAINING COMPLETED.")

    # ==========================\n    # PHASE 2: INFERENCE (ENSEMBLE)\n    # ==========================
    else:
        print("üîÆ STARTING MULTI-ONTOLOGY INFERENCE PIPELINE...")
        
        # 1. Load Test Data
        dm = CAFA6DataManager(Config)
        test_dl = dm.prepare_test_loader()
        
        all_dfs = [] 
        
        # 2. Loop through models
        for aspect, num_lbls, load_path in aspects_config:
            print(f"\n >>> Processing Aspect: {aspect} <<<")
            
            # Initialize variables to None to prevent NameError in cleanup
            model = None
            trainer = None
            probs = None
            
            try:
                # a. Load Model
                model, term_to_idx = load_specialist_model(load_path, num_lbls)
                
                # b. Predict (DL)
                trainer = Trainer(model, None, None, Config)
                probs = trainer.predict(test_dl)
                
                # c. Propagation
                propagator = GOPropagator(Config.GO_OBO, term_to_idx)
                probs = propagator.propagate(probs)
                
                # d. Create Temp DataFrame
                idx_to_term = {v: k for k, v in term_to_idx.items()}
                rows, cols = np.where(probs > Config.THRESHOLD)
                print(f"   Found {len(rows)} predictions for {aspect}.")
                
                df_temp = pd.DataFrame({
                    'EntryID': [dm.test_ids[i] for i in rows],
                    'term': [idx_to_term[j] for j in cols],
                    'score': probs[rows, cols]
                })
                all_dfs.append(df_temp)
                
            except Exception as e:
                print(f"‚ö†Ô∏è Error processing {aspect}: {e}")
                print(f"   (Check if {load_path} matches the model architecture for {aspect})")
            
            # Cleanup safely
            if model is not None: del model
            if trainer is not None: del trainer
            if probs is not None: del probs
            gc.collect()
            torch.cuda.empty_cache()
            
        # 3. Concatenate
        if len(all_dfs) > 0:
            print("\nüîó Concatenating predictions from all models...")
            final_df = pd.concat(all_dfs, ignore_index=True)
            print(f"üìù Final submission has {len(final_df)} rows.")
            final_df.to_csv('submission.tsv', sep='\t', header=False, index=False)
            print("‚úÖ DONE! submission.tsv created.")
            print(final_df.head())
        else:
            print("‚ùå No predictions generated! Check your model paths.")

üîÆ STARTING MULTI-ONTOLOGY INFERENCE PIPELINE...
üìö Parsing Aspect (Namespace) from /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo...
   Mapped 48101 terms to aspects.
üì• Loading Embeddings source...
   Loaded 287001 protein embeddings.
üìñ Reading Test Fasta: /kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta
   Constructing Test Matrix...
   Found vectors for 224309/224309 test proteins.

 >>> Processing Aspect: BPO <<<
üì• Loading specialist model from /kaggle/input/multi-ontology/pytorch/default/1/model_BPO.pth...
üîÆ Predicting...


  0%|          | 0/1753 [00:00<?, ?it/s]

üå≥ Parsing OBO structure from: /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo
üåä Starting Score Propagation (Bottom-up)...
   Found 863 parent-child relationships within the 1500 predicted labels.
   Iteration 1 complete.
   Iteration 2 complete.
   Iteration 3 complete.
   Found 12430733 predictions for BPO.

 >>> Processing Aspect: CCO <<<
üì• Loading specialist model from /kaggle/input/multi-ontology/pytorch/default/1/model_CCO.pth...
üîÆ Predicting...


  0%|          | 0/1753 [00:00<?, ?it/s]

üå≥ Parsing OBO structure from: /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo
üåä Starting Score Propagation (Bottom-up)...
   Found 452 parent-child relationships within the 800 predicted labels.
   Iteration 1 complete.
   Iteration 2 complete.
   Iteration 3 complete.
   Found 6731045 predictions for CCO.

 >>> Processing Aspect: MFO <<<
üì• Loading specialist model from /kaggle/input/multi-ontology/pytorch/default/1/model_MFO.pth...
üîÆ Predicting...


  0%|          | 0/1753 [00:00<?, ?it/s]

üå≥ Parsing OBO structure from: /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo
üåä Starting Score Propagation (Bottom-up)...
   Found 598 parent-child relationships within the 800 predicted labels.
   Iteration 1 complete.
   Iteration 2 complete.
   Iteration 3 complete.
   Found 4449028 predictions for MFO.

üîó Concatenating predictions from all models...
üìù Final submission has 23610806 rows.
‚úÖ DONE! submission.tsv created.
      EntryID        term     score
0  A0A0C5B5G6  GO:0006355  0.029428
1  A0A0C5B5G6  GO:0045893  0.027978
2  A0A0C5B5G6  GO:0006974  0.008850
3  A0A0C5B5G6  GO:0008285  0.005888
4  A0A0C5B5G6  GO:0050832  0.115165


In [14]:
final_df['score'].max()

0.9999929666519165