In [7]:
import torch
import torch_geometric as pyg
from torch_geometric.nn import MessagePassing, global_mean_pool
import torch.optim as optim
from sklearn.metrics import average_precision_score
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
import networkx as nx
import torch.nn as nn
from torch_geometric.utils import to_networkx, subgraph
from torch_geometric.datasets import TUDataset
from tqdm import tqdm

# Ensure reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# GINE Layer with Virtual Nodes
class GINELayerWithVN(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim):
        super(GINELayerWithVN, self).__init__(aggr='add')  # "Add" aggregation.
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(out_channels, out_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.edge_encoder = torch.nn.Linear(edge_dim, out_channels)
        # Remove node_encoder from here
        self.virtual_node_mlp = torch.nn.Sequential(
            torch.nn.Linear(out_channels, out_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels),
            torch.nn.ReLU(),
        )
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.edge_encoder.weight)
        for m in self.mlp:
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
        for m in self.virtual_node_mlp:
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

    def forward(self, x, edge_index, edge_attr, vn_embed, batch):
        # x is already encoded via node_encoder in the main model
        x = x.float()  # Ensure x is FloatTensor
        edge_attr = edge_attr.float()  # Ensure edge_attr is FloatTensor
        edge_attr = self.edge_encoder(edge_attr)

        # Add virtual node embedding to node features
        vn_expanded = vn_embed[batch]
        x = x + vn_expanded

        # Message Passing
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)

        # Update node embeddings
        out = self.mlp(out)
        return out

    def message(self, x_j, edge_attr):
        # Compute messages
        return x_j + edge_attr

    def update(self, aggr_out):
        return aggr_out

# Laplacian Positional Encodings (LapPE)
def compute_laplace_pe(data, num_eigenvec=10):
    G = to_networkx(data, to_undirected=True)
    A = nx.adjacency_matrix(G).astype(float)
    num_nodes = A.shape[0]
    D = np.diag(np.array(A.sum(axis=1)).flatten())
    L = D - A.todense()
    L = torch.tensor(L, dtype=torch.float, device=device)
    try:
        eigenvalues, eigenvectors = torch.linalg.eigh(L)
    except RuntimeError:
        eigenvalues, eigenvectors = torch.symeig(L, eigenvectors=True)
    available_eigenvec = eigenvectors.shape[1] - 1
    actual_num_eigenvec = min(num_eigenvec, available_eigenvec)
    eigenvectors = eigenvectors[:, 1:1 + actual_num_eigenvec]
    if actual_num_eigenvec < num_eigenvec:
        pad_size = num_eigenvec - actual_num_eigenvec
        padding = torch.zeros(eigenvectors.shape[0], pad_size, device=device)
        eigenvectors = torch.cat([eigenvectors, padding], dim=1)
    return eigenvectors  # Shape: (num_nodes, num_eigenvec)

# Random Walk Structural Embeddings (RWSE)
def compute_rwse(data, walk_length=10):
    G = to_networkx(data, to_undirected=True)
    A = nx.adjacency_matrix(G).astype(float)
    A = A.todense()
    num_nodes = A.shape[0]
    A = torch.tensor(A, dtype=torch.float, device=device)
    rw_features = []
    A_power = A.clone()
    for _ in range(walk_length):
        diag = torch.diagonal(A_power)
        rw_features.append(diag)
        A_power = torch.matmul(A_power, A)
    rwse = torch.stack(rw_features, dim=1)  # (num_nodes, walk_length)
    return rwse  # Shape: (num_nodes, walk_length)

# SignNet to ensure sign invariance
class SignNet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(SignNet, self).__init__()
        self.phi = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x):
        return self.phi(x) + self.phi(-x)

# Graph Transformer Layer with Masking
class GraphTransformerLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=4, dropout=0.1):
        super(GraphTransformerLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=in_dim, num_heads=num_heads, dropout=dropout)
        self.linear1 = nn.Linear(in_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(out_dim, in_dim)
        self.norm1 = nn.LayerNorm(in_dim)
        self.norm2 = nn.LayerNorm(in_dim)
        self.activation = nn.ReLU()

    def forward(self, x, key_padding_mask=None):
        # x: (sequence_length, batch_size, embed_dim)
        attn_output, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = x + attn_output
        x = self.norm1(x)
        linear_output = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = x + linear_output
        x = self.norm2(x)
        return x

# Updated GNN Model with Virtual Node, GINE Layers, and Graph Transformer
class GNNWithVirtualNodeAndGINEAndTransformer(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, edge_attr_dim, num_layers=5, lap_pe_dim=10, rwse_dim=10):
        super(GNNWithVirtualNodeAndGINEAndTransformer, self).__init__()
        self.num_layers = num_layers
        self.hidden_features = hidden_features

        # Node Encoder
        self.node_encoder = nn.Linear(in_features, hidden_features)

        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(GINELayerWithVN(
                in_channels=hidden_features,
                out_channels=hidden_features,
                edge_dim=edge_attr_dim
            ))

        self.virtual_node_embedding = torch.nn.Embedding(1, hidden_features)
        torch.nn.init.constant_(self.virtual_node_embedding.weight.data, 0)

        self.mlp_virtual_node = torch.nn.Sequential(
            torch.nn.Linear(hidden_features, hidden_features),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_features, hidden_features),
            torch.nn.ReLU(),
        )

        # Positional Encodings
        self.lap_pe_dim = lap_pe_dim
        self.rwse_dim = rwse_dim
        self.lap_pe_linear = nn.Linear(hidden_features, hidden_features)
        self.rwse_linear = nn.Linear(rwse_dim, hidden_features)
        self.signnet = SignNet(lap_pe_dim, hidden_features)

        # Graph Transformer
        self.transformer_layers = nn.ModuleList([
            GraphTransformerLayer(hidden_features, hidden_features) for _ in range(3)
        ])

        self.fc = torch.nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, edge_attr, batch, data):
        # Apply node_encoder first
        x = self.node_encoder(x)  # [num_nodes, hidden_features]
        # Initialize positional encodings tensor
        pos_enc = torch.zeros_like(x).to(device)  # [num_nodes, hidden_features]

        # Iterate over each graph in the batch
        num_graphs = batch.max().item() + 1
        for graph_id in range(num_graphs):
            mask = (batch == graph_id)
            num_nodes_graph = mask.sum().item()

            # Extract node indices for the current graph
            node_idx = torch.where(batch == graph_id)[0]

            # Extract subgraph using pyg.utils.subgraph
            sub_edge_index, sub_edge_attr = pyg.utils.subgraph(
                node_idx,
                edge_index,
                edge_attr,
                relabel_nodes=True,
                num_nodes=x.size(0)
            )

            # Create sub_data
            sub_data = pyg.data.Data(
                x=x[node_idx],
                edge_index=sub_edge_index,
                edge_attr=sub_edge_attr
            )

            # Compute Positional Encodings for the sub-graph
            lap_pe = compute_laplace_pe(sub_data, num_eigenvec=self.lap_pe_dim)
            rwse = compute_rwse(sub_data, walk_length=self.rwse_dim)

            # Apply SignNet to LapPE
            lap_pe = self.signnet(lap_pe)  # [num_nodes_graph, hidden_features]

            # Linear transformation
            lap_pe = self.lap_pe_linear(lap_pe)  # [num_nodes_graph, hidden_features]
            rwse = self.rwse_linear(rwse)        # [num_nodes_graph, hidden_features]

            # Combine positional encodings
            graph_pos_enc = lap_pe + rwse  # [num_nodes_graph, hidden_features]

            # Assign to pos_enc
            pos_enc[node_idx] = graph_pos_enc  # [num_nodes, hidden_features]

        # Add positional encodings to node features
        x = x + pos_enc  # [num_nodes, hidden_features]

        # Initialize virtual node embedding
        batch_size = num_graphs
        vn_embed = self.virtual_node_embedding.weight.repeat(batch_size, 1)  # [batch_size, hidden_features]

        for conv in self.convs:
            x = conv(x, edge_index, edge_attr, vn_embed, batch)  # [num_nodes, hidden_features]
            x = F.relu(x)

            # Update virtual node embedding
            vn_aggr = global_mean_pool(x, batch)  # [batch_size, hidden_features]
            vn_embed = vn_embed + self.mlp_virtual_node(vn_aggr)  # [batch_size, hidden_features]

        # Prepare for Graph Transformer
        # Group node features by graph and pad
        x_padded, mask = pyg.utils.to_dense_batch(x, batch)  # x_padded: [batch_size, max_num_nodes, hidden_features]

        # Transpose to match expected input of Transformer
        x_padded = x_padded.transpose(0, 1)  # x_padded: [max_num_nodes, batch_size, hidden_features]

        # mask remains of shape [batch_size, max_num_nodes], which matches key_padding_mask
        # Invert mask for key_padding_mask (True indicates positions to be masked)
        key_padding_mask = ~mask  # [batch_size, max_num_nodes]

        # Apply Transformer layers
        for transformer in self.transformer_layers:
            x_padded = transformer(x_padded, key_padding_mask=key_padding_mask)

        # Transpose back
        x_padded = x_padded.transpose(0, 1)  # x_padded: [batch_size, max_num_nodes, hidden_features]

        # Flatten x_padded back to x
        x = x_padded[mask]  # x: [num_nodes, hidden_features]

        # Apply global mean pooling
        x = global_mean_pool(x, batch)  # [batch_size, hidden_features]
        x = self.fc(x)  # [batch_size, out_features]
        return x

# Training and evaluation functions
def train(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc='Training'):
        batch = batch.to(device)
        batch.x = batch.x.float()  # Convert node features to float
        batch.edge_attr = batch.edge_attr.float()  # Convert edge attributes to float
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch)
        loss = loss_fn(out, batch.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(loader)
    return average_loss

def evaluate(model, loader):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            batch = batch.to(device)
            batch.x = batch.x.float()  # Convert node features to float
            batch.edge_attr = batch.edge_attr.float()  # Convert edge attributes to float
            out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, batch)
            y_pred.append(out.cpu())
            y_true.append(batch.y.cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    # Compute per-class AP
    ap_per_class = []
    for i in range(y_true.shape[1]):
        try:
            ap = average_precision_score(y_true[:, i], y_pred[:, i])
        except ValueError:
            ap = 0.0  # Handle cases where a class has no positive samples
        ap_per_class.append(ap)
    mean_ap = np.mean(ap_per_class)
    return mean_ap

def plot_results(epochs, train_losses, val_aps, learning_rates=None):
    epochs_range = range(1, epochs + 1)

    # Plot Training Loss
    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig('Training_Loss.png')
    plt.show()

    # Plot Validation AP Score
    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, val_aps, label='Validation AP Score', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Average Precision Score')
    plt.title('Validation AP Score over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig('Validation_AP_Score.png')
    plt.show()

    # Plot Learning Rate if provided
    if learning_rates is not None:
        plt.figure(figsize=(10, 5))
        plt.plot(epochs_range, learning_rates, label='Learning Rate', color='green')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.title('Learning Rate over Epochs')
        plt.legend()
        plt.grid(True)
        plt.savefig('Learning_Rate.png')
        plt.show()

def main(epochs=100, lr=0.001, hidden_features=256):
    # Compute edge_attr_dim and num_tasks from the dataset
    edge_attr_dim = dataset[0].edge_attr.shape[1]
    num_tasks = dataset[0].y.shape[-1]

    # Initialize the model, optimizer, and loss function
    model = GNNWithVirtualNodeAndGINEAndTransformer(
        in_features=dataset.num_node_features,
        hidden_features=hidden_features,
        out_features=num_tasks,
        edge_attr_dim=edge_attr_dim,
        num_layers=5,  # Increased depth as per the paper's suggestion
        lap_pe_dim=10,
        rwse_dim=10
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.9, patience=10)

    loss_fn = torch.nn.BCEWithLogitsLoss()

    # Lists to store losses and AP scores
    train_losses = []
    val_aps = []
    learning_rates = []

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        train_loss = train(model, train_loader, optimizer, loss_fn)
        val_ap = evaluate(model, val_loader)
        train_losses.append(train_loss)
        val_aps.append(val_ap)
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)
        print(f"Train Loss: {train_loss:.4f}, Validation AP Score: {val_ap:.4f}, Learning Rate: {current_lr:.6f}")
        scheduler.step(val_ap)

    # Final test evaluation
    test_ap = evaluate(model, test_loader)
    print(f"Test AP Score: {test_ap:.4f}")

    # Plotting the results
    plot_results(epochs, train_losses, val_aps, learning_rates)

# Task 4: Draw the molecule represented by peptides_train[0]
def draw_molecule(data, def_col=0):
    G = pyg.utils.to_networkx(data, to_undirected=True)
    node_features = data.x.numpy()
    edge_index = data.edge_index.numpy()
    edge_attr = data.edge_attr.numpy()
    bond_types = edge_attr[:, 0].astype(int)
    atom_types = None
    atom_type_indices = None
    for i, (u, v) in enumerate(zip(edge_index[0], edge_index[1])):
        G.edges[u, v]['bond_type'] = bond_types[i]
    if def_col == 0:
        atom_types = {
            5: 'C',
            6: 'N',
            7: 'O',
        }
        atom_type_indices = node_features[:, def_col].astype(int)
    elif def_col == 2:
        atom_types = {4: 'C', 3: 'O', 1: 'N'}
        atom_type_indices = node_features[:, def_col].astype(int)
    elif def_col == 4:
        atom_types = {1: 'C', 0: 'O', 2: 'N'}
        atom_type_indices = node_features[:, def_col].astype(int)
    bond_color_mapping = {
        0: 'black',
        1: 'blue',
        3: 'red',
    }
    edges = list(G.edges())
    edge_colors = []
    for u, v in edges:
        bond_type = G.edges[u, v]['bond_type']
        color = bond_color_mapping.get(bond_type, 'green')
        edge_colors.append(color)
    labels = {i: atom_types.get(atom_type_indices[i], 'X') for i in range(atom_type_indices.shape[0])}
    size=12
    plt.figure(figsize=(size, size))
    pos = nx.kamada_kawai_layout(G, scale=5)
    nx.draw(
        G, pos,
        with_labels=False,
        node_size=50,
        node_color='lightblue',
        edgelist=edges,
        edge_color=edge_colors,
        width=1.5
    )
    nx.draw_networkx_labels(
        G, pos,
        labels=labels,
        font_size=6,
        font_weight='bold'
    )
    plt.title('Molecule Visualization of peptides_train[0]')
    plt.axis('off')
    plt.savefig('Molecule_Visualization.png')
    plt.show()

if __name__ == "__main__":
    # Load dataset and create data loaders
    # Replace LRGBDataset with an appropriate dataset loader if needed
    # Here, I'll assume you're using a custom dataset similar to TUDataset

    try:
        dataset = pyg.datasets.LRGBDataset(root='dataset/peptides-func', name="Peptides-func")
    except AttributeError:
        # If LRGBDataset is not available, use a placeholder
        # Replace this with the actual dataset loader you're using
        print("LRGBDataset not found. Please replace with the actual dataset loader.")
        dataset = TUDataset(root='dataset/Mutagenicity', name='Mutagenicity')

    # Check if dataset has splits; if not, create them manually
    if hasattr(dataset, 'train_val_test_idx'):
        peptides_train = dataset[dataset.train_val_test_idx['train']]
        peptides_val = dataset[dataset.train_val_test_idx['val']]
        peptides_test = dataset[dataset.train_val_test_idx['test']]
    else:
        # Create train, val, test splits manually
        num_train = int(0.8 * len(dataset))
        num_val = int(0.1 * len(dataset))
        num_test = len(dataset) - num_train - num_val
        peptides_train, peptides_val, peptides_test = torch.utils.data.random_split(dataset, [num_train, num_val, num_test])

    batch_size = 32
    train_loader = pyg.loader.DataLoader(peptides_train, batch_size=batch_size, shuffle=True)
    val_loader = pyg.loader.DataLoader(peptides_val, batch_size=batch_size, shuffle=False)
    test_loader = pyg.loader.DataLoader(peptides_test, batch_size=batch_size, shuffle=False)

    # Check number of classes and label distribution
    if hasattr(dataset, 'num_tasks'):
        num_classes = dataset.num_tasks
    elif hasattr(dataset, 'num_classes'):
        num_classes = dataset.num_classes
    else:
        # Assume binary classification if not specified
        num_classes = 1
    print(f"Number of classes: {num_classes}")

    all_labels = np.concatenate([data.y.numpy() for data in dataset], axis=0)
    label_distribution = np.mean(all_labels, axis=0)
    print(f"Label distribution: {label_distribution}")

    # Run the main training loop
    main(epochs=300, lr=0.001, hidden_features=32)

    # Draw the molecule for Task 4
    if len(peptides_train) > 0:
        draw_molecule(peptides_train[0])
    else:
        print("Training set is empty. Cannot draw a molecule.")


Using device: cuda
Number of classes: 10
Label distribution: [0.08884393 0.03540881 0.06419571 0.06226432 0.6272418  0.19755358
 0.10687023 0.18412581 0.01995769 0.2598179 ]
Epoch 1/300


Training: 100%|██████████| 272/272 [01:44<00:00,  2.61it/s]
Evaluating: 100%|██████████| 34/34 [00:09<00:00,  3.46it/s]


Train Loss: 0.3689, Validation AP Score: 0.2334, Learning Rate: 0.001000
Epoch 2/300


Training: 100%|██████████| 272/272 [01:40<00:00,  2.70it/s]
Evaluating: 100%|██████████| 34/34 [00:10<00:00,  3.20it/s]


Train Loss: 0.3523, Validation AP Score: 0.2331, Learning Rate: 0.001000
Epoch 3/300


Training: 100%|██████████| 272/272 [01:38<00:00,  2.76it/s]
Evaluating: 100%|██████████| 34/34 [00:10<00:00,  3.14it/s]


Train Loss: 0.3516, Validation AP Score: 0.2202, Learning Rate: 0.001000
Epoch 4/300


Training: 100%|██████████| 272/272 [01:37<00:00,  2.78it/s]
Evaluating: 100%|██████████| 34/34 [00:10<00:00,  3.26it/s]


Train Loss: 0.3424, Validation AP Score: 0.2184, Learning Rate: 0.001000
Epoch 5/300


Training: 100%|██████████| 272/272 [01:55<00:00,  2.35it/s]
Evaluating: 100%|██████████| 34/34 [00:12<00:00,  2.66it/s]


Train Loss: 0.3375, Validation AP Score: 0.2168, Learning Rate: 0.001000
Epoch 6/300


Training: 100%|██████████| 272/272 [02:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 34/34 [00:12<00:00,  2.64it/s]


Train Loss: 0.3326, Validation AP Score: 0.2116, Learning Rate: 0.001000
Epoch 7/300


Training:  12%|█▏        | 32/272 [00:14<01:45,  2.28it/s]


KeyboardInterrupt: 