In [2]:
# 安装必要的库
!pip install paho-mqtt wfdb

import wfdb
import numpy as np
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import paho.mqtt.client as mqtt
import time
from collections import deque
import os
import unittest

# 下载MIT-BIH心电数据库数据集
if not os.path.isdir("mitdb"):
    wfdb.dl_database('mitdb', 'mitdb')

def load_data(record_name):
    try:
        record = wfdb.rdrecord(record_name, sampto=300000)
        annotation = wfdb.rdann(record_name, 'atr', sampto=300000)
        ecg_signal = record.p_signal[:, 0]
        annotations = annotation.sample

        fs = record.fs
        samples_per_minute = int(fs * 60)
        num_minutes = len(ecg_signal) // samples_per_minute

        heartbeat_counts = []

        for i in range(num_minutes):
            start = i * samples_per_minute
            end = (i + 1) * samples_per_minute
            segment_annotations = annotations[(annotations >= start) & (annotations < end)]
            heartbeat_counts.append(len(segment_annotations))

        return np.array(heartbeat_counts).reshape(-1, 1)
    except Exception as e:
        print(f"加载数据时出错: {e}")
        return np.array([]).reshape(-1, 1)

class BridgeMQTTClient(mqtt.Client):
    def __init__(self, cname, model, **kwargs):
        super(BridgeMQTTClient, self).__init__(**kwargs)
        self.cname = cname
        self.on_connect = self.on_connect_callback
        self.on_message = self.on_message_callback
        self.processed_messages = deque(maxlen=100)
        self.model = model

    def on_connect_callback(self, client, userdata, flags, rc):
        if rc == 0:
            print(f"{self.cname} 已连接到 broker")
            if (self.cname == "bridge_c1" and bridge_mode in ["in", "both"]) or \
               (self.cname == "bridge_c2" and bridge_mode in ["out", "both"]):
                self.subscribe(topic)
        else:
            print(f"连接失败，代码 {rc}，客户端 {self.cname}")

    def on_message_callback(self, client, userdata, msg):
        m_decode = str(msg.payload.decode("utf-8"))

        print(f"{self.cname} 收到消息: {m_decode}")
        heartbeat = int(m_decode)
        prediction = self.model.predict([[heartbeat]])
        if prediction == 0:
            print(f"来自 {msg.topic} 的心跳正常")
        else:
            print(f"来自 {msg.topic} 的心跳异常")

        if self.cname == "bridge_c1" and bridge_mode in ["in", "both"]:
            bridge_client2.publish(msg.topic, m_decode)
            print(f"已转发消息到 broker2: {m_decode}")

        elif self.cname == "bridge_c2" and bridge_mode in ["out", "both"]:
            bridge_client1.publish(msg.topic, m_decode)
            print(f"已转发消息到 broker1: {m_decode}")

class TestECGDataAndModel(unittest.TestCase):

    def test_load_data(self):
        data = load_data("mitdb/100")
        self.assertTrue(data.size > 0, "数据未正确加载")
        self.assertEqual(data.shape[1], 1, "数据形状不正确")

    def test_model_training(self):
        data = load_data("mitdb/100")
        if data.size > 0:
            labels = np.where((data >= 60) & (data <= 100), 0, 1)
            if len(np.unique(labels)) > 1:  # Check if there are at least two classes
                X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.4, random_state=42)
                model = svm.SVC()
                model.fit(X_train, y_train)
                y_pred = model.predict(X_test)
                accuracy = accuracy_score(y_test, y_pred)
                self.assertGreaterEqual(accuracy, 0.7, "模型准确率低于70%")
            else:
                self.skipTest("数据集中只有一个类，无法进行训练")

class TestMQTTClient(unittest.TestCase):

    def setUp(self):
        self.model = svm.SVC()
        data = load_data("mitdb/100")
        labels = np.where((data >= 60) & (data <= 100), 0, 1)
        if len(np.unique(labels)) > 1:
            X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.4, random_state=42)
            self.model.fit(X_train, y_train)
        else:
            self.skipTest("数据集中只有一个类，无法进行训练")

    def test_on_message_callback(self):
        client = BridgeMQTTClient(cname="test_client", model=self.model)
        client.on_connect = client.on_connect_callback
        client.on_message = client.on_message_callback

        class Msg:
            def __init__(self, payload, topic):
                self.payload = payload
                self.topic = topic

        msg = Msg(payload=b'75', topic="patient/123/heartbeat")
        client.on_message_callback(client, None, msg)
        msg = Msg(payload=b'50', topic="patient/123/heartbeat")
        client.on_message_callback(client, None, msg)

if __name__ == "__main__":
    unittest.main(argv=[''], exit=False)




.ss
----------------------------------------------------------------------
Ran 3 tests in 0.201s

OK (skipped=2)
