In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

def generate_complex_graph(num_communities=3, nodes_per_community=10, bottleneck_nodes=2):
    """
    Generates a complex graph with multiple communities connected through bottleneck nodes.
    """
    G = nx.Graph()
    labels = []

    for i in range(num_communities):
        # Create a complete graph for each community
        community = nx.complete_graph(nodes_per_community)
        mapping = {node: node + i * nodes_per_community for node in community.nodes()}
        community = nx.relabel_nodes(community, mapping)
        G = nx.compose(G, community)
        labels.extend([i] * nodes_per_community)

    # Add bottleneck connections
    for i in range(num_communities - 1):
        for j in range(bottleneck_nodes):
            G.add_edge(i * nodes_per_community + j, (i + 1) * nodes_per_community + j)

    # Convert to PyTorch Geometric data format
    edge_index = torch.tensor(list(G.edges)).t().contiguous()
    x = torch.eye(G.number_of_nodes(), dtype=torch.float)  # One-hot encoded features for simplicity
    y = torch.tensor(labels, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y)
    return G, data

class SimpleGNN(torch.nn.Module):
    def __init__(self):
        super(SimpleGNN, self).__init__()
        self.conv1 = GCNConv(30, 16)
        self.conv2 = GCNConv(16, 8)
        self.conv3 = GCNConv(8, 3)  # Adjusted output size for number of communities

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)

model = SimpleGNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    return loss.item()

G, data = generate_complex_graph()

# Training loop
for epoch in range(50):
    loss = train()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}')

model.eval()
_, pred = model(data).max(dim=1)
correct = (pred == data.y).sum().item()
accuracy = correct / data.num_nodes
print(f'Accuracy: {accuracy:.4f}')

# Visualization function
def visualize_graph(G, data, pred=None, title="Graph Visualization"):
    pos = nx.spring_layout(G)
    plt.figure(figsize=(10, 10))

    # Node colors
    if pred is None:
        colors = data.y.cpu().numpy()
    else:
        colors = pred.cpu().numpy()

    # Draw the graph
    nx.draw(G, pos, with_labels=True, node_color=colors, cmap=plt.get_cmap('coolwarm'), node_size=500, font_size=10, font_color='white')
    plt.title(title)
    plt.show()

# Visualize the original graph with true labels
visualize_graph(G, data, title="Complex Graph with True Labels")

# Visualize the graph with predicted labels
visualize_graph(G, data, pred=pred, title="Complex Graph with Predicted Labels")
