In [None]:
!pip install paho-mqtt
!pip install wfdb

import wfdb
import numpy as np
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import paho.mqtt.client as mqtt
import time
from collections import deque
import os

# Download the dataset
if os.path.isdir("mitdb"):
    print('You already have the data.')
else:
    wfdb.dl_database('mitdb', 'mitdb')

# Training
def load_data(record_name):
    try:
        print(f"Loading record: {record_name}")
        record = wfdb.rdrecord(record_name, sampto=300000)
        annotation = wfdb.rdann(record_name, 'atr', sampto=300000)

        # Get ECG signals and annotations
        ecg_signal = record.p_signal[:, 0]
        annotations = annotation.sample

        print(f"ECG signal length: {len(ecg_signal)}")
        print(f"Number of annotations: {len(annotations)}")

        # Split the ECG signal into one-minute segments and calculate heartbeats per minute
        fs = record.fs  # Sampling frequency
        samples_per_minute = int(fs * 60)
        num_minutes = len(ecg_signal) // samples_per_minute

        heartbeat_counts = []

        for i in range(num_minutes):
            start = i * samples_per_minute
            end = (i + 1) * samples_per_minute
            segment_annotations = annotations[(annotations >= start) & (annotations < end)]
            heartbeat_counts.append(len(segment_annotations))

        print(f"Heartbeat counts: {heartbeat_counts}")

        return np.array(heartbeat_counts).reshape(-1, 1)
    except Exception as e:
        print(f"Error loading data: {e}")
        return np.array([]).reshape(-1, 1)

# Load data
data = load_data("mitdb/100")

# Check if data is empty
if data.size == 0:
    print("No data loaded. Please check the record name and path.")
else:

    # Assume normal heartbeat is between 60 and 100 beats
    labels = np.where((data >=60) & (data <= 100), 0, 1)  # 0: Normal, 1: Abnormal

    # Check if the sample size is sufficient for splitting into training and testing sets
    if len(data) < 10:  # Assuming at least 10 samples are needed
        print("Not enough data to split into training and testing sets.")
    else:
        # Split the dataset: 60% training, 40% testing
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.4, random_state=42)

        model = svm.SVC()  # Initialize the SVM model
        model.fit(X_train, y_train)  # Train the SVM model

        # Evaluate the model
        y_pred = model.predict(X_test)
        print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
        print(classification_report(y_test, y_pred))

broker1_address = "broker.emqx.io"
broker2_address = "test.mosquitto.org"
port = 1883
topic = "patient/+/heartbeat"  # Use wildcard to subscribe to multiple topics

# Define bridge mode: "in", "out", "both"
bridge_mode = "in"

class BridgeMQTTClient(mqtt.Client):
    def __init__(self, cname, **kwargs):
        super(BridgeMQTTClient, self).__init__(**kwargs)
        self.cname = cname
        self.on_connect = self.on_connect_callback
        self.on_message = self.on_message_callback
        self.processed_messages = deque(maxlen=100)

    def on_connect_callback(self, client, userdata, flags, rc):
        if rc == 0:
            print(f"{self.cname} connected to broker")
            if (self.cname == "bridge_c1" and bridge_mode in ["in", "both"]) or \
               (self.cname == "bridge_c2" and bridge_mode in ["out", "both"]):
                self.subscribe(topic)
        else:
            print(f"Connection failed with code {rc} for {self.cname}")

    def on_message_callback(self, client, userdata, msg):
        m_decode = str(msg.payload.decode("utf-8"))

        # Process and publish message
        print(f"Received message from {self.cname}: {m_decode}")
        heartbeat = int(m_decode)
        prediction = model.predict([[heartbeat]])
        if prediction == 0:
            print(f"Heartbeat from {msg.topic} is normal")
        else:
            print(f"Heartbeat from {msg.topic} is abnormal")

        if self.cname == "bridge_c1" and bridge_mode in ["in", "both"]:
            bridge_client2.publish(msg.topic, m_decode)  # Keep the original topic
            print(f"Forwarded message to broker2: {m_decode}")

        elif self.cname == "bridge_c2" and bridge_mode in ["out", "both"]:
            bridge_client1.publish(msg.topic, m_decode)  # Keep the original topic
            print(f"Forwarded message to broker1: {m_decode}")

# Initialize the first client to subscribe and forward messages
bridge_client1 = BridgeMQTTClient(cname="bridge_c1")
bridge_client1.connect(broker1_address, port)
bridge_client1.loop_start()

# Initialize the second client to receive forwarded messages
bridge_client2 = BridgeMQTTClient(cname="bridge_c2")
bridge_client2.connect(broker2_address, port)
bridge_client2.loop_start()

# Keep the main thread running
try:
    while True:
        time.sleep(4)
except KeyboardInterrupt:
    pass
finally:
    print('now stop')
    bridge_client1.loop_stop()
    bridge_client2.loop_stop()
    bridge_client1.disconnect()
    bridge_client2.disconnect()
