In [8]:
import json
import numpy as np
from torchvision import datasets, transforms
from kafka import KafkaProducer, KafkaConsumer, TopicPartition
import redis
import time


offset = "latest"


# Redis setup
redis_client = redis.Redis(host='host.docker.internal', port=6379, db=0)

# Activation functions
def relu(x):
    return np.maximum(0, x)

def softmax(x):
    e_x = np.exp(x - np.max(x))  # Stability trick
    return e_x / e_x.sum()

ACTIVATIONS = {
    "relu": relu,
    "softmax": softmax
}

class Neuron:
    def __init__(self, layer_id, neuron_id, weights, bias, activation, is_final_layer=False):
        self.layer_id = layer_id
        self.neuron_id = neuron_id
        self.weights = np.array(weights)
        self.bias = np.array(bias)
        self.activation_func = None if is_final_layer else ACTIVATIONS.get(activation, relu)
        self.is_final_layer = is_final_layer

    def forward(self, inputs):
        """Wait for activation message from Kafka before processing"""
        topic = f'layer-{self.layer_id[-1]}'
        consumer = KafkaConsumer(
            bootstrap_servers=KAFKA_BROKER,
            value_deserializer=lambda m: json.loads(m.decode('utf-8')),
            auto_offset_reset=OFFSET_RESET,
            enable_auto_commit=True,
            group_id=f'group-layer-{self.layer_id}',
            consumer_timeout_ms=120000  # Increased timeout to avoid early exit
        )

        partition = TopicPartition(topic, self.neuron_id)
        consumer.assign([partition])

        print(f"Neuron {self.neuron_id} in {self.layer_id} waiting for activation...")

        while True:  # Keep waiting until message is received
            for message in consumer:
                print(f"Neuron {self.neuron_id} received message: {message.value}")
                if 'layer' in message.value and message.value['layer'] == self.layer_id:
                    print(f"Neuron {self.neuron_id} in {self.layer_id} activated!")
                    consumer.close()
                    z = np.dot(inputs, self.weights) + self.bias
                    return z if self.is_final_layer else self.activation_func(z)
            
            print(f"⚠️ Neuron {self.neuron_id} in {self.layer_id} still waiting for activation...")
            time.sleep(2)  # Prevents infinite loop from consuming CPU

class Layer:
    def __init__(self, layer_id, neurons, is_final_layer=False):
        self.layer_id = layer_id
        self.neurons = neurons
        self.is_final_layer = is_final_layer

    def forward(self, input_data):
        """Trigger neuron activations via Kafka and activate the next layer when finished"""
        producer = KafkaProducer(bootstrap_servers='kafka:9092',
                                 value_serializer=lambda v: json.dumps(v).encode('utf-8'),
                                 retries=5, request_timeout_ms=10000)
    
        topic = f'layer-{self.layer_id[-1]}'
        activation_message = {'layer': self.layer_id}

        print(f"📤 Layer {self.layer_id} sending activation messages to topic {topic}...")

        for neuron_id in range(len(self.neurons)):
            producer.send(topic, key=str(neuron_id).encode(), value=activation_message, partition=neuron_id)
            print(f"✅ Message sent to {topic}, partition {neuron_id}")

        producer.flush()
        producer.close()

        outputs = np.array([neuron.forward(input_data) for neuron in self.neurons])
        if self.is_final_layer:
            outputs = softmax(outputs)
        
        redis_client.set(self.layer_id, outputs.astype(np.float32).tobytes())
        
        # Activate the next layer
        producer = KafkaProducer(bootstrap_servers='kafka:9092',
                                 value_serializer=lambda v: json.dumps(v).encode('utf-8'))
        producer.send('activate-layer', {'layer': f'layer-{int(self.layer_id[-1]) + 1}'} if not self.is_final_layer else {'layer': 'final'})
        producer.flush()
        producer.close()
    
        return outputs

# Load JSON file
def load_network(filename):
    with open(filename, 'r') as f:
        return json.load(f)

# Build network
def build_network(json_data):
    layers = []
    sorted_layers = sorted(json_data.keys(), key=lambda x: int(x.split('_')[-1]))
    for i, layer_name in enumerate(sorted_layers):
        layer_info = json_data[layer_name]
        neurons = [
            Neuron(
                layer_id=layer_name,
                neuron_id=idx,
                weights=np.array(node['weights']),
                bias=np.array(node['biases']),
                activation=node['activation'],
                is_final_layer=(i == len(sorted_layers) - 1)
            )
            for idx, node in enumerate(layer_info['nodes'])
        ]
        layers.append(Layer(layer_id=layer_name, neurons=neurons, is_final_layer=(i == len(sorted_layers) - 1)))
    return layers

# Forward pass for single image
def forward_pass(layers, input_data, image_id):
    producer = KafkaProducer(bootstrap_servers='kafka:9092',
                             value_serializer=lambda v: json.dumps(v).encode('utf-8'))
    print("🔥 Sending initial activation message to layer-0...")
    producer.send('layer-0', {'layer': 'layer-0'})
    producer.flush()
    producer.close()
    
    for layer in layers:
        input_data = layer.forward(input_data)
    
    prediction = int(np.argmax(input_data))
    redis_client.hset('predictions', image_id, prediction)  # Store all predictions
    return prediction

# Load network
data = load_network("node_based_model.json")
network = build_network(data)

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_test = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Test first 10 images one by one
for i in range(10):
    image, label = mnist_test[i]
    image_np = image.view(-1).numpy()
    prediction = forward_pass(network, image_np, i)
    print(f"Image {i} Prediction: {prediction}, Label: {label}")

# Calculate accuracy
predictions = redis_client.hgetall('predictions')
correct = sum(int(predictions[k]) == mnist_test[int(k)][1] for k in predictions)
accuracy = correct / len(predictions)
print(f"Test Accuracy: {accuracy * 100:.2f}%")


🔥 Sending initial activation message to layer-0...
📤 Layer layer_0 sending activation messages to topic layer-0...
✅ Message sent to layer-0, partition 0
✅ Message sent to layer-0, partition 1
✅ Message sent to layer-0, partition 2
✅ Message sent to layer-0, partition 3
✅ Message sent to layer-0, partition 4
✅ Message sent to layer-0, partition 5
✅ Message sent to layer-0, partition 6
✅ Message sent to layer-0, partition 7
✅ Message sent to layer-0, partition 8
✅ Message sent to layer-0, partition 9
✅ Message sent to layer-0, partition 10
✅ Message sent to layer-0, partition 11
✅ Message sent to layer-0, partition 12
✅ Message sent to layer-0, partition 13
✅ Message sent to layer-0, partition 14
✅ Message sent to layer-0, partition 15
✅ Message sent to layer-0, partition 16
✅ Message sent to layer-0, partition 17
✅ Message sent to layer-0, partition 18
✅ Message sent to layer-0, partition 19
✅ Message sent to layer-0, partition 20
✅ Message sent to layer-0, partition 21
✅ Message sent 

KeyboardInterrupt: 