In [10]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def load_sequences(filename):
    sequences = {}
    current_id = None
    current_seq = []
    
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if current_id is not None:
                    sequences[current_id] = ''.join(current_seq)
                current_id = line[1:].split()[0]
                current_seq = []
            else:
                current_seq.append(line)
        
        if current_id is not None:
            sequences[current_id] = ''.join(current_seq)
    
    return sequences

class ProteinDataset(Dataset):
    def __init__(self, sequences, labels):
        self.features = torch.stack([prepare_input_features(seq) for seq in sequences])
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class ProteinFunctionPredictor(nn.Module):
    def __init__(self, num_features=21, hidden_dims=[1024, 512, 256], num_classes=None):
        super().__init__()
        
        layers = []
        prev_dim = num_features
        
        for dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.BatchNorm1d(dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            prev_dim = dim
            
        layers.append(nn.Linear(prev_dim, num_classes))
        layers.append(nn.Sigmoid())
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

def prepare_input_features(protein_sequence):
    aa_dict = {aa: idx for idx, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
    features = torch.zeros(len(aa_dict))
    
    for aa in protein_sequence:
        if aa in aa_dict:
            features[aa_dict[aa]] += 1
            
    seq_length = torch.tensor([len(protein_sequence)])
    features = features / (len(protein_sequence) + 1e-8)
    features = torch.cat([features, seq_length])
    
    return features

def calculate_metrics(outputs, labels, threshold=0.5):
    predictions = (outputs >= threshold).cpu().numpy()
    labels = labels.cpu().numpy()
    
    accuracy = accuracy_score(labels.flatten(), predictions.flatten())
    precision, recall, f1, _ = precision_recall_fscore_support(labels.flatten(), predictions.flatten(), average='binary')
    
    # Calculate per-class metrics
    class_accuracy = (predictions == labels).mean(axis=0)
    mean_class_accuracy = class_accuracy.mean()
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'mean_class_accuracy': mean_class_accuracy
    }

# Load data
print("Loading data...")
train_terms = pd.read_csv("CAFA 5 Protein Function Prediction/Train/train_terms.tsv", sep="\t")
sequences = load_sequences("CAFA 5 Protein Function Prediction/Train/train_sequences.fasta")
print(f"Loaded {len(sequences)} sequences")

# Prepare labels
print("Preparing labels...")
unique_terms = train_terms['term'].unique()
term_to_idx = {term: idx for idx, term in enumerate(unique_terms)}
num_classes = len(unique_terms)

# Create label matrix
protein_list = list(sequences.keys())
labels = np.zeros((len(protein_list), num_classes))
protein_to_idx = {pid: idx for idx, pid in enumerate(protein_list)}

for _, row in train_terms.iterrows():
    if row['EntryID'] in protein_to_idx:
        protein_idx = protein_to_idx[row['EntryID']]
        term_idx = term_to_idx[row['term']]
        labels[protein_idx, term_idx] = 1

print(f"Label matrix shape: {labels.shape}")

# Create dataset
print("Creating dataset...")
sequence_list = [sequences[pid] for pid in protein_list]
dataset = ProteinDataset(sequence_list, labels)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

def train_model(model, train_loader, val_loader, num_epochs=10):
    device = torch.device('cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()
    best_val_f1 = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        epoch_train_metrics = {
            'accuracy': 0,
            'precision': 0,
            'recall': 0,
            'f1': 0,
            'mean_class_accuracy': 0
        }
        
        for batch_idx, (features, labels) in enumerate(train_loader):
            features, labels = features.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            batch_metrics = calculate_metrics(outputs, labels)
            
            for k in epoch_train_metrics:
                epoch_train_metrics[k] += batch_metrics[k]
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}:')
                print(f'Loss: {loss.item():.4f}')
                print(f'Batch Accuracy: {batch_metrics["accuracy"]:.4f}')
                print(f'Batch F1: {batch_metrics["f1"]:.4f}\n')
        
        # Average training metrics
        num_batches = len(train_loader)
        train_loss /= num_batches
        epoch_train_metrics = {k: v/num_batches for k, v in epoch_train_metrics.items()}
        
        # Validation phase
        model.eval()
        val_loss = 0
        epoch_val_metrics = {
            'accuracy': 0,
            'precision': 0,
            'recall': 0,
            'f1': 0,
            'mean_class_accuracy': 0
        }
        
        with torch.no_grad():
            for features, labels in val_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                batch_metrics = calculate_metrics(outputs, labels)
                for k in epoch_val_metrics:
                    epoch_val_metrics[k] += batch_metrics[k]
        
        # Average validation metrics
        val_loss /= len(val_loader)
        epoch_val_metrics = {k: v/len(val_loader) for k, v in epoch_val_metrics.items()}
        
        # Print epoch results
        print(f'\nEpoch {epoch+1}/{num_epochs} Results:')
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Training Accuracy: {epoch_train_metrics["accuracy"]:.4f}')
        print(f'Training F1: {epoch_train_metrics["f1"]:.4f}')
        print(f'Training Precision: {epoch_train_metrics["precision"]:.4f}')
        print(f'Training Recall: {epoch_train_metrics["recall"]:.4f}')
        print(f'\nValidation Loss: {val_loss:.4f}')
        print(f'Validation Accuracy: {epoch_val_metrics["accuracy"]:.4f}')
        print(f'Validation F1: {epoch_val_metrics["f1"]:.4f}')
        print(f'Validation Precision: {epoch_val_metrics["precision"]:.4f}')
        print(f'Validation Recall: {epoch_val_metrics["recall"]:.4f}\n')
        
        # Save best model
        if epoch_val_metrics['f1'] > best_val_f1:
            best_val_f1 = epoch_val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_metrics': epoch_train_metrics,
                'val_metrics': epoch_val_metrics
            }, 'best_model.pt')
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_metrics': epoch_train_metrics,
            'val_metrics': epoch_val_metrics
        }, f'checkpoint_epoch_{epoch}.pt')

print("Initializing model...")
model = ProteinFunctionPredictor(num_classes=num_classes)
print("Starting training...")
train_model(model, train_loader, val_loader)

Loading data...
Loaded 142246 sequences
Preparing labels...
Label matrix shape: (142246, 31466)
Creating dataset...
Initializing model...
Starting training...
Epoch 1, Batch 0:
Loss: 0.7204
Batch Accuracy: 0.4980
Batch F1: 0.0016

Epoch 1, Batch 100:
Loss: 0.0124
Batch Accuracy: 0.9992
Batch F1: 0.1512

Epoch 1, Batch 200:
Loss: 0.0071
Batch Accuracy: 0.9990
Batch F1: 0.1209

Epoch 1, Batch 300:
Loss: 0.0072
Batch Accuracy: 0.9986
Batch F1: 0.1301

Epoch 1, Batch 400:
Loss: 0.0053
Batch Accuracy: 0.9989
Batch F1: 0.1579

Epoch 1, Batch 500:
Loss: 0.0066
Batch Accuracy: 0.9988
Batch F1: 0.1247

Epoch 1, Batch 600:
Loss: 0.0077
Batch Accuracy: 0.9985
Batch F1: 0.1084

Epoch 1, Batch 700:
Loss: 0.0057
Batch Accuracy: 0.9988
Batch F1: 0.1396

Epoch 1, Batch 800:
Loss: 0.0064
Batch Accuracy: 0.9987
Batch F1: 0.1264

Epoch 1, Batch 900:
Loss: 0.0050
Batch Accuracy: 0.9991
Batch F1: 0.1788

Epoch 1, Batch 1000:
Loss: 0.0052
Batch Accuracy: 0.9989
Batch F1: 0.1642

Epoch 1, Batch 1100:
Loss: 0