In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import networkx as nx
import random
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import math
import time
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class DAGConverter:
    """Convert undirected graph to Directed Acyclic Graph"""

    def __init__(self):
        pass

    def convert_to_dag(self, graph: nx.Graph) -> nx.DiGraph:
        """
        Convert undirected graph to DAG by assigning indices and directing edges
        from smaller to larger index (as described in paper)
        """
        # Create directed graph
        dag = nx.DiGraph()

        # Add all nodes
        dag.add_nodes_from(graph.nodes())

        # Add directed edges from smaller to larger index
        for u, v in graph.edges():
            if u < v:
                dag.add_edge(u, v)
            else:
                dag.add_edge(v, u)

        return dag

class NodeOperation(nn.Module):
    """
    Neural network node operation as described in paper
    Performs: Aggregation -> Transformation -> Distribution
    """

    def __init__(self, in_channels: int, out_channels: int, max_inputs: int = 15, use_gelu: bool = False):
        super(NodeOperation, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.max_inputs = max_inputs

        # Aggregation weights (learnable and positive via sigmoid)
        self.aggregation_weights = nn.Parameter(torch.randn(max_inputs))

        # Transformation operations
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)

        # Activation function (GeLU for first modules, ReLU for others as per paper)
        if use_gelu:
            self.activation = nn.GELU()
        else:
            self.activation = nn.ReLU(inplace=True)

    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        """
        Forward pass implementing the three steps from paper:
        1. Aggregation: weighted sum of inputs
        2. Transformation: conv + batch_norm + activation
        3. Distribution: output to next nodes
        """
        if not inputs:
            raise ValueError("Node must have at least one input")

        # Step 1: Aggregation with learnable weights
        if len(inputs) == 1:
            aggregated = inputs[0]
        else:
            # Limit the number of inputs to max_inputs to prevent index errors
            num_inputs = min(len(inputs), self.max_inputs)
            limited_inputs = inputs[:num_inputs]

            # Apply sigmoid to ensure positive weights
            weights = torch.sigmoid(self.aggregation_weights[:num_inputs])
            weights = weights / weights.sum()  # Normalize weights

            # Weighted sum of inputs
            aggregated = torch.zeros_like(limited_inputs[0])
            for i, inp in enumerate(limited_inputs):
                aggregated += weights[i] * inp

        # Step 2: Transformation
        x = self.conv(aggregated)
        x = self.batch_norm(x)
        x = self.activation(x)

        return x

class CNNBCNModule(nn.Module):
    """
    Single CNNBCN module generated from DAG
    """

    def __init__(self, dag: nx.DiGraph, in_channels: int, out_channels: int, use_gelu: bool = False):
        super(CNNBCNModule, self).__init__()
        self.dag = dag
        self.nodes = list(dag.nodes())
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Calculate max inputs for any node to size the aggregation weights properly
        max_inputs = max([dag.in_degree(node) for node in self.nodes] + [1])
        max_inputs = max(max_inputs, 15)  # Ensure minimum capacity

        # Create node operations
        self.node_ops = nn.ModuleDict()
        for node in self.nodes:
            self.node_ops[str(node)] = NodeOperation(in_channels, out_channels, max_inputs, use_gelu)

        # Input and output node mappings
        self.input_nodes = [node for node in self.nodes if dag.in_degree(node) == 0]
        self.output_nodes = [node for node in self.nodes if dag.out_degree(node) == 0]

        # If no natural input/output nodes, use first and last
        if not self.input_nodes:
            self.input_nodes = [min(self.nodes)]
        if not self.output_nodes:
            self.output_nodes = [max(self.nodes)]

        # Input projection to distribute input to input nodes
        self.input_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Output aggregation
        if len(self.output_nodes) > 1:
            self.output_aggregation = nn.Conv2d(out_channels * len(self.output_nodes), out_channels, kernel_size=1)
        else:
            self.output_aggregation = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Project input
        x_proj = self.input_projection(x)

        # Initialize node outputs
        node_outputs = {}

        # Topological sort for processing order
        topo_order = list(nx.topological_sort(self.dag))

        # Process nodes in topological order
        for node in topo_order:
            if node in self.input_nodes:
                # Input nodes receive the projected input
                node_inputs = [x_proj]
            else:
                # Collect inputs from predecessor nodes
                predecessors = list(self.dag.predecessors(node))
                node_inputs = [node_outputs[pred] for pred in predecessors if pred in node_outputs]

            if node_inputs:
                node_outputs[node] = self.node_ops[str(node)](node_inputs)

        # Aggregate outputs from output nodes
        output_tensors = [node_outputs[node] for node in self.output_nodes if node in node_outputs]

        if not output_tensors:
            # Fallback: use last computed node output
            output_tensors = [list(node_outputs.values())[-1]]

        if len(output_tensors) == 1:
            return output_tensors[0]
        else:
            # Concatenate and aggregate multiple outputs
            concatenated = torch.cat(output_tensors, dim=1)
            return self.output_aggregation(concatenated)

class CNNBCN(nn.Module):
    """
    Complete CNNBCN model with Newman-Watts small-world modules
    """

    def __init__(self, num_classes: int = 10, num_modules: int = 4, nodes_per_module: int = 36):
        super(CNNBCN, self).__init__()

        self.num_modules = num_modules
        self.nodes_per_module = nodes_per_module

        # Generate random graphs and convert to DAGs for each module
        self.dags = []
        self.graph_stats = []

        for i in range(num_modules):
            # Generate Newman-Watts small-world graph using NetworkX
            graph = nx.newman_watts_strogatz_graph(
                n=nodes_per_module,
                k=4,
                p=0.1,
                seed=42 + i  # Different seed for each module
            )

            # Calculate graph statistics
            initial_edges = (nodes_per_module * 4) // 2  # Initial ring lattice edges
            random_edges = graph.number_of_edges() - initial_edges
            stats = {
                'algorithm': 'Newman-Watts',
                'original_edges': initial_edges,
                'random_edges': random_edges,
                'total_edges': graph.number_of_edges(),
                'nodes': nodes_per_module
            }

            # Convert to DAG
            dag_converter = DAGConverter()
            dag = dag_converter.convert_to_dag(graph)
            self.dags.append(dag)
            self.graph_stats.append(stats)

        # Channel progression
        channels = [32, 64, 128, 256]

        # Input layer
        self.input_conv = nn.Conv2d(3, channels[0], kernel_size=3, padding=1)
        self.input_bn = nn.BatchNorm2d(channels[0])
        self.input_relu = nn.ReLU(inplace=True)

        # CNNBCN modules
        self.modules_list = nn.ModuleList()
        for i in range(num_modules):
            in_ch = channels[i] if i < len(channels) else channels[-1]
            out_ch = channels[i] if i < len(channels) else channels[-1]

            # Use GeLU for first two modules, ReLU for others
            use_gelu = i < 2

            module = CNNBCNModule(self.dags[i], in_ch, out_ch, use_gelu)
            self.modules_list.append(module)

        # Downsampling layers between modules
        self.downsample_layers = nn.ModuleList()
        for i in range(num_modules - 1):
            if i < len(channels) - 1:
                downsample = nn.Sequential(
                    nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU(inplace=True)
                )
            else:
                downsample = nn.Sequential(
                    nn.Conv2d(channels[-1], channels[-1], kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(channels[-1]),
                    nn.ReLU(inplace=True)
                )
            self.downsample_layers.append(downsample)

        # Global average pooling and classifier
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(channels[-1], num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input processing
        x = self.input_conv(x)
        x = self.input_bn(x)
        x = self.input_relu(x)

        # Pass through CNNBCN modules
        for i, module in enumerate(self.modules_list):
            x = module(x)

            # Downsample between modules (except last)
            if i < len(self.downsample_layers):
                x = self.downsample_layers[i](x)

        # Global average pooling and classification
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

def get_cifar10_dataloaders(batch_size: int = 64, num_workers: int = 2):
    """Get CIFAR-10 data loaders"""

    # Data transforms
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=test_transform
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader

def train_model(model, train_loader, test_loader, num_epochs=10, lr=0.001):
    """Train the CNNBCN model"""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Training history
    train_losses = []
    train_accuracies = []
    test_accuracies = []

    print(f"Training on device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    for epoch in range(num_epochs):
        start_time = time.time()

        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = output.max(1)
            total_train += target.size(0)
            correct_train += predicted.eq(target).sum().item()

            if batch_idx % 200 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}')

        # Calculate training metrics
        epoch_loss = running_loss / len(train_loader)
        train_acc = 100. * correct_train / total_train

        # Test phase
        model.eval()
        correct_test = 0
        total_test = 0
        test_loss = 0.0

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()

        test_acc = 100. * correct_test / total_test
        test_loss = test_loss / len(test_loader)

        # Update learning rate
        scheduler.step()

        # Save metrics
        train_losses.append(epoch_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)

        epoch_time = time.time() - start_time

        print(f'Epoch {epoch+1}/{num_epochs} ({epoch_time:.1f}s):')
        print(f'  Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
        print(f'  Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
        print('-' * 60)

    return train_losses, train_accuracies, test_accuracies

def analyze_newman_watts_properties(n_nodes=36, k=4, p=0.1):
    """Analyze Newman-Watts small-world network properties using NetworkX"""

    print("Newman-Watts Small-World Network Analysis (NetworkX)")
    print("=" * 50)

    # Generate Newman-Watts graph
    graph = nx.newman_watts_strogatz_graph(n=n_nodes, k=k, p=p, seed=42)

    # Calculate statistics
    initial_edges = (n_nodes * k) // 2  # Initial ring lattice edges
    random_edges = graph.number_of_edges() - initial_edges

    print(f"\nNewman-Watts (k={k}, p={p}):")
    print(f"  Initial ring lattice edges: {initial_edges}")
    print(f"  Random edges added: {random_edges}")
    print(f"  Total edges: {graph.number_of_edges()}")
    print(f"  Average degree: {2 * graph.number_of_edges() / n_nodes:.2f}")
    print(f"  Clustering coefficient: {nx.average_clustering(graph):.4f}")
    print(f"  Average shortest path: {nx.average_shortest_path_length(graph):.4f}")

    return graph

def analyze_dag_properties(dags, graph_stats):
    """Analyze properties of generated DAGs"""
    print("\nDAG Analysis:")
    print("=" * 50)

    for i, (dag, stats) in enumerate(zip(dags, graph_stats)):
        print(f"\nModule {i+1} DAG (Newman-Watts):")
        print(f"  Original edges: {stats.get('original_edges', 'N/A')}")
        print(f"  Random edges: {stats.get('random_edges', 'N/A')}")
        print(f"  Total nodes: {dag.number_of_nodes()}")
        print(f"  Total edges: {dag.number_of_edges()}")

        # Calculate in-degree and out-degree statistics
        in_degrees = [dag.in_degree(node) for node in dag.nodes()]
        out_degrees = [dag.out_degree(node) for node in dag.nodes()]

        print(f"  Max in-degree: {max(in_degrees)}")
        print(f"  Max out-degree: {max(out_degrees)}")
        print(f"  Avg in-degree: {np.mean(in_degrees):.2f}")
        print(f"  Avg out-degree: {np.mean(out_degrees):.2f}")

        # Count input and output nodes
        input_nodes = [node for node in dag.nodes() if dag.in_degree(node) == 0]
        output_nodes = [node for node in dag.nodes() if dag.out_degree(node) == 0]

        print(f"  Input nodes: {len(input_nodes)} - {input_nodes[:5]}{'...' if len(input_nodes) > 5 else ''}")
        print(f"  Output nodes: {len(output_nodes)} - {output_nodes[:5]}{'...' if len(output_nodes) > 5 else ''}")

def visualize_training_results(train_losses, train_accs, test_accs):
    """Visualize training results with plots"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot losses
    ax1.plot(train_losses, label='Training Loss', color='blue')
    ax1.set_title('Training Loss Over Epochs')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Plot accuracies
    ax2.plot(train_accs, label='Training Accuracy', color='blue')
    ax2.plot(test_accs, label='Test Accuracy', color='red')
    ax2.set_title('Accuracy Over Epochs')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)

    plt.suptitle('Newman-Watts CNNBCN Training Results (36 Nodes)')
    plt.tight_layout()
    plt.show()

def save_model_checkpoint(model, optimizer, epoch, loss, accuracy, filename):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
        'loss': loss,
        'accuracy': accuracy,
        'graph_stats': model.graph_stats if hasattr(model, 'graph_stats') else None
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")

def main():
    """Main function to run Newman-Watts CNNBCN experiment"""
    print("CNNBCN: Convolutional Neural Network Based on Complex Networks")
    print("Implementation with Newman-Watts Small-World Networks (36 Nodes)")
    print("=" * 70)

    # Set device info
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

    # Analyze Newman-Watts network properties
    print("\n" + "=" * 60)
    print("Graph Algorithm Analysis")
    print("=" * 60)
    nw_graph = analyze_newman_watts_properties(n_nodes=36, k=4, p=0.1)

    # Load data
    train_loader, test_loader = get_cifar10_dataloaders(batch_size=64)

    # Create model with 36 nodes per module
    model = CNNBCN(num_classes=10, num_modules=4, nodes_per_module=36)

    # Analyze model
    print("\n" + "=" * 60)
    print("Model Analysis")
    print("=" * 60)
    analyze_dag_properties(model.dags, model.graph_stats)

    # Train model
    print("\n" + "=" * 60)
    print("Training Newman-Watts CNNBCN Model (36 Nodes)")
    print("=" * 60)

    train_losses, train_accs, test_accs = train_model(
        model, train_loader, test_loader, num_epochs=15, lr=0.001
    )

    # Visualize results
    visualize_training_results(train_losses, train_accs, test_accs)

    # Save model checkpoint
    save_model_checkpoint(model, None, 15, train_losses[-1], test_accs[-1],
                         "cnnbcn_newman_watts_36nodes_final.pth")

    # Print final results
    print("\n" + "=" * 60)
    print("Final Results")
    print("=" * 60)
    print(f"Best Test Accuracy: {max(test_accs):.2f}%")
    print(f"Final Test Accuracy: {test_accs[-1]:.2f}%")
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

    return model, train_losses, train_accs, test_accs

if __name__ == "__main__":
    main()

CNNBCN: Convolutional Neural Network Based on Complex Networks
Implementation with Newman-Watts Small-World Networks (36 Nodes)
Using device: cuda
GPU: Tesla T4
Memory: 14.7 GB

Graph Algorithm Analysis
Newman-Watts Small-World Network Analysis (NetworkX)

Newman-Watts (k=4, p=0.1):
  Initial ring lattice edges: 72
  Random edges added: 8
  Total edges: 80
  Average degree: 4.44
  Clustering coefficient: 0.4435
  Average shortest path: 2.9698

Model Analysis

DAG Analysis:

Module 1 DAG (Newman-Watts):
  Original edges: 72
  Random edges: 8
  Total nodes: 36
  Total edges: 80
  Max in-degree: 4
  Max out-degree: 5
  Avg in-degree: 2.22
  Avg out-degree: 2.22
  Input nodes: 1 - [0]
  Output nodes: 1 - [35]

Module 2 DAG (Newman-Watts):
  Original edges: 72
  Random edges: 10
  Total nodes: 36
  Total edges: 82
  Max in-degree: 5
  Max out-degree: 5
  Avg in-degree: 2.28
  Avg out-degree: 2.28
  Input nodes: 1 - [0]
  Output nodes: 1 - [35]

Module 3 DAG (Newman-Watts):
  Original edges: