In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader

class ProteinDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

class ProteinFunctionPredictor(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=512, num_classes=None):
        super().__init__()
        
        # Feature integration layer
        self.feature_integration = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Transformer encoder layer
        self.transformer_encoder = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1
        )
        
        # Multi-label classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Feature integration
        x = self.feature_integration(x)
        
        # Add positional dimension for transformer
        x = x.unsqueeze(0)  # [1, batch_size, hidden_dim]
        
        # Transformer encoding
        x = self.transformer_encoder(x)
        
        # Remove positional dimension
        x = x.squeeze(0)
        
        # Classification
        return self.classifier(x)
    
def calculate_metrics(outputs, labels, threshold=0.5):
    predictions = (outputs >= threshold).float()
    accuracy = (predictions == labels).float().mean().item()
    
    # Per-class accuracy
    class_acc = (predictions == labels).float().mean(dim=0)
    
    # True positives, false positives, false negatives
    tp = ((predictions == 1) & (labels == 1)).float().sum(dim=0)
    fp = ((predictions == 1) & (labels == 0)).float().sum(dim=0)
    fn = ((predictions == 0) & (labels == 1)).float().sum(dim=0)
    
    # Precision, recall, F1
    precision = tp / (tp + fp + 1e-10)
    recall = tp / (tp + fn + 1e-10)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
    
    return {
        'accuracy': accuracy,
        'mean_precision': precision.mean().item(),
        'mean_recall': recall.mean().item(),
        'mean_f1': f1.mean().item()
    }

def train_model(model, train_loader, val_loader, num_epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.BCELoss()
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_metrics = {'accuracy': 0, 'mean_precision': 0, 'mean_recall': 0, 'mean_f1': 0}
        
        for batch_embeddings, batch_labels in train_loader:
            batch_embeddings = batch_embeddings
            batch_labels = batch_labels
            
            optimizer.zero_grad()
            outputs = model(batch_embeddings)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            batch_metrics = calculate_metrics(outputs, batch_labels)
            for k in train_metrics:
                train_metrics[k] += batch_metrics[k]
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_metrics = {'accuracy': 0, 'mean_precision': 0, 'mean_recall': 0, 'mean_f1': 0}
        
        with torch.no_grad():
            for batch_embeddings, batch_labels in val_loader:
                batch_embeddings = batch_embeddings.to(device)
                batch_labels = batch_labels.to(device)
                outputs = model(batch_embeddings)
                loss = criterion(outputs, batch_labels)
                val_loss += loss.item()
                
                batch_metrics = calculate_metrics(outputs, batch_labels)
                for k in val_metrics:
                    val_metrics[k] += batch_metrics[k]
        
        # Average metrics
        train_metrics = {k: v/len(train_loader) for k, v in train_metrics.items()}
        val_metrics = {k: v/len(val_loader) for k, v in val_metrics.items()}
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Training Accuracy: {train_metrics["accuracy"]:.4f}')
        print(f'Training F1: {train_metrics["mean_f1"]:.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        print(f'Validation Accuracy: {val_metrics["accuracy"]:.4f}')
        print(f'Validation F1: {val_metrics["mean_f1"]:.4f}\n')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pt')

def prepare_data(train_terms, train_embeddings, train_protein_ids):
    # Create mapping of protein IDs to their indices
    protein_to_idx = {pid: idx for idx, pid in enumerate(train_protein_ids)}
    
    # Get unique GO terms and create mapping
    unique_terms = train_terms['term'].unique()
    term_to_idx = {term: idx for idx, term in enumerate(unique_terms)}
    
    # Create label matrix
    num_proteins = len(train_protein_ids)
    num_terms = len(unique_terms)
    labels = np.zeros((num_proteins, num_terms))
    
    # Fill label matrix
    for _, row in train_terms.iterrows():
        protein_idx = protein_to_idx.get(row['EntryID'])
        if protein_idx is not None:  # Skip proteins not in embeddings
            term_idx = term_to_idx[row['term']]
            labels[protein_idx, term_idx] = 1
    
    return labels, len(unique_terms)

In [2]:
# Prepare data
import pandas as pd

train_terms = pd.read_csv("CAFA 5 Protein Function Prediction/Train/train_terms.tsv", sep="\t")
train_protein_ids = np.load('T5 Embeds Archive/train_ids.npy')
train_embeddings = np.load('T5 Embeds Archive/train_embeds.npy')


labels, num_classes = prepare_data(train_terms, train_embeddings, train_protein_ids)

# Create datasets
dataset = ProteinDataset(train_embeddings, labels)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Initialize model
model = ProteinFunctionPredictor(num_classes=num_classes)

In [5]:
column_num = train_embeddings.shape[1]
train_df = pd.DataFrame(train_embeddings, columns=[f"Column_{i}" for i in range(1, column_num+1)])

device = torch.device('cpu')

train_model(model, train_loader, val_loader)

Epoch 1/10:
Training Loss: 0.0049
Training Accuracy: 0.9989
Training F1: 0.0006
Validation Loss: 0.0045
Validation Accuracy: 0.9989
Validation F1: 0.0009

Epoch 2/10:
Training Loss: 0.0045
Training Accuracy: 0.9989
Training F1: 0.0008
Validation Loss: 0.0043
Validation Accuracy: 0.9989
Validation F1: 0.0009

Epoch 3/10:
Training Loss: 0.0044
Training Accuracy: 0.9989
Training F1: 0.0010
Validation Loss: 0.0042
Validation Accuracy: 0.9989
Validation F1: 0.0010



KeyboardInterrupt: 

: 