In [None]:
# ==============================================================================
# 1. INITIAL SETUP AND LIBRARY IMPORTS
# ==============================================================================

# --- Install Required Packages ---
# obonet is for parsing the Gene Ontology .obo file.
# Biopython is for parsing FASTA sequence files.
# transformers and sentencepiece are for the protein language model.
!pip install -q obonet biopython transformers sentencepiece

# --- Core Libraries ---
import os
import gc
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# --- Bioinformatics & Data Handling ---
import obonet
from Bio import SeqIO
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score

# --- Machine Learning ---
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier


# --- Hugging Face Transformers for Protein Language Models ---
import torch
from transformers import T5Tokenizer, T5EncoderModel

# --- Configuration ---
# Set a seed for reproducibility.
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device configuration (CPU is mandated for this notebook).
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Define base path for data.
BASE_PATH = "/kaggle/input/cafa-6-protein-function-prediction/"

In [None]:
# ==============================================================================
# 2. DATA LOADING AND INITIAL PARSING
# ==============================================================================
# In this cell, we load all the primary files provided by the competition.

print("Loading primary data files...")

# --- Load Gene Ontology Graph ---
# The obonet library is used to parse the .obo file into a networkx graph
go_graph = obonet.read_obo(os.path.join(BASE_PATH, 'Train/go-basic.obo'))
print(f"Gene Ontology graph loaded with {len(go_graph)} nodes and {len(go_graph.edges)} edges.")

# --- Load Training Terms ---
train_terms_df = pd.read_csv(os.path.join(BASE_PATH, 'Train/train_terms.tsv'), sep='\\t')
print(f"Training terms loaded. Shape: {train_terms_df.shape}")

# --- Load Training Sequences ---
# We will parse the FASTA file later when we need the sequences.
train_fasta_path = os.path.join(BASE_PATH, 'Train/train_sequences.fasta')
print(f"Training sequences path set: {train_fasta_path}")

# --- Load Test Sequences ---
test_fasta_path = os.path.join(BASE_PATH, 'Test/testsuperset.fasta')
print(f"Test sequences path set: {test_fasta_path}")

# --- Load Information Accretion (Weights) ---
ia_df = pd.read_csv(os.path.join(BASE_PATH, 'IA.tsv'), sep='\\t', header=None, names=['term_id', 'ia_score'])
ia_map = dict(zip(ia_df['term_id'], ia_df['ia_score']))
print(f"Information Accretion scores loaded for {len(ia_map)} terms.")

# --- Display a sample of the training terms data ---
print("\\nSample of train_terms.tsv:")
display(train_terms_df.head())

# Table 5: Summary of GO Term Distribution in Training Data
print("\\nTable 5: Summary of GO Term Distribution in Training Data")
display(train_terms_df['aspect'].value_counts().reset_index())

In [None]:
# ==============================================================================
# 3. DATA PREPROCESSING AND TARGET MATRIX CONSTRUCTION
# ==============================================================================
# In this cell, we will process the raw training data into a format suitable
# for machine learning: a feature matrix (X) and a target matrix (Y).

# --- Limit the number of labels for computational efficiency ---
# The full GO ontology has ~40,000 terms. Many are extremely rare.
# We will focus on the top N most frequent terms to create a manageable problem for a CPU environment.
N_LABELS = 40122

# --- Identify the top N most frequent GO terms ---
top_n_labels = train_terms_df['term'].value_counts().nlargest(N_LABELS).index.tolist()
print(f"Identified the top {N_LABELS} most frequent GO terms.")

# --- Filter the training data to only include these top labels ---
train_terms_filtered_df = train_terms_df[train_terms_df['term'].isin(top_n_labels)]
print(f"Filtered training terms. New shape: {train_terms_filtered_df.shape}")

# --- Create a list of unique proteins in our filtered dataset ---
unique_proteins = train_terms_df['EntryID'].unique()
print(f"Number of unique proteins with top {N_LABELS} labels: {len(unique_proteins)}")

# --- Create a mapping from protein ID to a list of its GO terms ---
# This is a crucial step for creating the multi-label target matrix.
protein_to_go_map = train_terms_filtered_df.groupby('EntryID')['term'].apply(list).to_dict()

# --- Use MultiLabelBinarizer to create the target matrix Y ---
# This converts the list of GO terms for each protein into a binary vector.
mlb = MultiLabelBinarizer(classes=top_n_labels)
Y = mlb.fit_transform([protein_to_go_map.get(prot, []) for prot in unique_proteins])

print(f"Target matrix Y created with shape: {Y.shape}")
# Y_train.shape will be (number of unique proteins, N_LABELS)

# Table 6: Sparsity Analysis of the Target Matrix
matrix_density = Y.sum() / (Y.shape[0] * Y.shape[1])
print("\\n" + "="*50)
print("Table 6: Sparsity Analysis of the Target Matrix")
print("="*50)
print(f"Number of Proteins (Rows): {Y.shape[0]}")
print(f"Number of GO Terms (Columns): {Y.shape[1]}")
print(f"Total Annotations: {Y.sum()}")
print(f"Matrix Density: {matrix_density:.4%}")
print("="*50)

In [None]:
#loading the pre-embedded protein sequnces 
X = np.load("/kaggle/input/train-set-embedding/train_embeddings.npy")
print(f'X is created with shapes:{X.shape}')

# 5. DEEP LEARNING MODEL: MULTI-LABEL GO TERM PREDICTION

We'll implement a PyTorch-based deep neural network for multi-label classification. The architecture includes:

- **Input**: Pre-computed protein embeddings (1024-dim from ProtT5)
- **Hidden Layers**: Multiple fully-connected layers with dropout for regularization
- **Output**: Multi-label predictions for GO terms using sigmoid activation
- **Loss**: Binary Cross-Entropy with Logits Loss (BCEWithLogitsLoss)

**Key Design Choices:**
- Batch normalization for training stability
- Dropout to prevent overfitting
- Separate layers to capture complex patterns in embeddings
- Sigmoid activation for independent multi-label prediction

In [6]:
# ==============================================================================
# 5.1. DEFINE DEEP LEARNING MODEL ARCHITECTURE
# ==============================================================================

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class GOTermPredictor(nn.Module):
    """
    Deep Neural Network for Multi-Label GO Term Prediction
    
    Architecture:
    - Input: Protein embeddings (embedding_dim)
    - Hidden layers with batch normalization and dropout
    - Output: GO term predictions (num_labels)
    """
    
    def __init__(self, embedding_dim, num_labels, hidden_dims=[2048, 1024, 512], dropout=0.3):
        super(GOTermPredictor, self).__init__()
        
        # Build layers dynamically
        layers = []
        input_dim = embedding_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            input_dim = hidden_dim
        
        # Output layer (no activation - we'll use BCEWithLogitsLoss)
        layers.append(nn.Linear(input_dim, num_labels))
        
        self.network = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.network(x)


# Create the model
embedding_dim = X.shape[1]  # Should be 1024 for ProtT5
num_labels = Y.shape[1]      # Number of GO terms

model = GOTermPredictor(
    embedding_dim=embedding_dim,
    num_labels=num_labels,
    hidden_dims=[2048, 1024, 512],
    dropout=0.3
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*60)
print("MODEL ARCHITECTURE")
print("="*60)
print(model)
print("="*60)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Input dimension: {embedding_dim}")
print(f"Output dimension: {num_labels}")
print("="*60)

MODEL ARCHITECTURE
GOTermPredictor(
  (network): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.3, inplace=False)
    (8): Linear(in_features=1024, out_features=512, bias=True)
    (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.3, inplace=False)
    (12): Linear(in_features=512, out_features=26125, bias=True)
  )
)
Total parameters: 18,131,469
Trainable parameters: 18,131,469
Input dimension: 1024
Output dimension: 26125


In [7]:
# ==============================================================================
# 5.2. CREATE DATASET AND DATALOADERS
# ==============================================================================

class ProteinDataset(Dataset):
    """Custom Dataset for protein embeddings and GO term labels"""
    
    def __init__(self, embeddings, labels):
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


# Split data into train/validation sets
from sklearn.model_selection import train_test_split

X_train, X_val, Y_train, Y_val = train_test_split(
    X, Y, 
    test_size=0.15, 
    random_state=SEED,
    shuffle=True
)

# Create datasets
train_dataset = ProteinDataset(X_train, Y_train)
val_dataset = ProteinDataset(X_val, Y_val)

# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=0  # Use 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=0
)

print("="*60)
print("DATASET INFORMATION")
print("="*60)
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"\nAverage GO terms per protein:")
print(f"  Training: {Y_train.sum(axis=1).mean():.2f}")
print(f"  Validation: {Y_val.sum(axis=1).mean():.2f}")
print("="*60)

DATASET INFORMATION
Training samples: 70,043
Validation samples: 12,361
Batch size: 32
Training batches: 2189
Validation batches: 387

Average GO terms per protein:
  Training: 6.52
  Validation: 6.48


In [8]:
# ==============================================================================
# 5.3. TRAINING SETUP
# ==============================================================================

# Loss function: Binary Cross-Entropy with Logits
# This is ideal for multi-label classification
criterion = nn.BCEWithLogitsLoss()

# Optimizer: Adam with weight decay for regularization
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Learning rate scheduler: Reduce LR when validation loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=3, 
    verbose=True
)

print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
print(f"Loss function: BCEWithLogitsLoss")
print(f"Optimizer: Adam")
print(f"Initial learning rate: 0.001")
print(f"Weight decay: 1e-5")
print(f"LR scheduler: ReduceLROnPlateau")
print(f"Device: {DEVICE}")
print("="*60)

TRAINING CONFIGURATION
Loss function: BCEWithLogitsLoss
Optimizer: Adam
Initial learning rate: 0.001
Weight decay: 1e-5
LR scheduler: ReduceLROnPlateau
Device: cpu


In [9]:
# ==============================================================================
# 5.4. TRAINING LOOP
# ==============================================================================

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    for embeddings, labels in tqdm(train_loader, desc="Training", leave=False):
        embeddings, labels = embeddings.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(embeddings)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)


def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for embeddings, labels in tqdm(val_loader, desc="Validating", leave=False):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            outputs = model(embeddings)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
    
    return total_loss / len(val_loader)


# Training loop
NUM_EPOCHS = 20
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, DEVICE)
    history['val_loss'].append(val_loss)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Print progress
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'best_model.pth')
        print(f"âœ“ Best model saved! (Val Loss: {val_loss:.4f})")

print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)
print(f"Best validation loss: {best_val_loss:.4f}")


STARTING TRAINING

Epoch 1/20
------------------------------------------------------------


Training:   0%|          | 0/2189 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# ==============================================================================
# 5.5. VISUALIZE TRAINING HISTORY
# ==============================================================================

plt.figure(figsize=(12, 5))

# Plot training and validation loss
plt.plot(history['train_loss'], label='Training Loss', marker='o', linewidth=2)
plt.plot(history['val_loss'], label='Validation Loss', marker='s', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss (BCE)', fontsize=12)
plt.title('Training and Validation Loss Over Time', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print summary statistics
print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Final training loss: {history['train_loss'][-1]:.4f}")
print(f"Final validation loss: {history['val_loss'][-1]:.4f}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Improvement: {((history['val_loss'][0] - best_val_loss) / history['val_loss'][0] * 100):.2f}%")
print("="*60)

In [None]:
# ==============================================================================
# 5.6. EVALUATION ON VALIDATION SET
# ==============================================================================

# Load best model
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")

# Get predictions on validation set
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for embeddings, labels in tqdm(val_loader, desc="Generating predictions"):
        embeddings = embeddings.to(DEVICE)
        outputs = model(embeddings)
        
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(outputs)
        
        all_predictions.append(probs.cpu().numpy())
        all_labels.append(labels.numpy())

# Concatenate all batches
Y_val_pred_proba = np.vstack(all_predictions)
Y_val_true = np.vstack(all_labels)

print(f"\nPrediction shape: {Y_val_pred_proba.shape}")
print(f"True labels shape: {Y_val_true.shape}")

# Test different thresholds
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
print("\n" + "="*60)
print("PERFORMANCE AT DIFFERENT THRESHOLDS")
print("="*60)

best_threshold = 0.5
best_f1 = 0

for threshold in thresholds:
    Y_val_pred = (Y_val_pred_proba >= threshold).astype(int)
    
    micro_f1 = f1_score(Y_val_true, Y_val_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(Y_val_true, Y_val_pred, average='macro', zero_division=0)
    samples_f1 = f1_score(Y_val_true, Y_val_pred, average='samples', zero_division=0)
    
    avg_preds = Y_val_pred.sum(axis=1).mean()
    
    print(f"\nThreshold: {threshold:.1f}")
    print(f"  Micro F1:   {micro_f1:.4f}")
    print(f"  Macro F1:   {macro_f1:.4f}")
    print(f"  Samples F1: {samples_f1:.4f}")
    print(f"  Avg predictions per protein: {avg_preds:.2f}")
    
    if micro_f1 > best_f1:
        best_f1 = micro_f1
        best_threshold = threshold

print("\n" + "="*60)
print(f"BEST THRESHOLD: {best_threshold} (Micro F1: {best_f1:.4f})")
print("="*60)

In [None]:
# ==============================================================================
# 5.7. DETAILED PERFORMANCE ANALYSIS
# ==============================================================================

# Use best threshold for final predictions
Y_val_pred_final = (Y_val_pred_proba >= best_threshold).astype(int)

# Calculate comprehensive metrics
from sklearn.metrics import precision_score, recall_score, hamming_loss

print("="*60)
print("FINAL VALIDATION SET PERFORMANCE")
print("="*60)

# Overall metrics
micro_f1 = f1_score(Y_val_true, Y_val_pred_final, average='micro', zero_division=0)
macro_f1 = f1_score(Y_val_true, Y_val_pred_final, average='macro', zero_division=0)
samples_f1 = f1_score(Y_val_true, Y_val_pred_final, average='samples', zero_division=0)

micro_precision = precision_score(Y_val_true, Y_val_pred_final, average='micro', zero_division=0)
micro_recall = recall_score(Y_val_true, Y_val_pred_final, average='micro', zero_division=0)

print(f"\nF1 Scores:")
print(f"  Micro F1:   {micro_f1:.4f}")
print(f"  Macro F1:   {macro_f1:.4f}")
print(f"  Samples F1: {samples_f1:.4f}")

print(f"\nPrecision & Recall (Micro):")
print(f"  Precision: {micro_precision:.4f}")
print(f"  Recall:    {micro_recall:.4f}")

print(f"\nPrediction Statistics:")
print(f"  Hamming Loss: {hamming_loss(Y_val_true, Y_val_pred_final):.4f}")
print(f"  Total true labels: {int(Y_val_true.sum())}")
print(f"  Total predictions: {int(Y_val_pred_final.sum())}")
print(f"  Correct predictions: {int((Y_val_pred_final * Y_val_true).sum())}")

# Per-protein statistics
avg_true = Y_val_true.sum(axis=1).mean()
avg_pred = Y_val_pred_final.sum(axis=1).mean()
print(f"\nAverage GO terms per protein:")
print(f"  True:      {avg_true:.2f}")
print(f"  Predicted: {avg_pred:.2f}")

# Distribution of predictions
print(f"\nPrediction distribution:")
pred_counts = Y_val_pred_final.sum(axis=1)
print(f"  Min predictions: {pred_counts.min()}")
print(f"  Max predictions: {pred_counts.max()}")
print(f"  Median predictions: {np.median(pred_counts):.0f}")
print(f"  Std predictions: {pred_counts.std():.2f}")

print("="*60)

In [None]:
# ==============================================================================
# 5.8. VISUALIZE PREDICTION QUALITY
# ==============================================================================

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Prediction probability distribution
axes[0, 0].hist(Y_val_pred_proba.flatten(), bins=50, edgecolor='black', alpha=0.7, color='steelblue')
axes[0, 0].axvline(x=best_threshold, color='red', linestyle='--', linewidth=2, label=f'Best threshold ({best_threshold})')
axes[0, 0].set_xlabel('Prediction Probability', fontsize=11)
axes[0, 0].set_ylabel('Frequency', fontsize=11)
axes[0, 0].set_title('Distribution of Prediction Probabilities', fontsize=12, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Number of predictions per protein
pred_counts = Y_val_pred_final.sum(axis=1)
true_counts = Y_val_true.sum(axis=1)

axes[0, 1].hist(true_counts, bins=30, alpha=0.5, label='True', color='green', edgecolor='black')
axes[0, 1].hist(pred_counts, bins=30, alpha=0.5, label='Predicted', color='orange', edgecolor='black')
axes[0, 1].axvline(x=true_counts.mean(), color='green', linestyle='--', linewidth=2)
axes[0, 1].axvline(x=pred_counts.mean(), color='orange', linestyle='--', linewidth=2)
axes[0, 1].set_xlabel('Number of GO Terms per Protein', fontsize=11)
axes[0, 1].set_ylabel('Frequency', fontsize=11)
axes[0, 1].set_title('GO Terms per Protein: True vs Predicted', fontsize=12, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: True vs Predicted scatter
axes[1, 0].scatter(true_counts, pred_counts, alpha=0.4, s=20, color='purple')
axes[1, 0].plot([0, true_counts.max()], [0, true_counts.max()], 'r--', linewidth=2, label='Perfect prediction')
axes[1, 0].set_xlabel('True Number of GO Terms', fontsize=11)
axes[1, 0].set_ylabel('Predicted Number of GO Terms', fontsize=11)
axes[1, 0].set_title('True vs Predicted GO Terms per Protein', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Per-label performance (top 20 most frequent labels)
label_f1_scores = []
for i in range(Y_val_true.shape[1]):
    if Y_val_true[:, i].sum() > 0:  # Only calculate for labels that exist
        f1 = f1_score(Y_val_true[:, i], Y_val_pred_final[:, i], zero_division=0)
        label_f1_scores.append((i, f1, Y_val_true[:, i].sum()))

# Sort by frequency and take top 20
label_f1_scores.sort(key=lambda x: x[2], reverse=True)
top_20_labels = label_f1_scores[:20]

indices = [x[0] for x in top_20_labels]
f1_scores = [x[1] for x in top_20_labels]
frequencies = [x[2] for x in top_20_labels]

x_pos = np.arange(len(indices))
axes[1, 1].bar(x_pos, f1_scores, color='teal', alpha=0.7, edgecolor='black')
axes[1, 1].set_xlabel('Top 20 Most Frequent GO Terms (by index)', fontsize=11)
axes[1, 1].set_ylabel('F1 Score', fontsize=11)
axes[1, 1].set_title('Per-Label F1 Score (Top 20 Labels)', fontsize=12, fontweight='bold')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3, axis='y')
axes[1, 1].set_xticks(x_pos[::2])  # Show every other label
axes[1, 1].set_xticklabels(x_pos[::2], rotation=45)

plt.tight_layout()
plt.show()

print(f"\nAverage F1 score across top 20 labels: {np.mean(f1_scores):.4f}")

# 6. MODEL COMPARISON & SUMMARY

Let's compare the deep learning model with other approaches and provide recommendations for further improvements.

In [None]:
# ==============================================================================
# 6.1. MODEL ARCHITECTURE SUMMARY & RECOMMENDATIONS
# ==============================================================================

print("="*70)
print("DEEP LEARNING MODEL SUMMARY")
print("="*70)
print("\nðŸ“Š MODEL ARCHITECTURE:")
print(f"  â€¢ Input dimension: {embedding_dim} (ProtT5 embeddings)")
print(f"  â€¢ Hidden layers: [2048 â†’ 1024 â†’ 512]")
print(f"  â€¢ Output dimension: {num_labels} GO terms")
print(f"  â€¢ Total parameters: {total_params:,}")
print(f"  â€¢ Regularization: Batch Normalization + Dropout (0.3)")

print("\nðŸŽ¯ TRAINING CONFIGURATION:")
print(f"  â€¢ Loss function: BCEWithLogitsLoss")
print(f"  â€¢ Optimizer: Adam (lr=0.001, weight_decay=1e-5)")
print(f"  â€¢ Batch size: {BATCH_SIZE}")
print(f"  â€¢ Epochs trained: {NUM_EPOCHS}")
print(f"  â€¢ Best validation loss: {best_val_loss:.4f}")

print("\nðŸ“ˆ PERFORMANCE METRICS:")
print(f"  â€¢ Micro F1 Score: {micro_f1:.4f}")
print(f"  â€¢ Macro F1 Score: {macro_f1:.4f}")
print(f"  â€¢ Precision (micro): {micro_precision:.4f}")
print(f"  â€¢ Recall (micro): {micro_recall:.4f}")
print(f"  â€¢ Optimal threshold: {best_threshold}")

print("\nðŸ’¡ RECOMMENDATIONS FOR IMPROVEMENT:")
print("  1. Fine-tune ProtT5 embeddings (currently using pre-computed)")
print("  2. Implement hierarchical loss considering GO DAG structure")
print("  3. Use focal loss to handle label imbalance")
print("  4. Try attention mechanisms to focus on important embedding regions")
print("  5. Ensemble multiple models trained with different random seeds")
print("  6. Implement per-ontology models (BPO, MFO, CCO separately)")
print("  7. Add graph neural network layers to leverage GO hierarchy")
print("  8. Use information accretion (IA) scores as sample weights")

print("\nðŸ”§ HYPERPARAMETER TUNING SUGGESTIONS:")
print("  â€¢ Hidden layer dimensions: Try [1024, 512, 256] or [3072, 2048, 1024]")
print("  â€¢ Dropout rate: Experiment with 0.2-0.5")
print("  â€¢ Learning rate: Try 5e-4, 1e-3, 2e-3")
print("  â€¢ Batch size: Test 16, 64, 128")
print("  â€¢ Add residual connections for deeper networks")

print("="*70)