In [None]:
!pip install "paho-mqtt<2.0.0"


In [None]:
!pip install wfdb

In [None]:
import os
import threading
import time
from collections import deque

import numpy as np
import paho.mqtt.client as mqtt
import wfdb
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Download the dataset if not already present
if not os.path.isdir("mitdb"):
    wfdb.dl_database('mitdb', 'mitdb')

class DataLoader:
    @staticmethod
    def load_data(record_name):
        """
        Loads ECG heartbeat annotation data from a WFDB record.

        Args:
            record_name (str): Name of the record to load.

        Returns:
            np.array: Array of heartbeat counts per minute.
        """
        try:
            print(f"Loading record: {record_name}")
            record = wfdb.rdrecord(record_name, sampto=300000)
            annotation = wfdb.rdann(record_name, 'atr', sampto=300000)

            # Get ECG signals and annotations
            ecg_signal = record.p_signal[:, 0]
            annotations = annotation.sample

            print(f"ECG signal length: {len(ecg_signal)}")
            print(f"Number of annotations: {len(annotations)}")

            # Split the ECG signal into one-minute segments and calculate heartbeats per minute
            fs = record.fs  # Sampling frequency
            samples_per_minute = int(fs * 60)
            num_minutes = len(ecg_signal) // samples_per_minute

            heartbeat_counts = []

            for i in range(num_minutes):
                start = i * samples_per_minute
                end = (i + 1) * samples_per_minute
                segment_annotations = annotations[(annotations >= start) & (annotations < end)]
                heartbeat_counts.append(len(segment_annotations))

            print(f"Heartbeat counts: {heartbeat_counts}")

            return np.array(heartbeat_counts).reshape(-1, 1)
        except Exception as e:
            print(f"Error loading data: {e}")
            return np.array([]).reshape(-1, 1)

class ECGClassifier:
    def __init__(self):
        """
        Initializes an ECGClassifier using Support Vector Machine (SVM).
        """
        self.model = svm.SVC()

    def train(self, X_train, y_train):
        """
        Trains the SVM classifier.

        Args:
            X_train (np.array): Training data.
            y_train (np.array): Training labels.
        """
        self.model.fit(X_train, y_train)

    def evaluate(self, X_test, y_test):
        """
        Evaluates the trained classifier on test data.

        Args:
            X_test (np.array): Test data.
            y_test (np.array): Test labels.
        """
        y_pred = self.model.predict(X_test)
        print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
        print(classification_report(y_test, y_pred))

class BridgeMQTTClient(mqtt.Client):
    def __init__(self, client_name, broker_address, port, topic, bridge_mode):
        """
        Initializes a MQTT client for bridging messages between brokers.

        Args:
            client_name (str): Name of the MQTT client.
            broker_address (str): Address of the MQTT broker.
            port (int): Port number of the MQTT broker.
            topic (str): Topic to subscribe/publish messages.
            bridge_mode (str): Bridge mode ('in', 'out', or 'both').
        """
        super().__init__()
        self.client_name = client_name
        self.broker_address = broker_address
        self.port = port
        self.topic = topic
        self.bridge_mode = bridge_mode
        self.processed_messages = deque(maxlen=100)

        self.on_connect = self.on_connect_callback
        self.on_message = self.on_message_callback

    def connect_to_broker(self):
        """
        Connects to the MQTT broker and starts the client loop.
        """
        self.connect(self.broker_address, self.port)
        self.loop_start()

    def on_connect_callback(self, client, userdata, flags, rc):
        """
        Callback function called when the MQTT client connects to the broker.

        Args:
            client (mqtt.Client): The MQTT client instance.
            userdata: User data passed during connection.
            flags: Response flags sent by the broker.
            rc (int): Result code of the connection attempt.
        """
        if rc == 0:
            print(f"{self.client_name} connected to broker")
            if (self.client_name == "bridge_c1" and self.bridge_mode in ["in", "both"]) or \
               (self.client_name == "bridge_c2" and self.bridge_mode in ["out", "both"]):
                self.subscribe(self.topic)
        else:
            print(f"Connection failed with code {rc} for {self.client_name}")

    def on_message_callback(self, client, userdata, msg):
        """
        Callback function called when a message is received from the broker.

        Args:
            client (mqtt.Client): The MQTT client instance.
            userdata: User data passed during message reception.
            msg (mqtt.MQTTMessage): The received message object.
        """
        m_decode = str(msg.payload.decode("utf-8"))
        print(f"Received message from {self.client_name}: {m_decode}")

        # Example logic: process and forward message
        # Your actual logic here may vary
        if self.client_name == "bridge_c1" and self.bridge_mode in ["in", "both"]:
            print(f"Forwarding message to broker2: {m_decode}")
        elif self.client_name == "bridge_c2" and self.bridge_mode in ["out", "both"]:
            print(f"Forwarding message to broker1: {m_decode}")

# Main script
if __name__ == "__main__":
    # Initialize data loader and load data
    data_loader = DataLoader()
    data = data_loader.load_data("mitdb/100")

    if data.size == 0:
        print("No data loaded. Please check the record name and path.")
    else:
        # Preprocess data
        labels = np.where((data >= 70) & (data <= 75), 0, 1)  # 0: Normal, 1: Abnormal

        if len(data) < 10:
            print("Not enough data to split into training and testing sets.")
        else:
            # Split data into training and testing sets
            X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.4, random_state=42)

            # Initialize and train the classifier
            classifier = ECGClassifier()
            classifier.train(X_train, y_train)

            # Evaluate the classifier
            classifier.evaluate(X_test, y_test)

            # Initialize MQTT bridge clients
            broker1_address = "broker.emqx.io"
            broker2_address = "test.mosquitto.org"
            port = 1883
            topic = "patient/+/heartbeat"
            bridge_mode = "in"  # Example, adjust as needed

            bridge_client1 = BridgeMQTTClient("bridge_c1", broker1_address, port, topic, bridge_mode)
            bridge_client1.connect_to_broker()

            bridge_client2 = BridgeMQTTClient("bridge_c2", broker2_address, port, topic, bridge_mode)
            bridge_client2.connect_to_broker()

            # Keep the main thread running
            try:
                while True:
                    time.sleep(4)
            except KeyboardInterrupt:
                print("Interrupted")
            finally:
                print("Stopping clients...")
                bridge_client1.loop_stop()
                bridge_client2.loop_stop()
                bridge_client1.disconnect()
                bridge_client2.disconnect()
