In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

In [5]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')


Epoch: 0, Loss: 1.9507325887680054
Epoch: 1, Loss: 1.8381447792053223
Epoch: 2, Loss: 1.718522310256958
Epoch: 3, Loss: 1.5691254138946533
Epoch: 4, Loss: 1.4506096839904785
Epoch: 5, Loss: 1.2855370044708252
Epoch: 6, Loss: 1.1746701002120972
Epoch: 7, Loss: 1.0253092050552368
Epoch: 8, Loss: 0.9569247364997864
Epoch: 9, Loss: 0.8305391073226929
Epoch: 10, Loss: 0.7266712784767151
Epoch: 11, Loss: 0.6746046543121338
Epoch: 12, Loss: 0.558982253074646
Epoch: 13, Loss: 0.4995765686035156
Epoch: 14, Loss: 0.44301655888557434
Epoch: 15, Loss: 0.3690904378890991
Epoch: 16, Loss: 0.32364463806152344
Epoch: 17, Loss: 0.3387809097766876
Epoch: 18, Loss: 0.2908649146556854
Epoch: 19, Loss: 0.24899815022945404
Epoch: 20, Loss: 0.1836688071489334
Epoch: 21, Loss: 0.20824334025382996
Epoch: 22, Loss: 0.14268237352371216
Epoch: 23, Loss: 0.14479960501194
Epoch: 24, Loss: 0.1329120546579361
Epoch: 25, Loss: 0.141646608710289
Epoch: 26, Loss: 0.10513527691364288
Epoch: 27, Loss: 0.13053934276103973


# SagPool

In [17]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import TopKPooling


In [22]:
class SAGPool(torch.nn.Module):
    def __init__(self, in_channels, ratio=0.5):
        super(SAGPool, self).__init__()
        self.ratio = ratio
        self.gcn = GCNConv(in_channels, 1)
        self.pool = TopKPooling(in_channels, ratio=self.ratio)

    def forward(self, x, edge_index, batch):
        # Compute attention scores
        score = self.gcn(x, edge_index).sigmoid()  # Ensure scores are in [0, 1]

        # Apply the attention scores to the node features
        x = x * score.view(-1, 1)  # Broadcast score across features

        # Perform pooling
        x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch)
        return x, edge_index, batch



class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        num_node_features = dataset.num_node_features
        self.conv1 = GCNConv(num_node_features, 128)
        self.pool1 = SAGPool(128, ratio=0.8)

        self.conv2 = GCNConv(128, 128)
        self.pool2 = SAGPool(128, ratio=0.6)

        self.conv3 = GCNConv(128, 128)
        self.lin = torch.nn.Linear(128, dataset.num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x, edge_index, batch = self.pool1(x, edge_index, batch)

        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x, edge_index, batch = self.pool2(x, edge_index, batch)

        x = self.conv3(x, edge_index)
        x = torch.relu(x)

        x = global_mean_pool(x, batch)
        x = self.lin(x)

        return torch.log_softmax(x, dim=-1)


In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

# Add training loop here
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset

# Example datasets, replace with actual dataset
dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')
dataset = dataset.shuffle()
n = len(dataset) // 10
train_dataset = dataset[:-n]
test_dataset = dataset[-n:]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

for epoch in range(1, 201):
    train_loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, '
          f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')


Epoch: 001, Loss: 0.6907, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 002, Loss: 0.6899, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 003, Loss: 0.6890, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 004, Loss: 0.6880, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 005, Loss: 0.6867, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 006, Loss: 0.6842, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 007, Loss: 0.6786, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 008, Loss: 0.6709, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 009, Loss: 0.6653, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 010, Loss: 0.6639, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 011, Loss: 0.6624, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 012, Loss: 0.6610, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 013, Loss: 0.6601, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 014, Loss: 0.6580, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 015, Loss: 0.6560, Train Acc: 0.5958, Test Acc: 0.5946
Epoch: 016, Loss: 0.6545, Train Acc: 0.6118, Test Acc: 0.5946
Epoch: 0

# modification

The task here is changing this selection process,
While selecting from top-k, we do not want to select nodes who are neighbors. So if a node is selected,
then It should eliminate its neighbors from candidate list and the select the next node who is not
neighbors of a selected node.

Use Pytorch Geometric (https://pytorch.org/docs/stable/nn.html#pooling-layers)

Apply your pooling method for one graph classification data set. Compare results with original sagpool.

In [31]:
import torch
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse, subgraph
from torch_geometric.nn import GCNConv

class CustomSAGPool(torch.nn.Module):
    def __init__(self, in_channels, ratio=0.5):
        super(CustomSAGPool, self).__init__()
        self.ratio = ratio
        self.gcn = GCNConv(in_channels, 1)

    def forward(self, x, edge_index, batch):
        # Compute attention scores using a GCN layer
        scores = self.gcn(x, edge_index).sigmoid()

        # Process graphs in batch separately
        x_out = []
        edge_index_out = []
        batch_out = []
        
        for batch_id in range(batch.max() + 1):
            # Mask to process only nodes for the current batch
            batch_mask = (batch == batch_id)
            
            # Extract the subgraph for the current batch
            sub_x = x[batch_mask]
            sub_score = scores[batch_mask]
            sub_edge_index, _ = subgraph(batch_mask, edge_index, relabel_nodes=True, num_nodes=x.size(0))

            # Dense adjacency matrix for subgraph
            dense_adj = to_dense_adj(sub_edge_index, max_num_nodes=sub_x.size(0)).squeeze(0)

            # Sort nodes based on scores
            num_nodes = int(self.ratio * sub_x.size(0))
            _, idx = sub_score.view(-1).sort(descending=True)

            # Mask to keep track of remaining nodes
            mask = torch.ones(sub_x.size(0), dtype=torch.bool, device=sub_x.device)

            # List to collect selected nodes
            selected = []

            for i in idx:
                if mask[i]:
                    # Add node to selected list
                    selected.append(i)
                    # Remove the node and its neighbors from the mask
                    mask[i] = False
                    mask[dense_adj[i] > 0] = False
                    if len(selected) >= num_nodes:
                        break

            # Filter nodes based on the final selected list
            selected = torch.tensor(selected, device=sub_x.device)
            new_edge_index, _ = dense_to_sparse(dense_adj[selected][:, selected])

            # Append subgraph results to outputs
            x_out.append(sub_x[selected])
            edge_index_out.append(new_edge_index)
            batch_out.append(torch.full((selected.size(0),), batch_id, dtype=torch.long, device=sub_x.device))

        # Concatenate all subgraphs back into batched format
        x_out = torch.cat(x_out, dim=0)
        edge_index_out = torch.cat(edge_index_out, dim=1)
        batch_out = torch.cat(batch_out, dim=0)

        return x_out, edge_index_out, batch_out



class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes
        self.conv1 = GCNConv(num_node_features, 128)
        self.pool1 = CustomSAGPool(128, ratio=0.8)

        self.conv2 = GCNConv(128, 128)
        self.pool2 = CustomSAGPool(128, ratio=0.6)

        self.conv3 = GCNConv(128, 128)
        self.lin = torch.nn.Linear(128, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x, edge_index, batch = self.pool1(x, edge_index, batch)

        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x, edge_index, batch = self.pool2(x, edge_index, batch)

        x = self.conv3(x, edge_index)
        x = torch.relu(x)

        x = global_mean_pool(x, batch)
        x = self.lin(x)

        return torch.log_softmax(x, dim=-1)


In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Add training loop here
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset

# Example datasets, replace with actual dataset
dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')
dataset = dataset.shuffle()
n = len(dataset) // 10
train_dataset = dataset[:-n]
test_dataset = dataset[-n:]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

train_losses = []
train_accuracies = []
test_accuracies = []
for epoch in range(1, 101):
    train_loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)
    print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, '
          f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')


Epoch: 001, Loss: 0.6724, Train Acc: 0.6028, Test Acc: 0.5315
Epoch: 002, Loss: 0.6558, Train Acc: 0.5818, Test Acc: 0.6126
Epoch: 003, Loss: 0.6679, Train Acc: 0.6028, Test Acc: 0.5315
Epoch: 004, Loss: 0.6637, Train Acc: 0.6028, Test Acc: 0.5315
Epoch: 005, Loss: 0.6613, Train Acc: 0.6517, Test Acc: 0.5856
Epoch: 006, Loss: 0.6576, Train Acc: 0.6527, Test Acc: 0.5766
Epoch: 007, Loss: 0.6554, Train Acc: 0.6747, Test Acc: 0.7027
Epoch: 008, Loss: 0.6383, Train Acc: 0.6477, Test Acc: 0.7838
Epoch: 009, Loss: 0.6688, Train Acc: 0.6028, Test Acc: 0.5315
Epoch: 010, Loss: 0.6635, Train Acc: 0.6437, Test Acc: 0.6036
Epoch: 011, Loss: 0.6389, Train Acc: 0.6727, Test Acc: 0.6486
Epoch: 012, Loss: 0.6255, Train Acc: 0.6587, Test Acc: 0.6396
Epoch: 013, Loss: 0.6287, Train Acc: 0.6607, Test Acc: 0.6486
Epoch: 014, Loss: 0.6230, Train Acc: 0.6966, Test Acc: 0.7117
Epoch: 015, Loss: 0.6219, Train Acc: 0.6717, Test Acc: 0.6757
Epoch: 016, Loss: 0.6228, Train Acc: 0.7006, Test Acc: 0.6757
Epoch: 0

In [None]:
# plot train loss, train accuracy, and test accuracy
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Acc')
plt.plot(test_accuracies, label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()