In [None]:
!pip install biopython > dev null
!pip install obonet > dev null 
!pip install -q torch esm
!pip install -q lightgbm

import pandas as pd
import torch
from esm import pretrained
import lightgbm as lgb
import numpy as np
from Bio import SeqIO
from tqdm import tqdm
from collections import Counter
import networkx as nx

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
 
# Modeling
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import f1_score

import warnings
import obonet
import gc
import time
import os

warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)

class CFG:
    # File Paths
    TRAIN_TERMS_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
    TRAIN_SEQUENCES_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
    TRAIN_TAXONOMY_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'
    TEST_SEQUENCES_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
    IA_PATH = '/kaggle/input/cafa-6-protein-function-prediction/IA.tsv'
    
    # EDA & Plotting
    COLORS = ['#221f1f', '#b20710', '#e50914', 'grey']
    BACKGROUND_COLOR = '#f5f6f6'
    
    # Modeling
    PROBABILITY_THRESHOLD = 0.02
    GO_PATH = '/kaggle/input/go-dag/go-basic.obo'


# Per-ontology thresholds (tunable)
ONTOLOGY_THRESH = {'P': 0.05, 'F': 0.03, 'C': 0.03}

# Map aspect codes to GO namespaces for propagation 
ASPECT_TO_NAMESPACE = {'P': 'biological_process', 'F': 'molecular_function', 'C': 'cellular_component'}


# Load ESM-2 model via Torch Hub (for esm 3.x)
MODEL = "esm2_t33_650M_UR50D"
print("Loading ESM-2 model...")
esm_model, esm_alphabet = torch.hub.load("facebookresearch/esm:main", MODEL)
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model.eval()
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model = esm_model.to(device)
print(f"ESM-2 model loaded on {device}.")

In [None]:
print("Loading tabular datasets...")
train_terms_df = pd.read_csv(CFG.TRAIN_TERMS_PATH, sep='\t')
train_taxonomy_df = pd.read_csv(CFG.TRAIN_TAXONOMY_PATH, sep='\t', header=None, names=['EntryID', 'taxonomyID'])
ia_df = pd.read_csv(CFG.IA_PATH, sep='\t', names=['term', 'ia'])

                                                  # Load GO graph for ancestor propagation
print("Loading GO graph for hierarchy propagation...")
go_graph = obonet.read_obo(CFG.GO_PATH)
DG = nx.DiGraph(go_graph)
print(f"GO graph loaded with {DG.number_of_nodes():,} terms and {DG.number_of_edges():,} edges.")

def load_fasta_to_dataframe(file_path, is_train=True):
    records = []
    parser = SeqIO.parse(file_path, "fasta")
    for record in tqdm(parser, desc=f"Parsing {file_path.split('/')[-1]}"):
        entry_id = record.id.split('|')[1] if is_train and '|' in record.id else record.id.split()[0]
        records.append({'EntryID': entry_id, 'sequence': str(record.seq)})
    return pd.DataFrame(records)

print("\nLoading sequence datasets...")
train_sequences_df = load_fasta_to_dataframe(CFG.TRAIN_SEQUENCES_PATH, is_train=True)

print("Consolidating data into a master dataframe for EDA...")
protein_labels = train_terms_df.groupby('EntryID')['term'].apply(list).reset_index(name='labels')
train_df_eda = pd.merge(train_sequences_df, train_taxonomy_df, on='EntryID', how='left')
train_df_eda = pd.merge(train_df_eda, protein_labels, on='EntryID', how='inner')
train_df_eda['seq_length'] = train_df_eda['sequence'].str.len()
train_df_eda['num_labels'] = train_df_eda['labels'].str.len()
print("Data loading complete.")

In [None]:
def extract_esm_embedding(sequence):
    """Extract mean-pooled ESM-2 embedding for a protein sequence"""
    if not sequence or len(sequence) == 0:
        return np.zeros(1280)  # ESM-2 embedding dimension

    # Prepare batch (model expects tuples of (label, sequence))
    batch = [("protein", sequence[:1022])]  # ESM-2 has length limit
    _, _, tokens = esm_batch_converter(batch)

    with torch.no_grad():
        tokens = tokens.to(device)  # Ensure tokens are on device
        results = esm_model(tokens, repr_layers=[33], return_contacts=False)
        embedding = results["representations"][33]  # Shape: [1, seq_len, 1280]

        # Mean pool over sequence length (excluding first and last tokens)
        mean_embedding = embedding[0, 1:-1, :].mean(dim=0).cpu().numpy()

    return mean_embedding

### PENDING - I need to get interPRO Features .tsv
### PENDING - I need to get interPRO Features .tsv
### PENDING - I need to get interPRO Features .tsv
def load_interpro_features(protein_ids, interpro_file='/kaggle/input/interpro-cafa6/interpro_features.tsv'):
    """Load binary InterPro domain presence features"""
    if not os.path.exists(interpro_file):
        print(f"InterPro file {interpro_file} not found. Skipping InterPro features.")
        return np.zeros((len(protein_ids), 0))  # Return empty matrix

    interpro_df = pd.read_csv(interpro_file, sep='\t')
    # Create binary matrix: proteins x domains
    domain_matrix = interpro_df.pivot(index='EntryID', columns='InterPro_ID', values='present')
    domain_matrix = domain_matrix.fillna(0).astype(int)

    # Align with protein_ids
    missing_ids = set(protein_ids) - set(domain_matrix.index)
    if missing_ids:
        # Add missing proteins with all zeros
        missing_df = pd.DataFrame(0, index=list(missing_ids), columns=domain_matrix.columns)
        domain_matrix = pd.concat([domain_matrix, missing_df])

    return domain_matrix.loc[protein_ids].values


def truepathrule_propagation(pred_df, go_graph, aspect, ia_df, min_ia=0.1):
    """Propagate predictions using TruePathRule and filter by information content"""

    # Build ancestor map once (for speed)
    ancestor_cache = {}
    for term in go_graph.nodes():
        if term.startswith('GO:'):
            ancestors = nx.ancestors(go_graph, term)
            # Only keep ancestors from the same aspect
            ancestor_cache[term] = [anc for anc in ancestors if go_graph.nodes[anc].get('namespace') == aspect]

    # Get IA values
    ia_dict = dict(zip(ia_df['term'], ia_df['ia']))

    propagated = []
    for _, row in pred_df.iterrows():
        protein_id = row['Protein Id']
        term = row['GO Term Id']
        score = row['Prediction']

        # Skip low-information terms
        if term not in ia_dict or ia_dict[term] < min_ia:
            continue

        propagated.append((protein_id, term, score))

        # Propagate to ancestors with decay factor
        if term in ancestor_cache:
            for anc in ancestor_cache[term]:
                if anc in ia_dict and ia_dict[anc] >= min_ia:
                    propagated.append((protein_id, anc, score * 0.9))  # Decay factor

    # Keep max score per term-protein pair
    result_df = pd.DataFrame(propagated, columns=['Protein Id', 'GO Term Id', 'Prediction'])
    result_df = result_df.groupby(['Protein Id', 'GO Term Id'])['Prediction'].max().reset_index()

    return result_df

def optimize_thresholds(val_predictions, val_labels, aspect_code):
    """Find optimal threshold that maximizes F1 score"""
    best_threshold = 0.01
    best_f1 = 0

    for threshold in np.arange(0.01, 0.3, 0.01):
        pred_binary = (val_predictions > threshold).astype(int)
        f1 = f1_score(val_labels.toarray(), pred_binary, average='micro', zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    print(f"Optimal threshold for {aspect_code}: {best_threshold:.3f} (F1: {best_f1:.3f})")
    return best_threshold
        
    
print("Extracting ESM-2 embeddings for training proteins...")

X_train_list, y_train_proteins = [], []
protein_ids_in_scope = set(train_terms_df['EntryID'].unique())
train_sequences_dict = dict(zip(train_sequences_df['EntryID'], train_sequences_df['sequence']))
for pid, seq in tqdm(train_sequences_dict.items(), desc="ESM-2 features"):
    if pid in protein_ids_in_scope:
        X_train_list.append(extract_esm_embedding(seq))
        y_train_proteins.append(pid)
X_train = np.array(X_train_list)
print(f"ESM-2 feature matrix shape: {X_train.shape}")

# Save embeddings to avoid recomputation
np.save('train_esm2_embeddings.npy', X_train)
print("ESM-2 embeddings saved to train_esm2_embeddings.npy")

In [None]:
print("Preparing labels and training models for each ontology...")
models, mlb_dict = {}, {}

# CORRECT MAPPING of ontology name to the code used in the dataframe
ontology_map = {'BPO': 'P', 'CCO': 'C', 'MFO': 'F'}

for aspect_name, aspect_code in ontology_map.items():
    print(f"\n--- Processing {aspect_name} ({aspect_code}) ---")
    
    # Filter terms for the current ontology using the CORRECT code
    ont_terms_df = train_terms_df[train_terms_df['aspect'] == aspect_code]
    
    # Continue if no terms are found for this ontology
    if ont_terms_df.empty:
        print(f"No terms found for {aspect_name}. Skipping model training.")
        continue
        
    protein_terms = ont_terms_df.groupby('EntryID')['term'].apply(list).to_dict()
    labels_list = [protein_terms.get(pid, []) for pid in y_train_proteins]
    
    mlb = MultiLabelBinarizer(sparse_output=True)
    y_train_ont = mlb.fit_transform(labels_list)
    
    print(f"Number of unique {aspect_name} terms: {y_train_ont.shape[1]}")
    
    #Robustness check: only train if there are labels
    if y_train_ont.shape[1] > 0:
        mlb_dict[aspect_code] = mlb
        num_samples = X_train.shape[0]
        num_features = X_train.shape[1] if X_train.ndim == 2 else len(X_train[0])
        num_classes = y_train_ont.shape[1]
        total_positive_labels = y_train_ont.nnz
        avg_labels_per_sample = (total_positive_labels / num_samples) if num_samples > 0 else 0.0
        print(f"Starting training for {aspect_name}: samples={num_samples}, features={num_features}, classes={num_classes}, positives={total_positive_labels}, avg_labels/sample={avg_labels_per_sample:.2f}")
        print(f"Backend: LightGBM CPU, n_jobs=-1 (detected cores={os.cpu_count()}); 5-fold CV ensemble")

        # Optimize threshold using validation split
        X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train_ont, test_size=0.1, random_state=42)
        temp_model = OneVsRestClassifier(
            make_pipeline(
                StandardScaler(),
                lgb.LGBMClassifier(num_leaves=256, learning_rate=0.05, n_estimators=200, class_weight='balanced', random_state=42, verbosity=-1)
            ),
            n_jobs=-1
        )
        temp_model.fit(X_tr, y_tr)
        y_val_pred = temp_model.predict_proba(X_val)
        optimal_thresh = optimize_thresholds(y_val_pred, y_val, aspect_code)
        ONTOLOGY_THRESH[aspect_code] = optimal_thresh

        
        # Train 5-fold CV ensemble
        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        trained_models = []
        start_time_all = time.time()
        for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
            print(f"  -> Training fold {fold+1}/5")
            X_fold_tr, X_fold_val = X_train[train_idx], X_train[val_idx]
            y_fold_tr, y_fold_val = y_train_ont[train_idx], y_train_ont[val_idx]
            model = OneVsRestClassifier(
                make_pipeline(
                    StandardScaler(),
                    lgb.LGBMClassifier(
                        num_leaves=256,
                        learning_rate=0.05, 
                        n_estimators=500,
                        class_weight='balanced',
                        random_state=42,
                        verbosity-1
                    )
                ),
                n_jobs=-1
            )
            start_time = time.time()
            model.fit(X_fold_tr, y_fold_tr)
            elapsed = time.time() - start_time
            per_class = (num_classes / elapsed) if elapsed > 0 else float('inf')
            print(f"     Fold {fold+1} done. Elapsed: {elapsed:.1f}s, classes/sec: {per_class:.1f}")
            trained_models.append(model)
        total_elapsed = time.time() - start_time_all
        print(f"All folds for {aspect_name} trained. Total elapsed: {total_elapsed:.1f}s")
        models[aspect_code] = trained_models
    else:
        print(f"Skipping model training for {aspect_name} as no labels were binarized.")

In [None]:
print("\nLoading and processing test sequences for submission...")
test_sequences_df = load_fasta_to_dataframe(CFG.TEST_SEQUENCES_PATH, is_train=False)
test_protein_ids = test_sequences_df['EntryID'].tolist()
test_sequences_dict = dict(zip(test_sequences_df['EntryID'], test_sequences_df['sequence']))
BATCH_SIZE = 5000
submission_list = []

# Collect prediction stats per ontology
predict_stats = {code: {'selected': 0, 'after_prop': 0} for code in ['P','F','C']}

for i in tqdm(range(0, len(test_protein_ids), BATCH_SIZE), desc="Predicting on Test Set"):
    batch_pids = test_protein_ids[i:i+BATCH_SIZE]
    batch_seqs = [test_sequences_dict[pid] for pid in batch_pids]
    X_batch = np.array([extract_esm_embedding(seq) for seq in batch_seqs])
    for aspect_code, model_list in models.items():
        mlb = mlb_dict[aspect_code]
        # Average probabilities across seeds
        y_pred_proba_sum = None
        for model in model_list:
            proba_part = model.predict_proba(X_batch)
            if y_pred_proba_sum is None:
                y_pred_proba_sum = proba_part
            else:
                y_pred_proba_sum += proba_part
        y_pred_proba = y_pred_proba_sum / len(model_list)
        for j, pid in enumerate(batch_pids):
            probs = y_pred_proba[j]
            candidate_indices = np.where(probs > ONTOLOGY_THRESH[aspect_code])[0]
            term_scores = [(mlb.classes_[idx], float(probs[idx])) for idx in candidate_indices]
            predict_stats[aspect_code]['selected'] += len(term_scores)
            # Create df for propagation
            pred_df = pd.DataFrame(term_scores, columns=['GO Term Id', 'Prediction'])
            pred_df['Protein Id'] = pid
            # Apply TruePathRule propagation
            propagated_df = truepathrule_propagation(pred_df, DG, ASPECT_TO_NAMESPACE[aspect_code], ia_df, min_ia=0.05)
            predict_stats[aspect_code]['after_prop'] += len(propagated_df)
            for _, row in propagated_df.iterrows():
                submission_list.append((row['Protein Id'], row['GO Term Id'], row['Prediction']))

print(f"Generated {len(submission_list):,} total predictions.")
submission_df = pd.DataFrame(submission_list, columns=['Protein Id', 'GO Term Id', 'Prediction'])
# Optional: filter to valid GO terms present in the graph
go_nodes = set(DG.nodes())
before_filter = len(submission_df)
submission_df = submission_df[submission_df['GO Term Id'].isin(go_nodes)]
after_filter = len(submission_df)
print(f"Filtered invalid GO terms: {before_filter - after_filter} removed, remaining {after_filter:,}.")

# Print prediction stats
for code, s in predict_stats.items():
    print(f"Stats {code}: threshold={ONTOLOGY_THRESH.get(code):.3f}, selected={s['selected']:,}, after_propagation={s['after_prop']:,}")
print("Applying 1500 prediction limit per protein...")
submission_df = submission_df.sort_values(by=['Protein Id', 'Prediction'], ascending=[True, False])
final_submission_df = submission_df.groupby('Protein Id').head(1500).reset_index(drop=True)
final_submission_df.to_csv('submission.tsv', sep='\t', index=False, header=False)

print("\nSubmission file 'submission.tsv' created successfully.")
print(f"Total predictions in final submission: {len(final_submission_df):,}")
print("Submission DataFrame Head:")
display(final_submission_df.head())