-
Notifications
You must be signed in to change notification settings - Fork 2
/
mqtt.py
116 lines (97 loc) · 4.27 KB
/
mqtt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
MQTT bridging
=============
dataClay includes support for client communication via MQTT.
In order to use this functionality, the client class has to inherit the MQTTMixin class.
The client will be able to specify how the messages should be handled, which topics will be subscribed to,
and send messages with a topic.
"""
import inspect
import logging
import os
from typing import Any
from dataclay import activemethod
try:
import paho.mqtt.client as mqtt
except ImportError:
import warnings
warnings.warn("Warning: <import paho.mqtt.client> failed", category=ImportWarning)
try:
from json import dumps
except ImportError:
import warnings
warnings.warn("Warning: <from json import dumps> failed", category=ImportWarning)
try:
from dataclay.contrib.mqtt import MQTT_PRODUCERS
except ImportError:
import warnings
warnings.warn(
"Warning: <from dataclay.contrib.mqtt import MQTT_PRODUCERS> failed", category=ImportWarning
)
""" Mqtt pool of producers """
MQTT_PRODUCERS = dict()
logger = logging.getLogger(__name__)
class MQTTMixin:
"""MQTT mechanisms"""
@activemethod
def message_handling(self, client, userdata, msg):
"""Placeholder for function message_handling. This function describes how the client will handle a
message.
Args:
client (paho.mqtt.client.Client): The client instance for this callback.
userdata: The private user data as set in Client() or user_data_set().
msg (MQTTMessage): The received message.
This is a class with members topic, payload, qos, retain.
Raises:
NotImplementedError: If the message_handling function has not been implemented, an error is raised.
"""
raise NotImplementedError("Must override message_handling")
@activemethod
def subscribe_to_topic(self, topic: str = "dataclay"):
"""Subscribes the client to a topic and indicates how it will handle a message received.
Args:
topic (str, optional): String representing the topic. Defaults to "dataclay".
"""
mqtt_host = os.getenv("MQTT_HOST", "mqtt5")
mqtt_port = int(os.getenv("MQTT_PORT", "1883"))
mqtt_client = os.getenv("MQTT_PRODUCER_ID", "dataclay_mqtt_producer")
mqtt_address = f"{mqtt_host}:{mqtt_port}"
if mqtt_address in MQTT_PRODUCERS:
mqtt_producer = MQTT_PRODUCERS[mqtt_address]
else:
mqtt_producer = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, mqtt_client)
mqtt_producer.on_message = self.message_handling
mqtt_producer.connect(mqtt_host, mqtt_port)
mqtt_producer.loop_start()
MQTT_PRODUCERS[mqtt_address] = mqtt_producer
mqtt_producer.subscribe(topic)
@activemethod
def produce_mqtt_msg(self, data: dict[str, Any], topic: str = "dataclay", **more):
"""The client is connected to the broker, and it sends the message to the chosen topic.
Args:
data (dict[str, Any]): Message.
topic (str, optional): Topic of the message. Defaults to "dataclay".
"""
mqtt_host = os.getenv("MQTT_HOST", "mqtt5")
mqtt_port = int(os.getenv("MQTT_PORT", "1883"))
mqtt_client = os.getenv("MQTT_PRODUCER_ID", "dataclay_mqtt_producer")
mqtt_address = f"{mqtt_host}:{mqtt_port}"
if mqtt_address in MQTT_PRODUCERS:
mqtt_producer = MQTT_PRODUCERS[mqtt_address]
else:
mqtt_producer = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, mqtt_client)
mqtt_producer.connect(mqtt_host, mqtt_port)
mqtt_producer.loop_start()
MQTT_PRODUCERS[mqtt_address] = mqtt_producer
data_str = dumps(data).encode("utf-8")
mqtt_producer.publish(topic, data_str, qos=1)
@activemethod
def send_to_mqtt(self):
"""Previous function to produce_mqtt_msg. Gets all the arguments needed from the calling class."""
attributes = inspect.getmembers(self.__class__, lambda a: not (inspect.isroutine(a)))
field_values = {}
for field in attributes:
fieldname = field[0]
if not (fieldname.startswith("_")):
field_values[fieldname] = getattr(self, fieldname)
self.produce_mqtt_msg(**field_values)