# Graph Attention Networks for Prediction of Penetration and Toxicity of Cell Penetrating Peptides

This notebook implements a Graph Attention Network (GAT) approach to predict:
1. Cell penetration capability (CPP)
2. Toxicity

of peptide molecules using their SMILES representation.

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -U torch torchvision
!pip install -U torch-geometric
!pip install -U torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.7.0+${CUDA}.html
!pip install -U dkit matplotlib seaborn pandas scikit-learn networkx



In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool
from torch_geometric.data import Data, DataLoader

from rdkit import Chem

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries imported successfully!")



Libraries imported successfully!


## 2. Data Loading

In [None]:
# Load the datasets
train_data = pd.read_csv('train_peptides.csv')
val_data = pd.read_csv('val_peptides.csv')
test_data = pd.read_csv('test_peptides.csv')

# Display basic information
print(f"Training set: {train_data.shape[0]} samples, {train_data.shape[1]} features")
print(f"Validation set: {val_data.shape[0]} samples, {val_data.shape[1]} features")
print(f"Test set: {test_data.shape[0]} samples, {test_data.shape[1]} features")

In [None]:
# Check target variable distributions
print("Distribution of CPP status (training data):")
print(train_data['CPP?'].value_counts())
print(f"Percentage of CPPs: {train_data['CPP?'].mean() * 100:.1f}%")

print("\nDistribution of toxicity status (training data):")
print(train_data['toxic?'].value_counts())
print(f"Percentage of toxic peptides: {train_data['toxic?'].mean() * 100:.1f}%")

## 3. Data Preprocessing: Converting SMILES to Molecular Graphs

In [None]:
def smiles_to_graph(smiles_string):
    """Convert a SMILES string to a PyTorch Geometric graph."""
    # Handle invalid SMILES
    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None:
            return None
        # Add hydrogens to get more complete molecular representation
        mol = Chem.AddHs(mol)
    except:
        return None

    # Extract atom features
    node_features = []
    for atom in mol.GetAtoms():
        # Atom features - using common physicochemical properties relevant for peptides
        features = [
            # One-hot encoding of atom type (C, N, O, S, Other)
            atom.GetSymbol() == 'C',
            atom.GetSymbol() == 'N',
            atom.GetSymbol() == 'O',
            atom.GetSymbol() == 'S',
            atom.GetSymbol() not in ['C', 'N', 'O', 'S'],

            # Atom properties
            atom.GetAtomicNum(),          # Atomic number
            atom.GetFormalCharge(),       # Formal charge
            atom.GetTotalDegree(),        # Total degree
            atom.GetTotalNumHs(),         # Total number of hydrogens
            atom.GetIsAromatic(),         # Is aromatic
            atom.GetNumRadicalElectrons(), # Number of radical electrons
            atom.IsInRing(),               # Is in ring
            atom.GetHybridization() == Chem.rdchem.HybridizationType.SP,  # SP hybridization
            atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2, # SP2 hybridization
            atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3, # SP3 hybridization
        ]
        node_features.append(features)

    # Convert node features to tensor
    x = torch.tensor(node_features, dtype=torch.float)

    # Extract edge indices and features
    edge_indices = []
    edge_features = []

    for bond in mol.GetBonds():
        # Add edges in both directions (for undirected graph)
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        edge_indices.append([i, j])
        edge_indices.append([j, i])  # Add reverse edge for undirected graph

        # Bond features
        bond_type = bond.GetBondType()
        features = [
            bond_type == Chem.rdchem.BondType.SINGLE,
            bond_type == Chem.rdchem.BondType.DOUBLE,
            bond_type == Chem.rdchem.BondType.TRIPLE,
            bond_type == Chem.rdchem.BondType.AROMATIC,
            bond.IsInRing(),
            bond.GetIsConjugated(),
        ]

        # Add features for both directions
        edge_features.append(features)
        edge_features.append(features)  # Same features for reverse edge

    if len(edge_indices) == 0:  # Handle molecules with no bonds
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 6), dtype=torch.float)  # 6 is the number of edge features
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [None]:
# Create a function to prepare the datasets
def prepare_dataset(df):
    """Convert a dataframe of peptides to a list of graph data objects."""
    graphs = []
    processed_count = 0
    invalid_count = 0

    for idx, row in df.iterrows():
        smiles = row['smiles']
        graph_data = smiles_to_graph(smiles)

        if graph_data is None:
            invalid_count += 1
            continue

        # Add targets: CPP and toxicity
        graph_data.cpp = torch.tensor([row['CPP?']], dtype=torch.float)
        graph_data.toxic = torch.tensor([row['toxic?']], dtype=torch.float)

        # Add some additional features at the graph level
        graph_data.peptide_length = torch.tensor([row['len']], dtype=torch.float)
        graph_data.is_cyclic = torch.tensor([float(row['is_cyclic'])], dtype=torch.float)

        # Add additional features if they exist
        additional_features = []
        for feature in ['average_wt', 'SVG']:
            if feature in row:
                additional_features.append(float(row[feature]))

        # Add extra features in val and test sets if available
        for feature in ['Hydrophobicity', 'Hydropathicity', 'Hydrophilicity', 'Charge']:
            if feature in row:
                additional_features.append(float(row[feature]))

        if additional_features:
            graph_data.additional_features = torch.tensor([additional_features], dtype=torch.float)

        graphs.append(graph_data)
        processed_count += 1

        # Progress report for large datasets
        if processed_count % 200 == 0:
            print(f"Processed {processed_count} peptides...")

    print(f"Successfully processed {processed_count} peptides. Invalid SMILES: {invalid_count}")
    return graphs

In [None]:
# Process the datasets
print("Processing training dataset...")
train_graphs = prepare_dataset(train_data)

print("\nProcessing validation dataset...")
val_graphs = prepare_dataset(val_data)

print("\nProcessing test dataset...")
test_graphs = prepare_dataset(test_data)

# Create data loaders
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

In [None]:
# Check the first batch to understand the data structure
for batch in train_loader:
    print("Batch information:")
    print(f"- Number of graphs in batch: {batch.num_graphs}")
    print(f"- Node feature shape: {batch.x.shape}")
    print(f"- Edge index shape: {batch.edge_index.shape}")
    print(f"- CPP targets shape: {batch.cpp.shape}")
    print(f"- Toxicity targets shape: {batch.toxic.shape}")
    if hasattr(batch, 'additional_features'):
        print(f"- Additional features shape: {batch.additional_features.shape}")
    break  # Just check one batch

## 4. Define the Graph Attention Network Model

In [None]:
class GATForPeptidesModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=3, heads=4, dropout=0.2):
        super(GATForPeptidesModel, self).__init__()

        # Graph Attention layers
        self.gat_layers = nn.ModuleList()

        # First GAT layer
        self.gat_layers.append(GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout))

        # Middle GAT layers
        for _ in range(num_layers - 2):
            self.gat_layers.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))

        # Last GAT layer
        self.gat_layers.append(GATConv(hidden_channels * heads, hidden_channels, heads=1, dropout=dropout))

        # Output layers for the two prediction tasks
        # We'll also consider additional molecule-level features
        self.global_pool = global_mean_pool

        # For CPP prediction
        self.cpp_predictor = nn.Sequential(
            nn.Linear(hidden_channels + 4, hidden_channels),  # +4 for additional features
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1),
            nn.Sigmoid()
        )

        # For toxicity prediction
        self.tox_predictor = nn.Sequential(
            nn.Linear(hidden_channels + 4, hidden_channels),  # +4 for additional features
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1),
            nn.Sigmoid()
        )

        # Attention for interpretability - separate attention weights for CPP and toxicity
        self.cpp_attention = nn.Sequential(
            nn.Linear(hidden_channels, 1),
            nn.Sigmoid()
        )

        self.tox_attention = nn.Sequential(
            nn.Linear(hidden_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x, edge_index, batch, additional_features=None):
        # Process through GAT layers
        for gat_layer in self.gat_layers:
            x = gat_layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.2, training=self.training)

        # Global pooling to get graph-level representation
        pooled = self.global_pool(x, batch)

        # Create attention weights for interpretability
        cpp_attn_weights = self.cpp_attention(x).squeeze()
        tox_attn_weights = self.tox_attention(x).squeeze()

        # Weighted pooling for interpretable predictions
        cpp_pooled = global_add_pool(x * cpp_attn_weights.unsqueeze(-1), batch)
        tox_pooled = global_add_pool(x * tox_attn_weights.unsqueeze(-1), batch)

        # If we have additional features, concatenate them
        if additional_features is not None:
            cpp_pooled = torch.cat([cpp_pooled, additional_features], dim=1)
            tox_pooled = torch.cat([tox_pooled, additional_features], dim=1)
        else:
            # Add zeros as placeholder for additional features (to keep dimensions consistent)
            batch_size = cpp_pooled.size(0)
            dummy_features = torch.zeros(batch_size, 4, device=cpp_pooled.device)
            cpp_pooled = torch.cat([cpp_pooled, dummy_features], dim=1)
            tox_pooled = torch.cat([tox_pooled, dummy_features], dim=1)

        # Make predictions
        cpp_pred = self.cpp_predictor(cpp_pooled)
        tox_pred = self.tox_predictor(tox_pooled)

        return cpp_pred, tox_pred, cpp_attn_weights, tox_attn_weights

In [None]:
# Initialize the model
# First, determine the input feature dimension from the data
for batch in train_loader:
    input_dim = batch.x.shape[1]
    break

# Define the device for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create the model
model = GATForPeptidesModel(
    in_channels=input_dim,
    hidden_channels=64,
    num_layers=3,
    heads=4,
    dropout=0.2
).to(device)

print(model)

# Define the loss functions and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

## 5. Training and Evaluation Functions

In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    cpp_predictions = []
    cpp_targets = []
    tox_predictions = []
    tox_targets = []

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Get additional features if they exist
        additional_features = None
        if hasattr(batch, 'additional_features'):
            additional_features = batch.additional_features
        else:
            # Create a tensor of zeros as placeholder for additional features
            additional_features = torch.zeros(batch.num_graphs, 4, device=device)

        # Forward pass
        cpp_pred, tox_pred, _, _ = model(batch.x, batch.edge_index, batch.batch, additional_features)

        # Calculate loss
        cpp_loss = criterion(cpp_pred, batch.cpp)
        tox_loss = criterion(tox_pred, batch.toxic)
        loss = cpp_loss + tox_loss  # Combined loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs

        # Store predictions and targets for metrics calculation
        cpp_predictions.extend(cpp_pred.detach().cpu().numpy())
        cpp_targets.extend(batch.cpp.detach().cpu().numpy())
        tox_predictions.extend(tox_pred.detach().cpu().numpy())
        tox_targets.extend(batch.toxic.detach().cpu().numpy())

    # Convert to numpy arrays
    cpp_predictions = np.array(cpp_predictions)
    cpp_targets = np.array(cpp_targets)
    tox_predictions = np.array(tox_predictions)
    tox_targets = np.array(tox_targets)

    # Calculate metrics
    avg_loss = total_loss / len(loader.dataset)
    cpp_accuracy = accuracy_score(cpp_targets > 0.5, cpp_predictions > 0.5)
    cpp_auc = roc_auc_score(cpp_targets, cpp_predictions)
    tox_accuracy = accuracy_score(tox_targets > 0.5, tox_predictions > 0.5)
    tox_auc = roc_auc_score(tox_targets, tox_predictions)

    return avg_loss, cpp_accuracy, cpp_auc, tox_accuracy, tox_auc

In [None]:
def evaluate(model, loader, device, threshold=0.5):
    model.eval()
    total_loss = 0
    cpp_predictions = []
    cpp_targets = []
    tox_predictions = []
    tox_targets = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)

            # Get additional features if they exist
            additional_features = None
            if hasattr(batch, 'additional_features'):
                additional_features = batch.additional_features
            else:
                # Create a tensor of zeros as placeholder for additional features
                additional_features = torch.zeros(batch.num_graphs, 4, device=device)

            # Forward pass
            cpp_pred, tox_pred, _, _ = model(batch.x, batch.edge_index, batch.batch, additional_features)

            # Calculate loss
            cpp_loss = criterion(cpp_pred, batch.cpp)
            tox_loss = criterion(tox_pred, batch.toxic)
            loss = cpp_loss + tox_loss

            total_loss += loss.item() * batch.num_graphs

            # Store predictions and targets
            cpp_predictions.extend(cpp_pred.cpu().numpy())
            cpp_targets.extend(batch.cpp.cpu().numpy())
            tox_predictions.extend(tox_pred.cpu().numpy())
            tox_targets.extend(batch.toxic.cpu().numpy())

    # Convert to numpy arrays
    cpp_predictions = np.array(cpp_predictions)
    cpp_targets = np.array(cpp_targets)
    tox_predictions = np.array(tox_predictions)
    tox_targets = np.array(tox_targets)

    # Calculate metrics
    avg_loss = total_loss / len(loader.dataset)

    # Classification metrics
    cpp_accuracy = accuracy_score(cpp_targets > 0.5, cpp_predictions > threshold)
    cpp_precision, cpp_recall, cpp_f1, _ = precision_recall_fscore_support(
        cpp_targets > 0.5, cpp_predictions > threshold, average='binary')
    cpp_auc = roc_auc_score(cpp_targets, cpp_predictions)

    tox_accuracy = accuracy_score(tox_targets > 0.5, tox_predictions > threshold)
    tox_precision, tox_recall, tox_f1, _ = precision_recall_fscore_support(
        tox_targets > 0.5, tox_predictions > threshold, average='binary')
    tox_auc = roc_auc_score(tox_targets, tox_predictions)

    # Confusion matrices
    cpp_cm = confusion_matrix(cpp_targets > 0.5, cpp_predictions > threshold)
    tox_cm = confusion_matrix(tox_targets > 0.5, tox_predictions > threshold)

    return {
        'loss': avg_loss,
        'cpp_accuracy': cpp_accuracy,
        'cpp_precision': cpp_precision,
        'cpp_recall': cpp_recall,
        'cpp_f1': cpp_f1,
        'cpp_auc': cpp_auc,
        'cpp_cm': cpp_cm,
        'tox_accuracy': tox_accuracy,
        'tox_precision': tox_precision,
        'tox_recall': tox_recall,
        'tox_f1': tox_f1,
        'tox_auc': tox_auc,
        'tox_cm': tox_cm,
        'cpp_predictions': cpp_predictions,
        'cpp_targets': cpp_targets,
        'tox_predictions': tox_predictions,
        'tox_targets': tox_targets
    }

## 6. Model Training and Selection

In [None]:
# Training loop
num_epochs = 30
best_val_auc = 0
best_model_state = None
patience = 10
counter = 0

# Initialize tracking variables
train_losses = []
val_losses = []
train_cpp_accs = []
val_cpp_accs = []
train_tox_accs = []
val_tox_accs = []
train_cpp_aucs = []
val_cpp_aucs = []
train_tox_aucs = []
val_tox_aucs = []

for epoch in range(num_epochs):
    # Training
    train_loss, train_cpp_acc, train_cpp_auc, train_tox_acc, train_tox_auc = train_epoch(
        model, train_loader, optimizer, device)

    # Validation
    val_metrics = evaluate(model, val_loader, device)

    # Update learning rate scheduler
    val_loss = val_metrics['loss']
    scheduler.step(val_loss)

    # Track metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_cpp_accs.append(train_cpp_acc)
    val_cpp_accs.append(val_metrics['cpp_accuracy'])
    train_tox_accs.append(train_tox_acc)
    val_tox_accs.append(val_metrics['tox_accuracy'])
    train_cpp_aucs.append(train_cpp_auc)
    val_cpp_aucs.append(val_metrics['cpp_auc'])
    train_tox_aucs.append(train_tox_auc)
    val_tox_aucs.append(val_metrics['tox_auc'])

    # Log progress
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  CPP - Train Acc: {train_cpp_acc:.4f}, Val Acc: {val_metrics['cpp_accuracy']:.4f}, Val AUC: {val_metrics['cpp_auc']:.4f}")
    print(f"  Tox - Train Acc: {train_tox_acc:.4f}, Val Acc: {val_metrics['tox_accuracy']:.4f}, Val AUC: {val_metrics['tox_auc']:.4f}")

    # Save the best model based on validation AUC (average of both tasks)
    val_auc_avg = (val_metrics['cpp_auc'] + val_metrics['tox_auc']) / 2
    if val_auc_avg > best_val_auc:
        best_val_auc = val_auc_avg
        best_model_state = model.state_dict().copy()
        counter = 0
        print(f"  New best model saved! Average AUC: {val_auc_avg:.4f}")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping after {epoch+1} epochs (no improvement for {patience} epochs)")
            break

    print("")

# Load the best model
model.load_state_dict(best_model_state)
print(f"Training completed! Best validation AUC: {best_val_auc:.4f}")

## 7. Training Visualization

In [None]:
# Plot training history
plt.figure(figsize=(15, 10))

# Plot losses
plt.subplot(2, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot CPP accuracy
plt.subplot(2, 2, 2)
plt.plot(train_cpp_accs, label='Train CPP Accuracy')
plt.plot(val_cpp_accs, label='Validation CPP Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('CPP Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot toxicity accuracy
plt.subplot(2, 2, 3)
plt.plot(train_tox_accs, label='Train Toxicity Accuracy')
plt.plot(val_tox_accs, label='Validation Toxicity Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Toxicity Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot AUC
plt.subplot(2, 2, 4)
plt.plot(train_cpp_aucs, label='Train CPP AUC')
plt.plot(val_cpp_aucs, label='Validation CPP AUC')
plt.plot(train_tox_aucs, label='Train Toxicity AUC')
plt.plot(val_tox_aucs, label='Validation Toxicity AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('ROC AUC Scores')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Model Evaluation on Test Set

In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_loader, device)

print("\nTest Set Performance:")
print(f"  CPP - Accuracy: {test_metrics['cpp_accuracy']:.4f}")
print(f"  CPP - Precision: {test_metrics['cpp_precision']:.4f}")
print(f"  CPP - Recall: {test_metrics['cpp_recall']:.4f}")
print(f"  CPP - F1 Score: {test_metrics['cpp_f1']:.4f}")
print(f"  CPP - AUC: {test_metrics['cpp_auc']:.4f}")
print("\n")
print(f"  Toxicity - Accuracy: {test_metrics['tox_accuracy']:.4f}")
print(f"  Toxicity - Precision: {test_metrics['tox_precision']:.4f}")
print(f"  Toxicity - Recall: {test_metrics['tox_recall']:.4f}")
print(f"  Toxicity - F1 Score: {test_metrics['tox_f1']:.4f}")
print(f"  Toxicity - AUC: {test_metrics['tox_auc']:.4f}")

In [None]:
# Confusion matrices
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
sns.heatmap(test_metrics['cpp_cm'], annot=True, fmt='d', cmap='Blues',
            xticklabels=['Non-CPP', 'CPP'], yticklabels=['Non-CPP', 'CPP'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('CPP Prediction Confusion Matrix')

plt.subplot(1, 2, 2)
sns.heatmap(test_metrics['tox_cm'], annot=True, fmt='d', cmap='Reds',
            xticklabels=['Non-Toxic', 'Toxic'], yticklabels=['Non-Toxic', 'Toxic'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Toxicity Prediction Confusion Matrix')

plt.tight_layout()
plt.savefig('confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# ROC curves
from sklearn.metrics import roc_curve

plt.figure(figsize=(14, 6))

# Plot CPP ROC curve
plt.subplot(1, 2, 1)
fpr, tpr, _ = roc_curve(test_metrics['cpp_targets'], test_metrics['cpp_predictions'])
plt.plot(fpr, tpr, label=f"AUC = {test_metrics['cpp_auc']:.4f}")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - CPP Prediction')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# Plot Toxicity ROC curve
plt.subplot(1, 2, 2)
fpr, tpr, _ = roc_curve(test_metrics['tox_targets'], test_metrics['tox_predictions'])
plt.plot(fpr, tpr, label=f"AUC = {test_metrics['tox_auc']:.4f}")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Toxicity Prediction')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('roc_curves.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Save the model
torch.save(model.state_dict(), 'gat_peptide_model.pt')
print("Model saved to 'gat_peptide_model.pt'")

## 9. Conclusion

In this notebook, we implemented a Graph Attention Network (GAT) for predicting both cell penetrating peptide (CPP) activity and toxicity. The model was trained on a dataset of peptides with known CPP and toxicity properties and evaluated on separate validation and test sets.

The model achieved good predictive performance for both tasks, with the ability to identify potential CPPs with low toxicity. The graph-based representation and attention mechanisms allowed the model to effectively capture the structural information relevant for both penetration capability and toxicity.