In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import networkx as nx
import random
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import math

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class DAGConverter:
    """Convert undirected graph to Directed Acyclic Graph"""

    def __init__(self):
        pass

    def convert_to_dag(self, graph: nx.Graph) -> nx.DiGraph:
        """
        Convert undirected graph to DAG by assigning indices and directing edges
        from smaller to larger index (as described in paper)
        """
        # Create directed graph
        dag = nx.DiGraph()

        # Add all nodes
        dag.add_nodes_from(graph.nodes())

        # Add directed edges from smaller to larger index
        for u, v in graph.edges():
            if u < v:
                dag.add_edge(u, v)
            else:
                dag.add_edge(v, u)

        return dag

class NodeOperation(nn.Module):
    """
    Neural network node operation as described in paper
    Performs: Aggregation -> Transformation -> Distribution
    """

    def __init__(self, in_channels: int, out_channels: int, max_inputs: int = 8, use_gelu: bool = False):
        super(NodeOperation, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.max_inputs = max_inputs

        # Aggregation weights (learnable and positive via sigmoid)
        # Increased max_inputs to handle more connections
        self.aggregation_weights = nn.Parameter(torch.randn(max_inputs))

        # Transformation operations
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)

        # Activation function (GeLU for first modules, ReLU for others as per paper)
        if use_gelu:
            self.activation = nn.GELU()
        else:
            self.activation = nn.ReLU(inplace=True)

    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        """
        Forward pass implementing the three steps from paper:
        1. Aggregation: weighted sum of inputs
        2. Transformation: conv + batch_norm + activation
        3. Distribution: output to next nodes
        """
        if not inputs:
            raise ValueError("Node must have at least one input")

        # Step 1: Aggregation with learnable weights
        if len(inputs) == 1:
            aggregated = inputs[0]
        else:
            # Limit the number of inputs to max_inputs to prevent index errors
            num_inputs = min(len(inputs), self.max_inputs)
            limited_inputs = inputs[:num_inputs]

            # Apply sigmoid to ensure positive weights
            weights = torch.sigmoid(self.aggregation_weights[:num_inputs])
            weights = weights / weights.sum()  # Normalize weights

            # Weighted sum of inputs
            aggregated = torch.zeros_like(limited_inputs[0])
            for i, inp in enumerate(limited_inputs):
                aggregated += weights[i] * inp

        # Step 2: Transformation
        x = self.conv(aggregated)
        x = self.batch_norm(x)
        x = self.activation(x)

        return x

class CNNBCNModule(nn.Module):
    """
    Single CNNBCN module generated from DAG
    """

    def __init__(self, dag: nx.DiGraph, in_channels: int, out_channels: int, use_gelu: bool = False):
        super(CNNBCNModule, self).__init__()
        self.dag = dag
        self.nodes = list(dag.nodes())
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Calculate max inputs for any node to size the aggregation weights properly
        max_inputs = max([dag.in_degree(node) for node in self.nodes] + [1])
        max_inputs = max(max_inputs, 8)  # Ensure at least 8 for safety

        # Create node operations
        self.node_ops = nn.ModuleDict()
        for node in self.nodes:
            self.node_ops[str(node)] = NodeOperation(in_channels, out_channels, max_inputs, use_gelu)

        # Input and output node mappings
        self.input_nodes = [node for node in self.nodes if dag.in_degree(node) == 0]
        self.output_nodes = [node for node in self.nodes if dag.out_degree(node) == 0]

        # If no natural input/output nodes, use first and last
        if not self.input_nodes:
            self.input_nodes = [min(self.nodes)]
        if not self.output_nodes:
            self.output_nodes = [max(self.nodes)]

        # Input projection to distribute input to input nodes
        self.input_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Output aggregation
        if len(self.output_nodes) > 1:
            self.output_aggregation = nn.Conv2d(out_channels * len(self.output_nodes), out_channels, kernel_size=1)
        else:
            self.output_aggregation = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Project input
        x_proj = self.input_projection(x)

        # Initialize node outputs
        node_outputs = {}

        # Topological sort for processing order
        topo_order = list(nx.topological_sort(self.dag))

        # Process nodes in topological order
        for node in topo_order:
            if node in self.input_nodes:
                # Input nodes receive the projected input
                node_inputs = [x_proj]
            else:
                # Collect inputs from predecessor nodes
                predecessors = list(self.dag.predecessors(node))
                node_inputs = [node_outputs[pred] for pred in predecessors if pred in node_outputs]

            if node_inputs:
                node_outputs[node] = self.node_ops[str(node)](node_inputs)

        # Aggregate outputs from output nodes
        output_tensors = [node_outputs[node] for node in self.output_nodes if node in node_outputs]

        if not output_tensors:
            # Fallback: use last computed node output
            output_tensors = [list(node_outputs.values())[-1]]

        if len(output_tensors) == 1:
            return output_tensors[0]
        else:
            # Concatenate and aggregate multiple outputs
            concatenated = torch.cat(output_tensors, dim=1)
            return self.output_aggregation(concatenated)

class CNNBCN(nn.Module):
    """
    Complete CNNBCN model with multiple modules using NetworkX Watts-Strogatz
    """

    def __init__(self, num_classes: int = 10, num_modules: int = 4, nodes_per_module: int = 36):
        super(CNNBCN, self).__init__()

        self.num_modules = num_modules
        self.nodes_per_module = nodes_per_module

        # Generate random graphs and convert to DAGs for each module
        self.dags = []
        for i in range(num_modules):
            # Use NetworkX's watts_strogatz_graph
            # Parameters: n=nodes, k=nearest neighbors, p=rewiring probability
            # WS parameters as specified in paper: WS(Z=4, P=0.75)
            ws_graph = nx.watts_strogatz_graph(n=nodes_per_module, k=4, p=0.75, seed=42+i)

            # Convert to DAG
            dag_converter = DAGConverter()
            dag = dag_converter.convert_to_dag(ws_graph)
            self.dags.append(dag)

        # Channel progression (simple mode from paper: 78 channels)
        # For CIFAR-10, we'll use a smaller progression
        channels = [32, 64, 128, 256]  # Reduced from paper's 78 channels for efficiency

        # Input layer
        self.input_conv = nn.Conv2d(3, channels[0], kernel_size=3, padding=1)
        self.input_bn = nn.BatchNorm2d(channels[0])
        self.input_relu = nn.ReLU(inplace=True)

        # CNNBCN modules
        self.modules_list = nn.ModuleList()
        for i in range(num_modules):
            in_ch = channels[i] if i < len(channels) else channels[-1]
            out_ch = channels[i] if i < len(channels) else channels[-1]

            # Use GeLU for first two modules, ReLU for others (as per paper)
            use_gelu = i < 2

            module = CNNBCNModule(self.dags[i], in_ch, out_ch, use_gelu)
            self.modules_list.append(module)

        # Downsampling layers between modules
        self.downsample_layers = nn.ModuleList()
        for i in range(num_modules - 1):
            if i < len(channels) - 1:
                downsample = nn.Sequential(
                    nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU(inplace=True)
                )
            else:
                # If we have more modules than channel sizes, use identity or simple conv
                downsample = nn.Sequential(
                    nn.Conv2d(channels[-1], channels[-1], kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(channels[-1]),
                    nn.ReLU(inplace=True)
                )
            self.downsample_layers.append(downsample)

        # Global average pooling and classifier
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(channels[-1], num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input processing
        x = self.input_conv(x)
        x = self.input_bn(x)
        x = self.input_relu(x)

        # Pass through CNNBCN modules
        for i, module in enumerate(self.modules_list):
            x = module(x)

            # Downsample between modules (except last)
            if i < len(self.downsample_layers):
                x = self.downsample_layers[i](x)

        # Global average pooling and classification
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

def get_cifar10_dataloaders(batch_size: int = 64, num_workers: int = 2):
    """Get CIFAR-10 data loaders"""

    # Data transforms
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=test_transform
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader

def train_model(model, train_loader, test_loader, num_epochs=15, lr=0.001):
    """Train the CNNBCN model"""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Training history
    train_losses = []
    train_accuracies = []
    test_accuracies = []

    print(f"Training on device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}, '
                      f'Loss: {loss.item():.4f}')

        # Calculate training metrics
        epoch_loss = running_loss / len(train_loader)
        train_acc = 100. * correct_train / total_train

        # Test phase
        model.eval()
        correct_test = 0
        total_test = 0

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()

        test_acc = 100. * correct_test / total_test

        # Update learning rate
        scheduler.step()

        # Save metrics
        train_losses.append(epoch_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'  Test Acc: {test_acc:.2f}%')
        print('-' * 50)

    return train_losses, train_accuracies, test_accuracies

def visualize_watts_strogatz_graph(n=32, k=4, p=0.75, seed=42):
    """Visualize the NetworkX Watts-Strogatz graph"""

    # Generate the graph
    ws_graph = nx.watts_strogatz_graph(n=n, k=k, p=p, seed=seed)

    plt.figure(figsize=(12, 5))

    # Plot 1: Original graph
    plt.subplot(1, 2, 1)
    pos = nx.circular_layout(ws_graph)  # Circular layout to show small-world structure
    nx.draw(ws_graph, pos, with_labels=True, node_color='lightblue',
            node_size=300, font_size=8, font_weight='bold')
    plt.title(f"WS Small-World Network\nNodes: {n}, k: {k}, p: {p}")
    plt.axis('off')

    # Plot 2: Convert to DAG and visualize
    plt.subplot(1, 2, 2)
    dag_converter = DAGConverter()
    dag = dag_converter.convert_to_dag(ws_graph)

    # Use hierarchical layout for DAG
    pos_dag = nx.spring_layout(dag, k=1, iterations=50)
    nx.draw(dag, pos_dag, with_labels=True, node_color='lightcoral',
            node_size=300, font_size=8, font_weight='bold',
            arrows=True, arrowsize=10, arrowstyle='->')
    plt.title(f"Converted DAG\nNodes: {dag.number_of_nodes()}, Edges: {dag.number_of_edges()}")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Print graph statistics
    print(f"Watts-Strogatz Graph Statistics:")
    print(f"  Nodes: {ws_graph.number_of_nodes()}")
    print(f"  Edges: {ws_graph.number_of_edges()}")
    print(f"  Average degree: {2 * ws_graph.number_of_edges() / ws_graph.number_of_nodes():.2f}")
    print(f"  Clustering coefficient: {nx.average_clustering(ws_graph):.4f}")

    # Only compute path length if graph is connected
    if nx.is_connected(ws_graph):
        print(f"  Average shortest path length: {nx.average_shortest_path_length(ws_graph):.4f}")
    else:
        print(f"  Graph is not connected - cannot compute average path length")

def analyze_dag_properties(dags):
    """Analyze properties of generated DAGs"""
    print("\nDAG Analysis:")
    print("=" * 30)

    for i, dag in enumerate(dags):
        print(f"\nModule {i+1} DAG:")
        print(f"  Nodes: {dag.number_of_nodes()}")
        print(f"  Edges: {dag.number_of_edges()}")

        # Calculate in-degree and out-degree statistics
        in_degrees = [dag.in_degree(node) for node in dag.nodes()]
        out_degrees = [dag.out_degree(node) for node in dag.nodes()]

        print(f"  Max in-degree: {max(in_degrees)}")
        print(f"  Max out-degree: {max(out_degrees)}")
        print(f"  Avg in-degree: {np.mean(in_degrees):.2f}")
        print(f"  Avg out-degree: {np.mean(out_degrees):.2f}")

        # Count input and output nodes
        input_nodes = [node for node in dag.nodes() if dag.in_degree(node) == 0]
        output_nodes = [node for node in dag.nodes() if dag.out_degree(node) == 0]

        print(f"  Input nodes: {len(input_nodes)}")
        print(f"  Output nodes: {len(output_nodes)}")

def main():
    """Main function to demonstrate CNNBCN model with NetworkX Watts-Strogatz"""

    print("CNNBCN Model Implementation for CIFAR-10")
    print("Using NetworkX Watts-Strogatz Graph Generator")
    print("=" * 50)

    # 1. Demonstrate NetworkX Watts-Strogatz Graph Generation
    print("\n1. Generating NetworkX Watts-Strogatz Small-World Network...")

    # Visualize the generated graph (optional, comment out if no display available)
    # visualize_watts_strogatz_graph(n=32, k=4, p=0.75, seed=42)

    # 2. Create and display model architecture
    print("\n2. Creating CNNBCN Model...")
    model = CNNBCN(num_classes=10, num_modules=4, nodes_per_module=36)

    # Analyze DAG properties
    analyze_dag_properties(model.dags)

    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"\nModel Architecture:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Number of modules: {model.num_modules}")
    print(f"  Nodes per module: {model.nodes_per_module}")

    # 3. Test with dummy input
    print("\n3. Testing model with dummy input...")
    dummy_input = torch.randn(2, 3, 32, 32)  # Batch size 2, RGB, 32x32
    with torch.no_grad():
        output = model(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    print("Model forward pass successful!")

    # 4. Load CIFAR-10 data
    print("\n4. Loading CIFAR-10 dataset...")
    train_loader, test_loader = get_cifar10_dataloaders(batch_size=64)
    print(f"Training batches: {len(train_loader)}")
    print(f"Test batches: {len(test_loader)}")

    # 5. Train the model (reduced epochs for demonstration)
    print("\n5. Training CNNBCN model...")
    train_losses, train_accs, test_accs = train_model(
        model, train_loader, test_loader, num_epochs=15, lr=0.001
    )

    # 6. Display final results
    print("\n6. Final Results:")
    print(f"Best training accuracy: {max(train_accs):.2f}%")
    print(f"Best test accuracy: {max(test_accs):.2f}%")
    print(f"Final test accuracy: {test_accs[-1]:.2f}%")

    print("\nTraining completed!")

if __name__ == "__main__":
    main()

CNNBCN Model Implementation for CIFAR-10
Using NetworkX Watts-Strogatz Graph Generator

1. Generating NetworkX Watts-Strogatz Small-World Network...

2. Creating CNNBCN Model...

DAG Analysis:

Module 1 DAG:
  Nodes: 36
  Edges: 72
  Max in-degree: 7
  Max out-degree: 6
  Avg in-degree: 2.00
  Avg out-degree: 2.00
  Input nodes: 7
  Output nodes: 4

Module 2 DAG:
  Nodes: 36
  Edges: 72
  Max in-degree: 5
  Max out-degree: 5
  Avg in-degree: 2.00
  Avg out-degree: 2.00
  Input nodes: 5
  Output nodes: 4

Module 3 DAG:
  Nodes: 36
  Edges: 72
  Max in-degree: 5
  Max out-degree: 6
  Avg in-degree: 2.00
  Avg out-degree: 2.00
  Input nodes: 6
  Output nodes: 4

Module 4 DAG:
  Nodes: 36
  Edges: 72
  Max in-degree: 7
  Max out-degree: 7
  Avg in-degree: 2.00
  Avg out-degree: 2.00
  Input nodes: 8
  Output nodes: 7

Model Architecture:
  Total parameters: 29,278,666
  Trainable parameters: 29,278,666
  Number of modules: 4
  Nodes per module: 36

3. Testing model with dummy input...
Inpu

100%|██████████| 170M/170M [00:03<00:00, 49.1MB/s]


Training batches: 782
Test batches: 157

5. Training CNNBCN model...
Training on device: cuda
Model parameters: 29,278,666
Epoch: 1/15, Batch: 0, Loss: 2.2908
Epoch: 1/15, Batch: 100, Loss: 2.0125
Epoch: 1/15, Batch: 200, Loss: 1.9656
Epoch: 1/15, Batch: 300, Loss: 1.9773
Epoch: 1/15, Batch: 400, Loss: 1.8446
Epoch: 1/15, Batch: 500, Loss: 1.8207
Epoch: 1/15, Batch: 600, Loss: 1.9157
Epoch: 1/15, Batch: 700, Loss: 1.7288
Epoch 1/15:
  Train Loss: 1.9322, Train Acc: 24.05%
  Test Acc: 32.49%
--------------------------------------------------
Epoch: 2/15, Batch: 0, Loss: 1.7979
Epoch: 2/15, Batch: 100, Loss: 1.6052
Epoch: 2/15, Batch: 200, Loss: 1.6842
Epoch: 2/15, Batch: 300, Loss: 1.7130
Epoch: 2/15, Batch: 400, Loss: 1.4952
Epoch: 2/15, Batch: 500, Loss: 1.5868
Epoch: 2/15, Batch: 600, Loss: 1.7410
Epoch: 2/15, Batch: 700, Loss: 1.4865
Epoch 2/15:
  Train Loss: 1.5889, Train Acc: 38.89%
  Test Acc: 42.86%
--------------------------------------------------
Epoch: 3/15, Batch: 0, Loss: 