# CAFA6 Protein Function Prediction - Plan 1: Initial Approach

This notebook implements a solution for the CAFA6 challenge using ESM2 protein embeddings, GO term embeddings, and cosine similarity for associations.

- **ESM2 Model**: facebook/esm2_t6_8M_UR50D
- **GO Embeddings**: Simple graph-based embeddings
- **Association**: Cosine similarity with thresholding and ontology propagation

Data path: `/kaggle/input/cafa-6-protein-function-prediction/`

## Step 1: Setup Environment and Import Dependencies

In [None]:
# Install any missing packages if needed
# !pip install transformers torch networkx obonet biopython

import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import torch
from transformers import AutoTokenizer, AutoModel
import networkx as nx
import obonet
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Data paths
data_dir = '/kaggle/input/cafa-6-protein-function-prediction/'
train_seq_file = os.path.join(data_dir, 'train_sequences.fasta')
train_terms_file = os.path.join(data_dir, 'train_terms.tsv')
go_obo_file = os.path.join(data_dir, 'go-basic.obo')
ia_file = os.path.join(data_dir, 'IA.tsv')
test_seq_file = os.path.join(data_dir, 'testsuperset.fasta')

print('Environment setup complete.')

## Step 2: Load and Preprocess Training Data

In [None]:
# Load training sequences
train_sequences = {}
for record in SeqIO.parse(train_seq_file, 'fasta'):
    train_sequences[record.id] = str(record.seq)
print(f'Loaded {len(train_sequences)} training sequences')

# Load training terms
train_terms = pd.read_csv(train_terms_file, sep='\t', header=None, names=['EntryID', 'term', 'aspect'])
print(f'Loaded {len(train_terms)} training term annotations')
print(train_terms.head())

# Load IA weights
ia_weights = pd.read_csv(ia_file, sep='\t', header=None, names=['term', 'weight'])
ia_dict = dict(zip(ia_weights['term'], ia_weights['weight']))
print(f'Loaded IA weights for {len(ia_dict)} terms')

# Load GO ontology
go_graph = obonet.read_obo(go_obo_file)
print(f'Loaded GO graph with {len(go_graph.nodes)} nodes and {len(go_graph.edges)} edges')

# Get unique GO terms from training
go_terms = train_terms['term'].unique()
print(f'Unique GO terms in training: {len(go_terms)}')

# Subset GO graph to relevant terms
relevant_nodes = set(go_terms)
for term in go_terms:
    if term in go_graph:
        ancestors = nx.ancestors(go_graph, term)
        relevant_nodes.update(ancestors)
go_subgraph = go_graph.subgraph(relevant_nodes)
print(f'Relevant GO subgraph: {len(go_subgraph.nodes)} nodes')

## Step 3: Compute ESM2 Protein Embeddings

In [None]:
# Load ESM2 model and tokenizer
model_name = 'facebook/esm2_t6_8M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.to(device)
model.eval()
print(f'Loaded ESM2 model: {model_name}')

# Function to get embeddings
def get_esm_embedding(sequence):
    inputs = tokenizer(sequence, return_tensors='pt', truncation=True, max_length=1024).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    # Use mean pooling over sequence length
    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    return embedding.flatten()

# Compute embeddings for training proteins (batch to save memory)
batch_size = 10  # Adjust based on GPU memory
protein_ids = list(train_sequences.keys())
protein_embeddings = {}

for i in range(0, len(protein_ids), batch_size):
    batch_ids = protein_ids[i:i+batch_size]
    batch_seqs = [train_sequences[pid] for pid in batch_ids]
    
    # Tokenize batch
    inputs = tokenizer(batch_seqs, return_tensors='pt', truncation=True, max_length=1024, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Mean pooling
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    
    for j, pid in enumerate(batch_ids):
        protein_embeddings[pid] = embeddings[j]
    
    if (i // batch_size) % 10 == 0:
        print(f'Processed {i + len(batch_ids)} / {len(protein_ids)} proteins')

print(f'Computed embeddings for {len(protein_embeddings)} proteins, embedding dim: {protein_embeddings[protein_ids[0]].shape}')

## Step 4: Compute GO Term Embeddings

In [None]:
# Simple GO embedding using graph features: degree, depth, and random projection
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Compute features for each term
go_features = {}
for node in go_subgraph.nodes:
    degree = go_subgraph.degree(node)
    # Depth: distance from root (approximate)
    try:
        depth = nx.shortest_path_length(go_subgraph, source=list(go_subgraph.nodes)[0], target=node)  # Assuming first node is root-like
    except:
        depth = 0
    go_features[node] = [degree, depth]

# Convert to array
go_terms_list = list(go_features.keys())
features_array = np.array([go_features[term] for term in go_terms_list])

# Standardize and reduce to embedding dim (match protein dim ~320)
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features_array)
pca = PCA(n_components=320)
go_embeddings = pca.fit_transform(features_scaled)

# Dict for quick lookup
go_embeddings_dict = {term: go_embeddings[i] for i, term in enumerate(go_terms_list)}

print(f'Computed embeddings for {len(go_embeddings_dict)} GO terms, embedding dim: {go_embeddings.shape[1]}')

## Step 5: Develop Similarity-Based Association Method

In [None]:
# Function to predict GO terms for a protein using cosine similarity
def predict_go_terms(protein_emb, go_embeddings_dict, threshold=0.5):
    predictions = {}
    for term, go_emb in go_embeddings_dict.items():
        sim = cosine_similarity([protein_emb], [go_emb])[0][0]
        if sim > threshold:
            predictions[term] = sim
    return predictions

# Test on a small subset for validation
sample_proteins = list(protein_embeddings.keys())[:5]
for pid in sample_proteins:
    preds = predict_go_terms(protein_embeddings[pid], go_embeddings_dict, threshold=0.7)
    print(f'Protein {pid}: {len(preds)} predicted terms')

# Note: Full prediction for all proteins/terms is computationally expensive; optimize for test set

## Step 6: Implement Ontology Propagation for Predictions

In [None]:
# Function to propagate predictions to ancestors
def propagate_predictions(predictions, go_graph):
    propagated = predictions.copy()
    for term, score in predictions.items():
        if term in go_graph:
            ancestors = nx.ancestors(go_graph, term)
            for anc in ancestors:
                if anc not in propagated or propagated[anc] < score:
                    propagated[anc] = score
    return propagated

# Test propagation on sample
sample_preds = predict_go_terms(protein_embeddings[sample_proteins[0]], go_embeddings_dict, threshold=0.7)
propagated_preds = propagate_predictions(sample_preds, go_subgraph)
print(f'Original: {len(sample_preds)}, Propagated: {len(propagated_preds)}')

## Step 7: Train and Validate Model on Training Data

In [None]:
# Split training data for validation
train_proteins, val_proteins = train_test_split(list(protein_embeddings.keys()), test_size=0.2, random_state=42)
print(f'Train: {len(train_proteins)}, Val: {len(val_proteins)}')

# Tune threshold on validation set (simplified: use fixed threshold for now)
threshold = 0.5  # Tune this based on validation performance

# Compute predictions for validation proteins
val_predictions = {}
for pid in val_proteins[:10]:  # Subset for speed
    preds = predict_go_terms(protein_embeddings[pid], go_embeddings_dict, threshold)
    propagated = propagate_predictions(preds, go_subgraph)
    val_predictions[pid] = propagated

print(f'Validation predictions computed for {len(val_predictions)} proteins')

# Note: Implement full F1 evaluation here if time allows

## Step 8: Generate Predictions for Test Superset

In [None]:
# Load test sequences
test_sequences = {}
for record in SeqIO.parse(test_seq_file, 'fasta'):
    test_sequences[record.id] = str(record.seq)
print(f'Loaded {len(test_sequences)} test sequences')

# Compute embeddings for test proteins
test_embeddings = {}
test_ids = list(test_sequences.keys())
for i in range(0, len(test_ids), batch_size):
    batch_ids = test_ids[i:i+batch_size]
    batch_seqs = [test_sequences[pid] for pid in batch_ids]
    
    inputs = tokenizer(batch_seqs, return_tensors='pt', truncation=True, max_length=1024, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    
    for j, pid in enumerate(batch_ids):
        test_embeddings[pid] = embeddings[j]
    
    if (i // batch_size) % 10 == 0:
        print(f'Processed {i + len(batch_ids)} / {len(test_ids)} test proteins')

# Generate predictions for test set
test_predictions = {}
for pid in test_ids:
    preds = predict_go_terms(test_embeddings[pid], go_embeddings_dict, threshold)
    propagated = propagate_predictions(preds, go_subgraph)
    test_predictions[pid] = propagated

print(f'Generated predictions for {len(test_predictions)} test proteins')

## Step 9: Prepare and Output Submission File

In [None]:
# Prepare submission dataframe
submission_rows = []
for pid, preds in test_predictions.items():
    for term, score in preds.items():
        if score > 0:  # Only include positive predictions
            submission_rows.append([pid, term, score])

submission_df = pd.DataFrame(submission_rows, columns=['Protein ID', 'GO Term', 'Score'])
submission_df = submission_df.sort_values(['Protein ID', 'GO Term'])

# Save to file
submission_df.to_csv('/kaggle/working/submission.tsv', sep='\t', index=False, header=False)
print('Submission file saved to /kaggle/working/submission.tsv')
print(f'Total predictions: {len(submission_df)}')