In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import networkx as nx
import random
from collections import defaultdict
import matplotlib.pyplot as plt
import math

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class KleinbergGraphGenerator:
    """
    Kleinberg small-world network generator using NetworkX implementation
    Creates a 6x6 grid (36 nodes) with long-range connections
    """
    def __init__(self, n=6, p=2, q=1, r=2, dim=2, seed=None):
        """
        Args:
            n: Linear dimension of the lattice (6 for 36 nodes in 2D)
            p: Exponent for computing connection probabilities
            q: Number of local contacts per node
            r: Number of long-range contacts per node
            dim: Dimension of the underlying lattice (2 for 2D)
            seed: Random seed for reproducibility
        """
        self.n = n
        self.p = p
        self.q = q
        self.r = r
        self.dim = dim
        self.seed = seed
        self.total_nodes = n ** dim  # n^dim nodes for dim-dimensional lattice

        # Ensure we have exactly 36 nodes
        assert self.total_nodes == 36, f"Expected 36 nodes, got {self.total_nodes}"

    def generate_graph(self):
        """
        Generate Kleinberg small-world network using NetworkX
        """
        # Use NetworkX's navigable_small_world_graph function (Kleinberg model)
        G = nx.navigable_small_world_graph(
            n=self.n,
            p=self.p,
            q=self.q,
            r=self.r,
            dim=self.dim,
            seed=self.seed
        )

        # Verify we have exactly 36 nodes
        assert G.number_of_nodes() == 36, f"Graph has {G.number_of_nodes()} nodes, expected 36"

        return G

    def convert_to_dag(self, G):
        """
        Convert undirected graph to Directed Acyclic Graph (DAG)
        using grid-based ordering
        """
        # Create ordering based on lattice position
        node_indices = {}
        for i, node in enumerate(G.nodes()):
            # NetworkX navigable_small_world_graph nodes are tuples representing lattice coordinates
            if isinstance(node, tuple):
                # For 2D lattice, convert (i, j) to linear index
                if len(node) == 2:
                    i_coord, j_coord = node
                    order = i_coord * self.n + j_coord
                else:
                    # For higher dimensions, use lexicographic ordering
                    order = sum(coord * (self.n ** idx) for idx, coord in enumerate(reversed(node)))
                node_indices[node] = order
            else:
                # If nodes are integers, use them directly
                node_indices[node] = node

        # Create networkx DiGraph
        DAG = nx.DiGraph()
        DAG.add_nodes_from(G.nodes())

        # Add edges directed from smaller to larger order
        for u, v in G.edges():
            if node_indices[u] < node_indices[v]:
                DAG.add_edge(u, v)
            elif node_indices[u] > node_indices[v]:
                DAG.add_edge(v, u)
            # If node_indices[u] == node_indices[v], this case should not happen
            # with unique lattice-based ordering, but if it does, we can ignore the edge
            # or handle based on specific requirements (e.g., remove to avoid potential cycles)


        # Verify the graph is a DAG
        if not nx.is_directed_acyclic_graph(DAG):
             # If it's not a DAG, there must be a cycle.
             # This might happen if the original graph generation or node ordering
             # creates a situation where directing edges based purely on index
             # still results in a cycle, though the lattice ordering should prevent this.
             # As a fallback, we can try to break cycles. A simple way is to
             # remove edges that are part of a cycle.
             try:
                 cycles = list(nx.simple_cycles(DAG))
                 for cycle in cycles:
                     # Remove one edge from each cycle. Removing the edge with
                     # the largest index difference might be a heuristic,
                     # or just remove the first edge.
                     if cycle:
                         u, v = cycle[0], cycle[1] if len(cycle) > 1 else cycle[0] # Handle self-loops if possible
                         if DAG.has_edge(u, v):
                              DAG.remove_edge(u, v)
             except nx.NetworkXNoCycle:
                 # Should not happen based on the outer if condition, but for safety
                 pass

             # Re-check if it's a DAG after attempting to break cycles
             if not nx.is_directed_acyclic_graph(DAG):
                 # If still not a DAG, raise an error or handle as needed
                 raise nx.NetworkXUnfeasible("Could not convert graph to a DAG after attempting to break cycles.")


        return DAG, node_indices

class CNNBCNModule(nn.Module):
    """
    Individual CNNBCN module based on the generated Kleinberg DAG
    """
    def __init__(self, dag, node_indices, num_channels, activation_type='gelu'):
        super(CNNBCNModule, self).__init__()
        self.dag = dag
        self.node_indices = node_indices
        self.num_channels = num_channels
        self.activation_type = activation_type

        # Create node operations for each node
        self.node_ops = nn.ModuleDict()
        self.edge_weights = nn.ParameterDict()

        # Calculate in-degrees for each node
        in_degrees = dict(dag.in_degree())

        for node in dag.nodes():
            in_degree = in_degrees[node]
            if in_degree > 0:  # Only create operations for nodes with inputs
                # Aggregation weights (made positive using softmax for better stability)
                self.edge_weights[str(node)] = nn.Parameter(
                    torch.randn(in_degree) * 0.1
                )

                # Transformation operations: Conv3x3 + BN + Activation
                self.node_ops[str(node)] = nn.Sequential(
                    nn.Conv2d(num_channels, num_channels, 3, padding=1),
                    nn.BatchNorm2d(num_channels),
                    self._get_activation(activation_type)
                )

    def _get_activation(self, activation_type):
        """Get activation function based on type"""
        if activation_type == 'gelu':
            return nn.GELU()
        elif activation_type == 'relu':
            return nn.ReLU(inplace=True)
        elif activation_type == 'swish':
            return nn.SiLU()  # Swish activation
        else:
            return nn.ReLU(inplace=True)

    def forward(self, x):
        """
        Forward pass through the Kleinberg DAG-based module
        """
        batch_size = x.size(0)

        # Store node outputs
        node_outputs = {}

        # Topological sort to process nodes in correct order
        topo_order = list(nx.topological_sort(self.dag))

        # Initialize input nodes (nodes with no predecessors)
        input_nodes = [n for n in self.dag.nodes() if self.dag.in_degree(n) == 0]
        for node in input_nodes:
            node_outputs[node] = x

        # Process nodes in topological order
        for node in topo_order:
            predecessors = list(self.dag.predecessors(node))

            if len(predecessors) > 0:  # Node has inputs
                # Aggregation: weighted sum of inputs with softmax weights
                weights = F.softmax(self.edge_weights[str(node)], dim=0)

                aggregated = None
                for i, pred in enumerate(predecessors):
                    if pred in node_outputs:
                        weighted_input = weights[i] * node_outputs[pred]
                        if aggregated is None:
                            aggregated = weighted_input
                        else:
                            aggregated = aggregated + weighted_input

                if aggregated is not None:
                    # Transformation: Conv + BN + Activation
                    transformed = self.node_ops[str(node)](aggregated)
                    node_outputs[node] = transformed

        # Collect outputs from all terminal nodes
        output_nodes = [n for n in self.dag.nodes() if self.dag.out_degree(n) == 0]

        if len(output_nodes) == 1:
            return node_outputs.get(output_nodes[0], x)
        else:
            # Weighted average of outputs from multiple terminal nodes
            outputs = [node_outputs.get(node, x) for node in output_nodes if node in node_outputs]
            if outputs:
                return sum(outputs) / len(outputs)
            else:
                return x

class CNNBCNModel(nn.Module):
    """
    Complete CNNBCN model with Kleinberg small-world modules (36 nodes each)
    """
    def __init__(self, num_classes=10, num_modules=5, kleinberg_p=2, simple_mode=True):
        super(CNNBCNModel, self).__init__()

        self.num_modules = num_modules
        self.simple_mode = simple_mode
        self.num_nodes = 36  # Fixed to 36 nodes per module

        # Set number of channels based on mode
        if simple_mode:
            self.channels = 78
        else:
            self.channels = 109

        # Initial convolution to adjust channels
        self.input_conv = nn.Conv2d(3, self.channels, 3, padding=1)

        # Generate modules using NetworkX Kleinberg algorithm
        self.module_list = nn.ModuleList()

        for i in range(num_modules):
            # Generate Kleinberg graph for this module (6x6 grid = 36 nodes)
            # Vary parameters slightly for each module to increase diversity
            p_param = kleinberg_p + (i * 0.2)  # Slightly increase distance decay

            kleinberg_generator = KleinbergGraphGenerator(
                n=6,  # 6^2 = 36 nodes for 2D lattice
                p=p_param,
                q=1,  # Number of local contacts
                r=2,  # Number of long-range contacts
                dim=2,
                seed=42 + i  # Different seed for each module
            )
            graph = kleinberg_generator.generate_graph()
            dag, node_indices = kleinberg_generator.convert_to_dag(graph)

            # Determine activation type based on module position
            if i < 2:  # First two modules use GeLU
                activation_type = 'gelu'
            elif i < 4:  # Middle modules use Swish
                activation_type = 'swish'
            else:  # Last modules use ReLU
                activation_type = 'relu'

            # Create module
            module = CNNBCNModule(dag, node_indices, self.channels, activation_type)
            self.module_list.append(module)

        # Downsampling layers between modules
        self.downsample_layers = nn.ModuleList()
        for i in range(num_modules - 1):
            self.downsample_layers.append(
                nn.Sequential(
                    nn.Conv2d(self.channels, self.channels, 3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.ReLU(inplace=True)
                )
            )

        # Final classifier with attention mechanism
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.attention = nn.Sequential(
            nn.Linear(self.channels, self.channels // 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.channels // 4, self.channels),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.channels, num_classes)
        )

    def forward(self, x):
        # Initial convolution
        x = self.input_conv(x)

        # Pass through modules with downsampling
        for i, module in enumerate(self.module_list):
            x = module(x)

            # Apply downsampling except for the last module
            if i < len(self.module_list) - 1:
                x = self.downsample_layers[i](x)

        # Global pooling
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)

        # Apply attention mechanism
        attention_weights = self.attention(x)
        x = x * attention_weights

        # Classification
        x = self.classifier(x)

        return x

def load_cifar10(batch_size=64):
    """Load CIFAR-10 dataset with appropriate transforms"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load datasets
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainloader, testloader

def train_model(model, trainloader, testloader, num_epochs=15, lr=0.001):
    """Train the CNNBCN model with Kleinberg networks"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Added label smoothing
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    train_losses = []
    train_accuracies = []
    test_accuracies = []

    best_test_acc = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total_train += targets.size(0)
            correct_train += predicted.eq(targets).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, '
                      f'Loss: {loss.item():.4f}')

        # Calculate training accuracy
        train_acc = 100. * correct_train / total_train
        train_losses.append(running_loss / len(trainloader))
        train_accuracies.append(train_acc)

        # Testing phase
        model.eval()
        correct_test = 0
        total_test = 0

        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total_test += targets.size(0)
                correct_test += predicted.eq(targets).sum().item()

        test_acc = 100. * correct_test / total_test
        test_accuracies.append(test_acc)

        # Save best model
        if test_acc > best_test_acc:
            best_test_acc = test_acc

        print(f'Epoch {epoch+1}/{num_epochs}: '
              f'Train Loss: {train_losses[-1]:.4f}, '
              f'Train Acc: {train_acc:.2f}%, '
              f'Test Acc: {test_acc:.2f}%, '
              f'Best Test Acc: {best_test_acc:.2f}%')

        scheduler.step()

    return train_losses, train_accuracies, test_accuracies


def main():
    """Main function to run the Kleinberg CNNBCN experiment"""
    print("CNNBCN with NetworkX Kleinberg Small-World Networks (36 nodes) for CIFAR-10")
    print("=" * 80)


    # Load CIFAR-10 dataset
    print("Loading CIFAR-10 dataset...")
    trainloader, testloader = load_cifar10(batch_size=64)

    # Create CNNBCN model with Kleinberg parameters
    print("Creating CNNBCN model with NetworkX Kleinberg networks (36 nodes each)...")
    model = CNNBCNModel(
        num_classes=10,
        num_modules=5,
        kleinberg_p=2,  # Distance decay parameter
        simple_mode=True  # 78 channels
    )

    # Print model information
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Nodes per module: {model.num_nodes}")

    # Train the model
    print("Starting training...")
    train_losses, train_accuracies, test_accuracies = train_model(
        model, trainloader, testloader, num_epochs=15, lr=0.001
    )


    print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")
    print(f"Best Test Accuracy: {max(test_accuracies):.2f}%")

if __name__ == "__main__":
    main()

CNNBCN with NetworkX Kleinberg Small-World Networks (36 nodes) for CIFAR-10
Loading CIFAR-10 dataset...
Creating CNNBCN model with NetworkX Kleinberg networks (36 nodes each)...
Total parameters: 9,850,087
Trainable parameters: 9,850,087
Nodes per module: 36
Starting training...
Epoch 1/15, Batch 0, Loss: 2.3464
