Skip to content

Commit

Permalink
simplify receive loop (#1758)
Browse files Browse the repository at this point in the history
  • Loading branch information
f1shl3gs committed May 23, 2024
1 parent 31fc8b0 commit 30a9835
Showing 1 changed file with 133 additions and 153 deletions.
286 changes: 133 additions & 153 deletions src/sources/mqtt/broker.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::cmp::Ordering;
use std::net::SocketAddr;

use bytes::{Buf, BufMut, Bytes, BytesMut};
use bytes::{Buf, BytesMut};
use event::LogRecord;
use framework::tls::MaybeTlsIncomingStream;
use framework::Pipeline;
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
use value::path;

// Specs definition from 2.2.1 MQTT Control Packet:
Expand Down Expand Up @@ -48,186 +47,167 @@ const fn type_name(typ: u8) -> &'static str {
}
}

struct Codec;
pub async fn serve_connection(
peer: SocketAddr,
mut conn: MaybeTlsIncomingStream<TcpStream>,
mut output: Pipeline,
) {
let mut buf = BytesMut::new();

impl tokio_util::codec::Decoder for Codec {
type Item = (u8, Bytes);
type Error = std::io::Error;
'RECV: loop {
if let Err(err) = conn.read_buf(&mut buf).await {
error!(message = "read packet failed", ?err, ?peer,);
return;
}

fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if buf.len() < 2 {
return Ok(None);
continue;
}

let mut remaining = 0usize;
let mut shift = 0;
for pos in 1..buf.len() {
let byte = buf[pos] as usize;
remaining += (byte & 0x7F) << shift;

// stop when continue bit is 0
if byte & 0x80 == 0 {
let want = 1 + pos + remaining;
if buf.len() >= want {
let ctrl_byte = buf[0];
loop {
let ctrl_byte = buf[0];
let mut remaining = 0usize;
let mut shift = 0;
for pos in 1..buf.len() {
let byte = buf[pos] as usize;
remaining += (byte & 0x7F) << shift;

// stop when continue bit is 0
if byte & 0x80 == 0 {
let want = 1 + pos + remaining;
if buf.len() < want {
continue 'RECV;
}

buf.advance(1 + pos);
let payload = buf.split_to(remaining).freeze();

return Ok(Some((ctrl_byte, payload)));
break;
}

break;
}

shift += 7;
shift += 7;

// Only a max of 4 bytes allowed for remaining length
// more than 4 shifts (0, 7, 14, 21) implies bad length
if shift > 21 {
return Err(std::io::Error::other("invalid variable length"));
// Only a max of 4 bytes allowed for remaining length
// more than 4 shifts(0, 7, 14, 21) implies bad length
if shift > 21 {
error!(message = "invalid remaining length");
return;
}
}
}

Ok(None)
}
}

impl tokio_util::codec::Encoder<[u8; 4]> for Codec {
type Error = std::io::Error;

fn encode(&mut self, item: [u8; 4], dst: &mut BytesMut) -> Result<(), Self::Error> {
if item[1] == 0 {
dst.put_slice(&item[..2]);
} else {
dst.put_slice(&item);
}
// handle packets
let mut payload = buf.split_to(remaining).freeze();
match ctrl_byte >> 4 {
MQTT_CONNECT => {
// PROTOCOL NAME
// byte description
// 1 Protocol Name MSB
// 2 Protocol Name LSB
// 3 `M`
// 4 `Q`
// 5 `T`
// 6 `T`
// 7 Protocol version, 4 for MQTT311, 5 for MQTT5
// 8 Connect Flags
// 9 Keepalive MSB
// 10 Keepalive LSB
// 11
// 12
let mut len = payload[0] as usize;
len |= payload[1] as usize;

if len != 4 || payload[2..6].cmp(b"MQTT") != Ordering::Equal {
error!(message = "unknown protocol name");
return;
}

Ok(())
}
}
let version = payload[6];
if payload[6] != MQTT_VERSION_311 {
error!(message = "unsupported MQTT version", version);
return;
}

pub async fn serve_connection(
peer: SocketAddr,
stream: MaybeTlsIncomingStream<TcpStream>,
mut output: Pipeline,
) {
let (mut frame_sink, mut frame_stream) = Framed::new(stream, Codec).split();

while let Some(result) = frame_stream.next().await {
match result {
Ok((ctrl_byte, mut payload)) => {
match ctrl_byte >> 4 {
MQTT_CONNECT => {
// PROTOCOL NAME
// byte description
// 1 Protocol Name MSB
// 2 Protocol Name LSB
// 3 `M`
// 4 `Q`
// 5 `T`
// 6 `T`
// 7 Protocol version, 4 for MQTT311
// 8 Connect Flags
// 9 Keepalive MSB
// 10 Keepalive LSB
// 11
// 12
let mut len = payload[0] as usize;
len |= payload[1] as usize;

if len != 4 || payload[2..6].cmp(b"MQTT") != Ordering::Equal {
error!(message = "unknown protocol name");
return;
if let Err(err) = conn.write_all(&[MQTT_CONNACK << 4, 2, 0, 0]).await {
error!(message = "write CONNACK failed", ?err, ?peer);
return;
}
}
MQTT_PUBLISH => {
let mut len = payload[0] as usize;
len |= payload[1] as usize;
payload.advance(2);

let topic = match String::from_utf8(payload[..len].to_vec()) {
Ok(s) => {
payload.advance(len);
s
}

let version = payload[6];
if payload[6] != MQTT_VERSION_311 {
error!(message = "unsupported MQTT version", version);
Err(err) => {
error!(message = "invalid topic name", ?err, ?peer);
return;
}
};

let qos = (ctrl_byte >> 1) & 0x03;
if qos > MQTT_QOS_LEV0 {
// packet identifier
//
// The Packet Identifier field is only present in
// `PUBLISH` Packets where the QoS level is 1 or 2.
//
// set the identifier that we are replying to
let mut resp = [0u8, 2, payload[0], payload[1]];

if qos == MQTT_QOS_LEV1 {
resp[0] = MQTT_PUBACK << 4;
} else if qos == MQTT_QOS_LEV2 {
resp[0] = MQTT_PUBREC << 4;
}

if let Err(err) = frame_sink.send([MQTT_CONNACK << 4, 2, 0, 0]).await {
error!(message = "write CONNACK failed", ?err, ?peer);
if let Err(err) = conn.write_all(&resp).await {
error!(message = "write PUBLISH response failed", ?err, ?peer);
return;
}
}
MQTT_PUBLISH => {
let mut len = payload[0] as usize;
len |= payload[1] as usize;
payload.advance(2);

let topic = match String::from_utf8(payload[..len].to_vec()) {
Ok(s) => {
payload.advance(len);
s
}
Err(err) => {
error!(message = "invalid topic name", ?err, ?peer);
return;
}
};

let qos = (ctrl_byte >> 1) & 0x03;
if qos > MQTT_QOS_LEV0 {
// packet identifier
//
// The Packet Identifier field is only present in
// `PUBLISH` Packets where the QoS level is 1 or 2.
//
// set the identifier that we are replying to
let mut resp = [0u8, 2, payload[0], payload[1]];

if qos == MQTT_QOS_LEV1 {
resp[0] = MQTT_PUBACK << 4;
} else if qos == MQTT_QOS_LEV2 {
resp[0] = MQTT_PUBREC << 4;
}

if let Err(err) = frame_sink.send(resp).await {
error!(message = "write PUBLISH response failed", ?err, ?peer);
return;
}

payload.advance(2);
}
payload.advance(2);
}

let value: event::log::Value = serde_json::from_slice(&payload).unwrap();
let mut log = LogRecord::from(value);
log.metadata_mut()
.value_mut()
.insert(path!("topic"), topic.to_string());
let value: event::log::Value = serde_json::from_slice(&payload).unwrap();
let mut log = LogRecord::from(value);
log.metadata_mut()
.value_mut()
.insert(path!("topic"), topic.to_string());

if let Err(err) = output.send(log).await {
warn!(message = "send message failed", ?err, ?peer);
return;
}
}
MQTT_PINGREQ => {
let resp = [MQTT_PINGRESP >> 4, 0, 0, 0];
if let Err(err) = frame_sink.send(resp).await {
error!(message = "wrtie PINGRESP failed", ?err, ?peer);
return;
}
}
MQTT_DISCONNECT => {
debug!(message = "client disconnect", ?peer);
if let Err(err) = output.send(log).await {
warn!(message = "send message failed", ?err, ?peer);
return;
}
typ => {
error!(
message = "unsupported packet type",
?peer,
name = type_name(typ),
r#typ
);
}
MQTT_PINGREQ => {
let resp = [MQTT_PINGRESP >> 4, 0];
if let Err(err) = conn.write(&resp).await {
error!(message = "wrtie PINGRESP failed", ?err, ?peer);
return;
}
}
MQTT_DISCONNECT => {
debug!(message = "client disconnect", ?peer);
return;
}
typ => {
error!(
message = "unsupported packet type",
?peer,
name = type_name(typ),
r#typ
);
return;
}
}
Err(err) => {
error!(message = "read packet failed", ?err, ?peer);
return;

if buf.is_empty() {
// reuse buf
buf.clear();
break;
}
}
}
Expand Down

0 comments on commit 30a9835

Please sign in to comment.