diff --git a/Cargo.lock b/Cargo.lock index 02fc5087..3d85016c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -550,6 +550,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "hickory-proto" version = "0.24.0" @@ -1561,6 +1567,7 @@ dependencies = [ "derive_more", "dns-lookup", "etcetera", + "hex-literal", "hickory-resolver", "humantime", "indexmap", diff --git a/Cargo.toml b/Cargo.toml index f23cc62e..9b2026d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ windows-sys = { version = "0.48.0", features = [ ] } [dev-dependencies] +hex-literal = "0.4.1" rand = "0.8.5" test-case = "3.2.1" diff --git a/src/backend/trace.rs b/src/backend/trace.rs index 6f5174b0..fabbc30b 100644 --- a/src/backend/trace.rs +++ b/src/backend/trace.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::iter::once; use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; -use trippy::tracing::{Probe, ProbeStatus, TracerRound}; +use trippy::tracing::{Extensions, Probe, ProbeStatus, TracerRound}; /// The state of all hops in a trace. #[derive(Debug, Clone)] @@ -121,6 +121,7 @@ pub struct Hop { mean: f64, m2: f64, samples: Vec, + extensions: Option, } impl Hop { @@ -200,6 +201,10 @@ impl Hop { pub fn samples(&self) -> &[Duration] { &self.samples } + + pub fn extensions(&self) -> Option<&Extensions> { + self.extensions.as_ref() + } } impl Default for Hop { @@ -216,6 +221,7 @@ impl Default for Hop { mean: 0f64, m2: 0f64, samples: Vec::default(), + extensions: None, } } } @@ -319,6 +325,7 @@ impl TraceData { } let host = probe.host.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); *hop.addrs.entry(host).or_default() += 1; + hop.extensions = probe.extensions.clone(); } ProbeStatus::Awaited => { let index = usize::from(probe.ttl.0) - 1; diff --git a/src/config.rs b/src/config.rs index 87484561..4ac3af96 100644 --- a/src/config.rs +++ b/src/config.rs @@ -199,6 +199,7 @@ pub struct TrippyConfig { pub max_inflight: u8, pub initial_sequence: u16, pub tos: u8, + pub icmp_extensions: bool, pub read_timeout: Duration, pub packet_size: u16, pub payload_pattern: u8, @@ -377,6 +378,11 @@ impl TryFrom<(Args, &Platform)> for TrippyConfig { cfg_file_strategy.tos, constants::DEFAULT_STRATEGY_TOS, ); + let icmp_extensions = cfg_layer_bool_flag( + args.icmp_extensions, + cfg_file_strategy.icmp_extensions, + false, + ); let read_timeout = cfg_layer( args.read_timeout, cfg_file_strategy.read_timeout, @@ -559,6 +565,7 @@ impl TryFrom<(Args, &Platform)> for TrippyConfig { packet_size, payload_pattern, tos, + icmp_extensions, source_addr, interface, port_direction, diff --git a/src/config/cmd.rs b/src/config/cmd.rs index 88dfcb05..43e50573 100644 --- a/src/config/cmd.rs +++ b/src/config/cmd.rs @@ -129,6 +129,10 @@ pub struct Args { #[arg(short = 'Q', long)] pub tos: Option, + /// Parse ICMP extensions + #[arg(short = 'e', long)] + pub icmp_extensions: bool, + /// The socket read timeout [default: 10ms] #[arg(long)] pub read_timeout: Option, diff --git a/src/config/file.rs b/src/config/file.rs index 77661b86..bad5e4fb 100644 --- a/src/config/file.rs +++ b/src/config/file.rs @@ -111,6 +111,7 @@ pub struct ConfigStrategy { pub packet_size: Option, pub payload_pattern: Option, pub tos: Option, + pub icmp_extensions: Option, pub read_timeout: Option, } diff --git a/src/main.rs b/src/main.rs index f802b2ba..f38c1628 100644 --- a/src/main.rs +++ b/src/main.rs @@ -253,6 +253,7 @@ fn make_channel_config( args.payload_pattern, args.multipath_strategy, args.tos, + args.icmp_extensions, args.read_timeout, args.min_round_duration, ) diff --git a/src/tracing.rs b/src/tracing.rs index 99ba13d5..0dbd7737 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -17,5 +17,5 @@ pub use config::{ pub use net::channel::TracerChannel; pub use net::source::SourceAddr; pub use net::SocketImpl; -pub use probe::{IcmpPacketType, Probe, ProbeStatus}; +pub use probe::{Extension, Extensions, IcmpPacketType, Probe, ProbeStatus}; pub use tracer::{Tracer, TracerRound}; diff --git a/src/tracing/config.rs b/src/tracing/config.rs index 0566d1d8..f9c14d9c 100644 --- a/src/tracing/config.rs +++ b/src/tracing/config.rs @@ -176,6 +176,7 @@ pub struct TracerChannelConfig { pub payload_pattern: PayloadPattern, pub multipath_strategy: MultipathStrategy, pub tos: TypeOfService, + pub icmp_extensions: bool, pub read_timeout: Duration, pub tcp_connect_timeout: Duration, } @@ -193,6 +194,7 @@ impl TracerChannelConfig { payload_pattern: u8, multipath_strategy: MultipathStrategy, tos: u8, + icmp_extensions: bool, read_timeout: Duration, tcp_connect_timeout: Duration, ) -> Self { @@ -206,6 +208,7 @@ impl TracerChannelConfig { payload_pattern: PayloadPattern(payload_pattern), multipath_strategy, tos: TypeOfService(tos), + icmp_extensions, read_timeout, tcp_connect_timeout, } diff --git a/src/tracing/net.rs b/src/tracing/net.rs index 259bdd2b..8e79459d 100644 --- a/src/tracing/net.rs +++ b/src/tracing/net.rs @@ -11,6 +11,9 @@ mod ipv4; /// IPv6 implementation. mod ipv6; +/// ICMP extensions. +mod extension; + /// Platform specific network code. mod platform; diff --git a/src/tracing/net/channel.rs b/src/tracing/net/channel.rs index 773aa071..f0f1bcf3 100644 --- a/src/tracing/net/channel.rs +++ b/src/tracing/net/channel.rs @@ -29,6 +29,7 @@ pub struct TracerChannel { payload_pattern: PayloadPattern, multipath_strategy: MultipathStrategy, tos: TypeOfService, + icmp_extensions: bool, read_timeout: Duration, tcp_connect_timeout: Duration, send_socket: Option, @@ -68,6 +69,7 @@ impl TracerChannel { payload_pattern: config.payload_pattern, multipath_strategy: config.multipath_strategy, tos: config.tos, + icmp_extensions: config.icmp_extensions, read_timeout: config.read_timeout, tcp_connect_timeout: config.tcp_connect_timeout, send_socket, @@ -95,7 +97,7 @@ impl Network for TracerChannel { resp => Ok(resp), }, }?; - if let Some(resp) = prob_response { + if let Some(resp) = &prob_response { tracing::debug!(?resp); } Ok(prob_response) @@ -170,10 +172,10 @@ impl TracerChannel { fn dispatch_tcp_probe(&mut self, probe: Probe) -> TraceResult<()> { let socket = match (self.src_addr, self.dest_addr) { (IpAddr::V4(src_addr), IpAddr::V4(dest_addr)) => { - ipv4::dispatch_tcp_probe(probe, src_addr, dest_addr, self.tos) + ipv4::dispatch_tcp_probe(&probe, src_addr, dest_addr, self.tos) } (IpAddr::V6(src_addr), IpAddr::V6(dest_addr)) => { - ipv6::dispatch_tcp_probe(probe, src_addr, dest_addr) + ipv6::dispatch_tcp_probe(&probe, src_addr, dest_addr) } _ => unreachable!(), }?; @@ -187,8 +189,16 @@ impl TracerChannel { fn recv_icmp_probe(&mut self) -> TraceResult> { if self.recv_socket.is_readable(self.read_timeout)? { match self.dest_addr { - IpAddr::V4(_) => ipv4::recv_icmp_probe(&mut self.recv_socket, self.protocol), - IpAddr::V6(_) => ipv6::recv_icmp_probe(&mut self.recv_socket, self.protocol), + IpAddr::V4(_) => ipv4::recv_icmp_probe( + &mut self.recv_socket, + self.protocol, + self.icmp_extensions, + ), + IpAddr::V6(_) => ipv6::recv_icmp_probe( + &mut self.recv_socket, + self.protocol, + self.icmp_extensions, + ), } } else { Ok(None) diff --git a/src/tracing/net/extension.rs b/src/tracing/net/extension.rs new file mode 100644 index 00000000..4c28c539 --- /dev/null +++ b/src/tracing/net/extension.rs @@ -0,0 +1,66 @@ +use crate::tracing::error::TracerError; +use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeaderPacket; +use crate::tracing::packet::icmp_extension::extension_object::{ClassNum, ExtensionObjectPacket}; +use crate::tracing::packet::icmp_extension::extension_structure::ExtensionsPacket; +use crate::tracing::packet::icmp_extension::mpls_label_stack::MplsLabelStackPacket; +use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMemberPacket; +use crate::tracing::probe::{Extension, Extensions, MplsLabelStack, MplsLabelStackMember}; +use crate::tracing::util::Required; + +/// The supported ICMP extension version number. +const ICMP_EXTENSION_VERSION: u8 = 2; + +impl TryFrom<&[u8]> for Extensions { + type Error = TracerError; + + fn try_from(value: &[u8]) -> Result { + Self::try_from(ExtensionsPacket::new_view(value).req()?) + } +} + +impl TryFrom> for Extensions { + type Error = TracerError; + + fn try_from(value: ExtensionsPacket<'_>) -> Result { + let header = ExtensionHeaderPacket::new_view(value.header()).req()?; + if header.get_version() != ICMP_EXTENSION_VERSION { + return Ok(Self::default()); + } + let extensions = value + .objects() + .flat_map(|obj| ExtensionObjectPacket::new_view(obj).req()) + .map(|obj| match obj.get_class_num() { + ClassNum::MultiProtocolLabelSwitchingLabelStack => { + MplsLabelStackPacket::new_view(obj.payload()) + .req() + .map(|mpls| Extension::Mpls(MplsLabelStack::from(mpls))) + } + _ => Ok(Extension::Unknown), + }) + .collect::>()?; + Ok(Self { extensions }) + } +} + +impl From> for MplsLabelStack { + fn from(value: MplsLabelStackPacket<'_>) -> Self { + Self { + members: value + .members() + .flat_map(|member| MplsLabelStackMemberPacket::new_view(member).req()) + .map(MplsLabelStackMember::from) + .collect(), + } + } +} + +impl From> for MplsLabelStackMember { + fn from(value: MplsLabelStackMemberPacket<'_>) -> Self { + Self { + label: value.get_label(), + exp: value.get_exp(), + bos: value.get_bos(), + ttl: value.get_ttl(), + } + } +} diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 34e8e94e..fd7f32ca 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -14,8 +14,8 @@ use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; use crate::tracing::packet::IpProtocol; use crate::tracing::probe::{ - ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, - ProbeResponseSeqUdp, + Extensions, ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, + ProbeResponseSeqTcp, ProbeResponseSeqUdp, }; use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId, TypeOfService}; use crate::tracing::util::Required; @@ -187,7 +187,7 @@ fn dispatch_udp_probe_non_raw( #[instrument(skip(probe))] pub fn dispatch_tcp_probe( - probe: Probe, + probe: &Probe, src_addr: Ipv4Addr, dest_addr: Ipv4Addr, tos: TypeOfService, @@ -206,12 +206,13 @@ pub fn dispatch_tcp_probe( pub fn recv_icmp_probe( recv_socket: &mut S, protocol: TracerProtocol, + icmp_extensions: bool, ) -> TraceResult> { let mut buf = [0_u8; MAX_PACKET_SIZE]; match recv_socket.read(&mut buf) { Ok(bytes_read) => { let ipv4 = Ipv4Packet::new_view(&buf[..bytes_read]).req()?; - Ok(extract_probe_resp(protocol, &ipv4)?) + Ok(extract_probe_resp(protocol, icmp_extensions, &ipv4)?) } Err(err) => match err.kind() { ErrorKind::WouldBlock => Ok(None), @@ -248,11 +249,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -343,6 +343,7 @@ fn udp_payload_size(packet_size: usize) -> usize { #[instrument] fn extract_probe_resp( protocol: TracerProtocol, + icmp_extensions: bool, ipv4: &Ipv4Packet<'_>, ) -> TraceResult> { let recv = SystemTime::now(); @@ -352,15 +353,28 @@ fn extract_probe_resp( IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v4.packet()).req()?; let nested_ipv4 = Ipv4Packet::new_view(packet.payload()).req()?; + let extension = if icmp_extensions { + packet.extension().map(Extensions::try_from).transpose()? + } else { + None + }; extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { - ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, src, resp_seq)) + ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, src, resp_seq), extension) }) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet()).req()?; let nested_ipv4 = Ipv4Packet::new_view(packet.payload()).req()?; + let extension = if icmp_extensions { + packet.extension().map(Extensions::try_from).transpose()? + } else { + None + }; extract_probe_resp_seq(&nested_ipv4, protocol)?.map(|resp_seq| { - ProbeResponse::DestinationUnreachable(ProbeResponseData::new(recv, src, resp_seq)) + ProbeResponse::DestinationUnreachable( + ProbeResponseData::new(recv, src, resp_seq), + extension, + ) }) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 304ba433..8fe6cfb1 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -14,8 +14,8 @@ use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; use crate::tracing::packet::IpProtocol; use crate::tracing::probe::{ - ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, - ProbeResponseSeqUdp, + Extensions, ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, + ProbeResponseSeqTcp, ProbeResponseSeqUdp, }; use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId}; use crate::tracing::util::Required; @@ -158,7 +158,7 @@ fn dispatch_udp_probe_non_raw( #[instrument(skip(probe))] pub fn dispatch_tcp_probe( - probe: Probe, + probe: &Probe, src_addr: Ipv6Addr, dest_addr: Ipv6Addr, ) -> TraceResult { @@ -175,6 +175,7 @@ pub fn dispatch_tcp_probe( pub fn recv_icmp_probe( recv_socket: &mut S, protocol: TracerProtocol, + icmp_extensions: bool, ) -> TraceResult> { let mut buf = [0_u8; MAX_PACKET_SIZE]; match recv_socket.recv_from(&mut buf) { @@ -184,7 +185,12 @@ pub fn recv_icmp_probe( SocketAddr::V6(addr) => addr.ip(), SocketAddr::V4(_) => panic!(), }; - Ok(extract_probe_resp(protocol, &icmp_v6, *src_addr)?) + Ok(extract_probe_resp( + protocol, + icmp_extensions, + &icmp_v6, + *src_addr, + )?) } Err(err) => match err.kind() { ErrorKind::WouldBlock => Ok(None), @@ -221,11 +227,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -288,6 +293,7 @@ fn udp_payload_size(packet_size: usize) -> usize { fn extract_probe_resp( protocol: TracerProtocol, + icmp_extensions: bool, icmp_v6: &IcmpPacket<'_>, src: Ipv6Addr, ) -> TraceResult> { @@ -297,15 +303,28 @@ fn extract_probe_resp( IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v6.packet()).req()?; let nested_ipv6 = Ipv6Packet::new_view(packet.payload()).req()?; + let extension = if icmp_extensions { + packet.extension().map(Extensions::try_from).transpose()? + } else { + None + }; extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { - ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, ip, resp_seq)) + ProbeResponse::TimeExceeded(ProbeResponseData::new(recv, ip, resp_seq), extension) }) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v6.packet()).req()?; let nested_ipv6 = Ipv6Packet::new_view(packet.payload()).req()?; + let extension = if icmp_extensions { + packet.extension().map(Extensions::try_from).transpose()? + } else { + None + }; extract_probe_resp_seq(&nested_ipv6, protocol)?.map(|resp_seq| { - ProbeResponse::DestinationUnreachable(ProbeResponseData::new(recv, ip, resp_seq)) + ProbeResponse::DestinationUnreachable( + ProbeResponseData::new(recv, ip, resp_seq), + extension, + ) }) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/packet.rs b/src/tracing/packet.rs index ab9f5616..18d886e7 100644 --- a/src/tracing/packet.rs +++ b/src/tracing/packet.rs @@ -9,6 +9,9 @@ pub mod icmpv4; /// `ICMPv6` packets. pub mod icmpv6; +/// `ICMP` extensions. +pub mod icmp_extension; + /// `IPv4` packets. pub mod ipv4; diff --git a/src/tracing/packet/icmp_extension.rs b/src/tracing/packet/icmp_extension.rs new file mode 100644 index 00000000..78cc8ef6 --- /dev/null +++ b/src/tracing/packet/icmp_extension.rs @@ -0,0 +1,1090 @@ +pub mod extension_structure { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::extension_object::ExtensionObjectPacket; + + /// Represents an ICMP `ExtensionsPacket` pseudo object. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionsPacket<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionsPacket<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn header(&self) -> &[u8] { + &self.buf.as_slice()[..Self::minimum_packet_size()] + } + + /// An iterator of Extension Objects contained within this `ExtensionsPacket`. + #[must_use] + pub fn objects(&self) -> ExtensionObjectIter<'_> { + ExtensionObjectIter::new(&self.buf) + } + } + + pub struct ExtensionObjectIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + } + + impl<'a> ExtensionObjectIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: ExtensionsPacket::minimum_packet_size(), + } + } + } + + impl<'a> Iterator for ExtensionObjectIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.offset >= self.buf.as_slice().len() { + None + } else { + let object_bytes = &self.buf.as_slice()[self.offset..]; + if let Some(object) = ExtensionObjectPacket::new_view(object_bytes) { + let length = object.get_length(); + // If a malformed extension object has a length of 0 then we end iteration. + if length == 0 { + return None; + } + self.offset += usize::from(length); + Some(object_bytes) + } else { + None + } + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeaderPacket; + use crate::tracing::packet::icmp_extension::extension_object::{ + ClassNum, ClassSubType, ExtensionObjectPacket, + }; + + #[test] + fn test_header() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionsPacket::new_view(&buf).unwrap(); + let header = ExtensionHeaderPacket::new_view(extensions.header()).unwrap(); + assert_eq!(2, header.get_version()); + assert_eq!(0x993A, header.get_checksum()); + } + + #[test] + fn test_object_iterator() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionsPacket::new_view(&buf).unwrap(); + let mut object_iter = extensions.objects(); + let object_bytes = object_iter.next().unwrap(); + let object = ExtensionObjectPacket::new_view(object_bytes).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + assert!(object_iter.next().is_none()); + } + + #[test] + fn test_object_iterator_zero_length() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x00, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionsPacket::new_view(&buf).unwrap(); + let mut object_iter = extensions.objects(); + assert!(object_iter.next().is_none()); + } + } +} + +pub mod extension_header { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const VERSION_OFFSET: usize = 0; + const CHECKSUM_OFFSET: usize = 2; + + /// Represents an ICMP `ExtensionHeaderPacket`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionHeaderPacket<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionHeaderPacket<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_version(&self) -> u8 { + (self.buf.read(VERSION_OFFSET) & 0xf0) >> 4 + } + + #[must_use] + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) + } + + pub fn set_version(&mut self, val: u8) { + *self.buf.write(VERSION_OFFSET) = + (self.buf.read(VERSION_OFFSET) & 0xf) | ((val & 0xf) << 4); + } + + pub fn set_checksum(&mut self, val: u16) { + self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for ExtensionHeaderPacket<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionHeader") + .field("version", &self.get_version()) + .field("checksum", &self.get_checksum()) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_version() { + let mut buf = [0_u8; ExtensionHeaderPacket::minimum_packet_size()]; + let mut extension = ExtensionHeaderPacket::new(&mut buf).unwrap(); + extension.set_version(0); + assert_eq!(0, extension.get_version()); + assert_eq!([0x00], extension.packet()[0..1]); + extension.set_version(2); + assert_eq!(2, extension.get_version()); + assert_eq!([0x20], extension.packet()[0..1]); + extension.set_version(15); + assert_eq!(15, extension.get_version()); + assert_eq!([0xF0], extension.packet()[0..1]); + } + + #[test] + fn test_checksum() { + let mut buf = [0_u8; ExtensionHeaderPacket::minimum_packet_size()]; + let mut extension = ExtensionHeaderPacket::new(&mut buf).unwrap(); + extension.set_checksum(0); + assert_eq!(0, extension.get_checksum()); + assert_eq!([0x00, 0x00], extension.packet()[2..=3]); + extension.set_checksum(1999); + assert_eq!(1999, extension.get_checksum()); + assert_eq!([0x07, 0xCF], extension.packet()[2..=3]); + extension.set_checksum(39226); + assert_eq!(39226, extension.get_checksum()); + assert_eq!([0x99, 0x3A], extension.packet()[2..=3]); + extension.set_checksum(u16::MAX); + assert_eq!(u16::MAX, extension.get_checksum()); + assert_eq!([0xFF, 0xFF], extension.packet()[2..=3]); + } + + #[test] + fn test_extension_header_view() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extension = ExtensionHeaderPacket::new_view(&buf).unwrap(); + assert_eq!(2, extension.get_version()); + assert_eq!(0x993A, extension.get_checksum()); + } + } +} + +pub mod extension_object { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::fmt_payload; + use std::fmt::{Debug, Formatter}; + + /// The ICMP Extension Object Class Num. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub enum ClassNum { + MultiProtocolLabelSwitchingLabelStack, + InterfaceInformationObject, + InterfaceIdentificationObject, + ExtendedInformation, + Other(u8), + } + + impl ClassNum { + #[must_use] + pub fn id(&self) -> u8 { + match self { + Self::MultiProtocolLabelSwitchingLabelStack => 1, + Self::InterfaceInformationObject => 2, + Self::InterfaceIdentificationObject => 3, + Self::ExtendedInformation => 4, + Self::Other(id) => *id, + } + } + } + + impl From for ClassNum { + fn from(val: u8) -> Self { + match val { + 1 => Self::MultiProtocolLabelSwitchingLabelStack, + 2 => Self::InterfaceInformationObject, + 3 => Self::InterfaceIdentificationObject, + 4 => Self::ExtendedInformation, + id => Self::Other(id), + } + } + } + + /// The ICMP Extension Object Class Sub-type. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub struct ClassSubType(pub u8); + + impl From for ClassSubType { + fn from(val: u8) -> Self { + Self(val) + } + } + + const LENGTH_OFFSET: usize = 0; + const CLASS_NUM_OFFSET: usize = 2; + const CLASS_SUBTYPE_OFFSET: usize = 3; + + /// Represents an ICMP `ExtensionObjectPacket`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionObjectPacket<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionObjectPacket<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + pub fn set_length(&mut self, val: u16) { + self.buf.set_bytes(LENGTH_OFFSET, val.to_be_bytes()); + } + + pub fn set_class_num(&mut self, val: ClassNum) { + *self.buf.write(CLASS_NUM_OFFSET) = val.id(); + } + + pub fn set_class_subtype(&mut self, val: ClassSubType) { + *self.buf.write(CLASS_SUBTYPE_OFFSET) = val.0; + } + + pub fn set_payload(&mut self, vals: &[u8]) { + let current_offset = Self::minimum_packet_size(); + self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] + .copy_from_slice(vals); + } + + #[must_use] + pub fn get_length(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(LENGTH_OFFSET)) + } + + #[must_use] + pub fn get_class_num(&self) -> ClassNum { + ClassNum::from(self.buf.read(CLASS_NUM_OFFSET)) + } + + #[must_use] + pub fn get_class_subtype(&self) -> ClassSubType { + ClassSubType::from(self.buf.read(CLASS_SUBTYPE_OFFSET)) + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn payload(&self) -> &[u8] { + &self.buf.as_slice()[Self::minimum_packet_size()..usize::from(self.get_length())] + } + } + + impl Debug for ExtensionObjectPacket<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionObject") + .field("length", &self.get_length()) + .field("class_num", &self.get_class_num()) + .field("class_subtype", &self.get_class_subtype()) + .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_length() { + let mut buf = [0_u8; ExtensionObjectPacket::minimum_packet_size()]; + let mut extension = ExtensionObjectPacket::new(&mut buf).unwrap(); + extension.set_length(0); + assert_eq!(0, extension.get_length()); + assert_eq!([0x00, 0x00], extension.packet()[0..=1]); + extension.set_length(8); + assert_eq!(8, extension.get_length()); + assert_eq!([0x00, 0x08], extension.packet()[0..=1]); + extension.set_length(u16::MAX); + assert_eq!(u16::MAX, extension.get_length()); + assert_eq!([0xFF, 0xFF], extension.packet()[0..=1]); + } + + #[test] + fn test_class_num() { + let mut buf = [0_u8; ExtensionObjectPacket::minimum_packet_size()]; + let mut extension = ExtensionObjectPacket::new(&mut buf).unwrap(); + extension.set_class_num(ClassNum::MultiProtocolLabelSwitchingLabelStack); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension.get_class_num() + ); + assert_eq!([0x01], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceInformationObject); + assert_eq!( + ClassNum::InterfaceInformationObject, + extension.get_class_num() + ); + assert_eq!([0x02], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceIdentificationObject); + assert_eq!( + ClassNum::InterfaceIdentificationObject, + extension.get_class_num() + ); + assert_eq!([0x03], extension.packet()[2..3]); + extension.set_class_num(ClassNum::ExtendedInformation); + assert_eq!(ClassNum::ExtendedInformation, extension.get_class_num()); + assert_eq!([0x04], extension.packet()[2..3]); + extension.set_class_num(ClassNum::Other(255)); + assert_eq!(ClassNum::Other(255), extension.get_class_num()); + assert_eq!([0xFF], extension.packet()[2..3]); + } + + #[test] + fn test_class_subtype() { + let mut buf = [0_u8; ExtensionObjectPacket::minimum_packet_size()]; + let mut extension = ExtensionObjectPacket::new(&mut buf).unwrap(); + extension.set_class_subtype(ClassSubType(0)); + assert_eq!(ClassSubType(0), extension.get_class_subtype()); + assert_eq!([0x00], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(1)); + assert_eq!(ClassSubType(1), extension.get_class_subtype()); + assert_eq!([0x01], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(255)); + assert_eq!(ClassSubType(255), extension.get_class_subtype()); + assert_eq!([0xff], extension.packet()[3..4]); + } + + #[test] + fn test_extension_header_view() { + let buf = [0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01]; + let object = ExtensionObjectPacket::new_view(&buf).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + } + } +} + +pub mod mpls_label_stack { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMemberPacket; + + /// Represents an ICMP `MplsLabelStackPacket`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStackPacket<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStackPacket<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn members(&self) -> MplsLabelStackIter<'_> { + MplsLabelStackIter::new(&self.buf) + } + } + + pub struct MplsLabelStackIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + bos: u8, + } + + impl<'a> MplsLabelStackIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: 0, + bos: 0, + } + } + } + + impl<'a> Iterator for MplsLabelStackIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.bos > 0 || self.offset >= self.buf.as_slice().len() { + None + } else { + let member_bytes = &self.buf.as_slice()[self.offset..]; + if let Some(member) = MplsLabelStackMemberPacket::new_view(member_bytes) { + self.bos = member.get_bos(); + self.offset += MplsLabelStackMemberPacket::minimum_packet_size(); + Some(member_bytes) + } else { + None + } + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_stack_member_iterator() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let stack = MplsLabelStackPacket::new_view(&buf).unwrap(); + let mut member_iter = stack.members(); + let member_bytes = member_iter.next().unwrap(); + let member = MplsLabelStackMemberPacket::new_view(member_bytes).unwrap(); + assert_eq!(19380, member.get_label()); + assert_eq!(0, member.get_exp()); + assert_eq!(1, member.get_bos()); + assert_eq!(1, member.get_ttl()); + assert!(member_iter.next().is_none()); + } + } +} + +pub mod mpls_label_stack_member { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const LABEL_OFFSET: usize = 0; + const EXP_OFFSET: usize = 2; + const BOS_OFFSET: usize = 2; + const TTL_OFFSET: usize = 3; + + /// Represents an ICMP `MplsLabelStackMemberPacket`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStackMemberPacket<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStackMemberPacket<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_label(&self) -> u32 { + u32::from_be_bytes([ + 0x0, + self.buf.read(LABEL_OFFSET), + self.buf.read(LABEL_OFFSET + 1), + self.buf.read(LABEL_OFFSET + 2), + ]) >> 4 + } + + #[must_use] + pub fn get_exp(&self) -> u8 { + (self.buf.read(EXP_OFFSET) & 0x0e) >> 1 + } + + #[must_use] + pub fn get_bos(&self) -> u8 { + self.buf.read(BOS_OFFSET) & 0x01 + } + + #[must_use] + pub fn get_ttl(&self) -> u8 { + self.buf.read(TTL_OFFSET) + } + + pub fn set_label(&mut self, val: u32) { + let bytes = (val << 4).to_be_bytes(); + *self.buf.write(LABEL_OFFSET) = bytes[1]; + *self.buf.write(LABEL_OFFSET + 1) = bytes[2]; + *self.buf.write(LABEL_OFFSET + 2) = + (self.buf.read(LABEL_OFFSET + 2) & 0x0f) | (bytes[3] & 0xf0); + } + + pub fn set_exp(&mut self, exp: u8) { + *self.buf.write(EXP_OFFSET) = (self.buf.read(EXP_OFFSET) & 0xf1) | ((exp << 1) & 0x0e); + } + + pub fn set_bos(&mut self, bos: u8) { + *self.buf.write(BOS_OFFSET) = (self.buf.read(BOS_OFFSET) & 0xfe) | (bos & 0x01); + } + + pub fn set_ttl(&mut self, ttl: u8) { + *self.buf.write(TTL_OFFSET) = ttl; + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for MplsLabelStackMemberPacket<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MplsLabelStackMember") + .field("label", &self.get_label()) + .field("exp", &self.get_exp()) + .field("bos", &self.get_bos()) + .field("ttl", &self.get_ttl()) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_label() { + let mut buf = [0_u8; MplsLabelStackMemberPacket::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMemberPacket::new(&mut buf).unwrap(); + mpls_extension.set_label(0); + assert_eq!(0, mpls_extension.get_label()); + assert_eq!([0x00, 0x00, 0x00], mpls_extension.packet()[0..3]); + mpls_extension.set_label(19380); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!([0x04, 0xbb, 0x40], mpls_extension.packet()[0..3]); + mpls_extension.set_label(1_048_575); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!([0xff, 0xff, 0xf0], mpls_extension.packet()[0..3]); + } + + #[test] + fn test_exp() { + let mut buf = [0_u8; MplsLabelStackMemberPacket::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMemberPacket::new(&mut buf).unwrap(); + mpls_extension.set_exp(0); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_exp(7); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!([0x0e], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_bos() { + let mut buf = [0_u8; MplsLabelStackMemberPacket::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMemberPacket::new(&mut buf).unwrap(); + mpls_extension.set_bos(0); + assert_eq!(0, mpls_extension.get_bos()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_bos(1); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!([0x01], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_ttl() { + let mut buf = [0_u8; MplsLabelStackMemberPacket::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMemberPacket::new(&mut buf).unwrap(); + mpls_extension.set_ttl(0); + assert_eq!(0, mpls_extension.get_ttl()); + assert_eq!([0x00], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(1); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x01], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(255); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff], mpls_extension.packet()[3..4]); + } + + #[test] + fn test_combined() { + let mut buf = [0_u8; MplsLabelStackMemberPacket::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMemberPacket::new(&mut buf).unwrap(); + mpls_extension.set_label(19380); + mpls_extension.set_exp(0); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(1); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], mpls_extension.packet()[0..4]); + mpls_extension.set_label(1_048_575); + mpls_extension.set_exp(7); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(255); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff, 0xff, 0xff, 0xff], mpls_extension.packet()[0..4]); + } + + #[test] + fn test_view() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let object = MplsLabelStackMemberPacket::new_view(&buf).unwrap(); + assert_eq!(19380, object.get_label()); + assert_eq!(0, object.get_exp()); + assert_eq!(1, object.get_bos()); + assert_eq!(1, object.get_ttl()); + } + } +} + +pub mod extension_splitter { + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeaderPacket; + + const ICMP_ORIG_DATAGRAM_MIN_LENGTH: usize = 128; + const MIN_HEADER: usize = ExtensionHeaderPacket::minimum_packet_size(); + + /// Separate an ICMP payload from ICMP extensions as defined in rfc4884. + /// + /// Applies to `TimeExceeded` and `DestinationUnreachable` ICMP messages only. + /// + /// From rfc4884 (section 3) entitled "Summary of Changes to ICMP": + /// + /// "When the ICMP Extension Structure is appended to an ICMP message + /// and that ICMP message contains an "original datagram" field, the + /// "original datagram" field MUST contain at least 128 octets." + #[must_use] + pub fn split(rfc4884_length: u8, icmp_payload: &[u8]) -> (&[u8], Option<&[u8]>) { + let length = usize::from(rfc4884_length * 4); + if length > icmp_payload.len() { + return (&[], None); + } + if icmp_payload.len() > ICMP_ORIG_DATAGRAM_MIN_LENGTH { + if length > ICMP_ORIG_DATAGRAM_MIN_LENGTH { + // a 'compliant' ICMP extension longer than 128 octets. + do_split(length, icmp_payload) + } else if length > 0 { + // a 'compliant' ICMP extension padded to at least 128 octets. + match do_split(ICMP_ORIG_DATAGRAM_MIN_LENGTH, icmp_payload) { + (&[], ext) => (&[], ext), + (payload, extension) => (&payload[..length], extension), + } + } else { + // a 'non-compliant' ICMP extension padded to 128 octets. + do_split(ICMP_ORIG_DATAGRAM_MIN_LENGTH, icmp_payload) + } + } else { + // no extension present + (icmp_payload, None) + } + } + + /// Split the ICMP payload into payload and extension parts. + /// + /// If the extension is not empty and is at least as long as the minimum + /// extension header then Some(extension) is returned. + /// + /// If the extension is empty then None is returned. + /// + /// If the extension is non-empty but not as long as the minimum extension + /// header then the payload is invalid and so we return an empty payload + /// and extension. + fn do_split(index: usize, icmp_payload: &[u8]) -> (&[u8], Option<&[u8]>) { + match icmp_payload.split_at(index) { + (payload, extension) if extension.len() >= MIN_HEADER => (payload, Some(extension)), + (payload, extension) if extension.is_empty() => (payload, None), + _ => (&[], None), + } + } + + #[cfg(test)] + mod tests { + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeaderPacket; + use crate::tracing::packet::icmp_extension::extension_object::{ + ClassNum, ClassSubType, ExtensionObjectPacket, + }; + use crate::tracing::packet::icmp_extension::extension_splitter::split; + use crate::tracing::packet::icmp_extension::extension_structure::ExtensionsPacket; + use crate::tracing::packet::icmp_extension::mpls_label_stack::MplsLabelStackPacket; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMemberPacket; + use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; + use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; + use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; + use crate::tracing::packet::ipv4::Ipv4Packet; + use std::net::Ipv4Addr; + + #[test] + fn test_split_empty_payload() { + let icmp_payload: [u8; 0] = []; + let (payload, extension) = split(0, &icmp_payload); + assert!(payload.is_empty() && extension.is_none()); + } + + // Test ICMP payload which is 12 bytes and has rfc4884 length of 3 (12 + // bytes) so payload is 12 bytes and there is no extension. + #[test] + fn test_split_payload_with_compliant_empty_extension() { + let rfc4884_length = 3; + let icmp_payload: [u8; 12] = [0; 12]; + let (payload, extension) = split(rfc4884_length, &icmp_payload); + assert_eq!(payload, &[0; 12]); + assert_eq!(extension, None); + } + + // Test ICMP payload with a minimal compliant extension. + #[test] + fn test_split_payload_with_compliant_minimal_extension() { + let icmp_payload: [u8; 132] = [0; 132]; + let (payload, extension) = split(32, &icmp_payload); + assert_eq!(payload, &[0; 128]); + assert_eq!(extension, Some([0; 4].as_slice())); + } + + // Test handling of an ICMP payload which has an rfc4884 length that + // is longer than the original datagram. + #[test] + fn test_split_payload_with_invalid_rfc4884_length() { + let icmp_payload: [u8; 128] = [0; 128]; + let (payload, extension) = split(33, &icmp_payload); + assert!(payload.is_empty() && extension.is_none()); + } + + // Test handling of an ICMP payload which has a compliant extension + // which is not as long as the minimum size for an ICMP extension + // header (4 bytes). + #[test] + fn test_split_payload_with_compliant_invalid_extension() { + let icmp_payload: [u8; 129] = [0; 129]; + let (payload, extension) = split(32, &icmp_payload); + assert!(payload.is_empty() && extension.is_none()); + } + + // This ICMP TimeExceeded packet which contains single `MPLS` extension + // object with a single member. The packet does not have a `length` + // field and is therefore rfc4884 non-complaint. + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_split_extension_ipv4_time_exceeded_non_compliant_mpls() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ff 00 00 00 00 45 00 00 54 cc 1c 40 00 + 01 01 b5 f4 c0 a8 01 15 5d b8 d8 22 08 00 0f e3 + 65 da 82 42 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 20 00 99 3a 00 08 01 01 + 04 bb 41 01 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62719, time_exceeded_packet.get_checksum()); + assert_eq!(0, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..136], time_exceeded_packet.payload()); + assert_eq!(Some(&buf[136..]), time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(time_exceeded_packet.payload()).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..136], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE3, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33346, nested_echo.get_sequence()); + assert_eq!(&buf[36..136], nested_echo.payload()); + + let extensions = + ExtensionsPacket::new_view(time_exceeded_packet.extension().unwrap()).unwrap(); + + let extension_header = ExtensionHeaderPacket::new_view(extensions.header()).unwrap(); + assert_eq!(2, extension_header.get_version()); + assert_eq!(0x993A, extension_header.get_checksum()); + + let object_bytes = extensions.objects().next().unwrap(); + let extension_object = ExtensionObjectPacket::new_view(object_bytes).unwrap(); + + assert_eq!(8, extension_object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension_object.get_class_num() + ); + assert_eq!(ClassSubType(1), extension_object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], extension_object.payload()); + + let mpls_stack = MplsLabelStackPacket::new_view(extension_object.payload()).unwrap(); + let mpls_stack_member_bytes = mpls_stack.members().next().unwrap(); + let mpls_stack_member = + MplsLabelStackMemberPacket::new_view(mpls_stack_member_bytes).unwrap(); + assert_eq!(19380, mpls_stack_member.get_label()); + assert_eq!(0, mpls_stack_member.get_exp()); + assert_eq!(1, mpls_stack_member.get_bos()); + assert_eq!(1, mpls_stack_member.get_ttl()); + } + + // This ICMP TimeExceeded packet does not have any ICMP extensions. + // It has a rfc4884 complaint `length` field. + #[test] + fn test_split_extension_ipv4_time_exceeded_compliant_no_extension() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ee 00 11 00 00 45 00 00 54 a2 ee 40 00 + 01 01 df 22 c0 a8 01 15 5d b8 d8 22 08 00 0f e1 + 65 da 82 44 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62702, time_exceeded_packet.get_checksum()); + assert_eq!(17, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..76], time_exceeded_packet.payload()); + assert_eq!(None, time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..76]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..76], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE1, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33348, nested_echo.get_sequence()); + assert_eq!(&buf[36..76], nested_echo.payload()); + } + + // This is an real example that was observed in the wild whilst testing. + // + // It has a rfc4884 complaint `length` field set to be 17 and so has + // an original datagram if length 68 octet (17 * 4 = 68) but is padded + // to be 128 octets. + // + // See https://github.com/fujiapple852/trippy/issues/804 for further + // discussion and analysis of this case. + #[test] + fn test_split_extension_ipv4_time_exceeded_compliant_extension() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ee 00 11 00 00 45 00 00 54 20 c3 40 00 + 02 01 b5 7e 64 63 08 2a 5d b8 d8 22 08 00 11 8d + 65 83 80 ef 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 20 00 78 56 00 08 01 01 + 65 9f 01 01 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(68, time_exceeded_packet.payload().len()); + assert_eq!(12, time_exceeded_packet.extension().unwrap().len()); + let extensions = + ExtensionsPacket::new_view(time_exceeded_packet.extension().unwrap()).unwrap(); + + let extension_header = ExtensionHeaderPacket::new_view(extensions.header()).unwrap(); + assert_eq!(2, extension_header.get_version()); + assert_eq!(0x7856, extension_header.get_checksum()); + + let object_bytes = extensions.objects().next().unwrap(); + let extension_object = ExtensionObjectPacket::new_view(object_bytes).unwrap(); + + assert_eq!(8, extension_object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension_object.get_class_num() + ); + assert_eq!(ClassSubType(1), extension_object.get_class_subtype()); + assert_eq!([0x65, 0x9f, 0x01, 0x01], extension_object.payload()); + + let mpls_stack = MplsLabelStackPacket::new_view(extension_object.payload()).unwrap(); + let mpls_stack_member_bytes = mpls_stack.members().next().unwrap(); + let mpls_stack_member = + MplsLabelStackMemberPacket::new_view(mpls_stack_member_bytes).unwrap(); + assert_eq!(416_240, mpls_stack_member.get_label()); + assert_eq!(0, mpls_stack_member.get_exp()); + assert_eq!(1, mpls_stack_member.get_bos()); + assert_eq!(1, mpls_stack_member.get_ttl()); + } + } +} diff --git a/src/tracing/packet/icmpv4.rs b/src/tracing/packet/icmpv4.rs index 577e41ec..77032f0c 100644 --- a/src/tracing/packet/icmpv4.rs +++ b/src/tracing/packet/icmpv4.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 5; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -784,6 +809,21 @@ pub mod time_exceeded { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; TimeExceededPacket::minimum_packet_size()]; + let mut packet = TimeExceededPacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[5..6]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[5..6]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[5..6]); + } + #[test] fn test_view() { let buf = [0x0b, 0x00, 0xf4, 0xee, 0x00, 0x11, 0x00, 0x00]; @@ -791,6 +831,7 @@ pub mod time_exceeded { assert_eq!(IcmpType::TimeExceeded, packet.get_icmp_type()); assert_eq!(IcmpCode(0), packet.get_icmp_code()); assert_eq!(62702, packet.get_checksum()); + assert_eq!(17, packet.get_length()); assert!(packet.payload().is_empty()); } } @@ -799,13 +840,14 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 5; const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -860,8 +902,8 @@ pub mod destination_unreachable { } #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +923,8 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +944,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +967,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() @@ -974,6 +1029,21 @@ pub mod destination_unreachable { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; DestinationUnreachablePacket::minimum_packet_size()]; + let mut packet = DestinationUnreachablePacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[5..6]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[5..6]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[5..6]); + } + #[test] fn test_view() { let buf = [0x03, 0x03, 0xdf, 0xdc, 0x00, 0x00, 0x00, 0x00]; @@ -981,6 +1051,7 @@ pub mod destination_unreachable { assert_eq!(IcmpType::DestinationUnreachable, packet.get_icmp_type()); assert_eq!(IcmpCode(3), packet.get_icmp_code()); assert_eq!(57308, packet.get_checksum()); + assert_eq!(0, packet.get_length()); assert!(packet.payload().is_empty()); } } diff --git a/src/tracing/packet/icmpv6.rs b/src/tracing/packet/icmpv6.rs index 704f9e41..ef6acc19 100644 --- a/src/tracing/packet/icmpv6.rs +++ b/src/tracing/packet/icmpv6.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv6::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 4; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -784,13 +809,29 @@ pub mod time_exceeded { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; TimeExceededPacket::minimum_packet_size()]; + let mut packet = TimeExceededPacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[4..5]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[4..5]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[4..5]); + } + #[test] fn test_view() { - let buf = [0x03, 0x00, 0xf4, 0xee, 0x00, 0x11, 0x00, 0x00]; + let buf = [0x03, 0x00, 0xf4, 0xee, 0x11, 0x00, 0x00, 0x00]; let packet = TimeExceededPacket::new_view(&buf).unwrap(); assert_eq!(IcmpType::TimeExceeded, packet.get_icmp_type()); assert_eq!(IcmpCode(0), packet.get_icmp_code()); assert_eq!(62702, packet.get_checksum()); + assert_eq!(17, packet.get_length()); assert!(packet.payload().is_empty()); } } @@ -799,13 +840,14 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv6::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 4; const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -860,8 +902,8 @@ pub mod destination_unreachable { } #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +923,8 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +944,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +967,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() @@ -974,6 +1029,21 @@ pub mod destination_unreachable { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; DestinationUnreachablePacket::minimum_packet_size()]; + let mut packet = DestinationUnreachablePacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[4..5]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[4..5]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[4..5]); + } + #[test] fn test_view() { let buf = [0x01, 0x03, 0xdf, 0xdc, 0x00, 0x00, 0x00, 0x00]; @@ -981,6 +1051,7 @@ pub mod destination_unreachable { assert_eq!(IcmpType::DestinationUnreachable, packet.get_icmp_type()); assert_eq!(IcmpCode(3), packet.get_icmp_code()); assert_eq!(57308, packet.get_checksum()); + assert_eq!(0, packet.get_length()); assert!(packet.payload().is_empty()); } } diff --git a/src/tracing/probe.rs b/src/tracing/probe.rs index 81864e69..bdb0b123 100644 --- a/src/tracing/probe.rs +++ b/src/tracing/probe.rs @@ -3,7 +3,7 @@ use std::net::IpAddr; use std::time::{Duration, SystemTime}; /// The state of an ICMP echo request/response -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Probe { /// The sequence of the probe. pub sequence: Sequence, @@ -27,6 +27,8 @@ pub struct Probe { pub received: Option, /// The type of ICMP response packet received for the probe. pub icmp_packet_type: Option, + /// The ICMP response extensions. + pub extensions: Option, } impl Probe { @@ -53,6 +55,7 @@ impl Probe { host: None, received: None, icmp_packet_type: None, + extensions: None, } } @@ -67,12 +70,12 @@ impl Probe { } #[must_use] - pub const fn with_status(self, status: ProbeStatus) -> Self { + pub fn with_status(self, status: ProbeStatus) -> Self { Self { status, ..self } } #[must_use] - pub const fn with_icmp_packet_type(self, icmp_packet_type: IcmpPacketType) -> Self { + pub fn with_icmp_packet_type(self, icmp_packet_type: IcmpPacketType) -> Self { Self { icmp_packet_type: Some(icmp_packet_type), ..self @@ -80,7 +83,7 @@ impl Probe { } #[must_use] - pub const fn with_host(self, host: IpAddr) -> Self { + pub fn with_host(self, host: IpAddr) -> Self { Self { host: Some(host), ..self @@ -88,12 +91,17 @@ impl Probe { } #[must_use] - pub const fn with_received(self, received: SystemTime) -> Self { + pub fn with_received(self, received: SystemTime) -> Self { Self { received: Some(received), ..self } } + + #[must_use] + pub fn with_extensions(self, extensions: Option) -> Self { + Self { extensions, ..self } + } } /// The status of a `Echo` for a single TTL. @@ -130,17 +138,46 @@ pub enum IcmpPacketType { } /// The response to a probe. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponse { - TimeExceeded(ProbeResponseData), - DestinationUnreachable(ProbeResponseData), + TimeExceeded(ProbeResponseData, Option), + DestinationUnreachable(ProbeResponseData, Option), EchoReply(ProbeResponseData), TcpReply(ProbeResponseData), TcpRefused(ProbeResponseData), } +/// The ICMP extensions for a probe response. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Extensions { + pub extensions: Vec, +} + +/// A probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum Extension { + #[default] + Unknown, + Mpls(MplsLabelStack), +} + +/// The members of a MPLS probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MplsLabelStack { + pub members: Vec, +} + +/// A member of a MPLS probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MplsLabelStackMember { + pub label: u32, + pub exp: u8, + pub bos: u8, + pub ttl: u8, +} + /// The data in the probe response. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseData { /// Timestamp of the probe response. pub recv: SystemTime, @@ -160,14 +197,14 @@ impl ProbeResponseData { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponseSeq { Icmp(ProbeResponseSeqIcmp), Udp(ProbeResponseSeqUdp), Tcp(ProbeResponseSeqTcp), } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqIcmp { pub identifier: u16, pub sequence: u16, @@ -182,7 +219,7 @@ impl ProbeResponseSeqIcmp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqUdp { pub identifier: u16, pub src_port: u16, @@ -201,7 +238,7 @@ impl ProbeResponseSeqUdp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqTcp { pub src_port: u16, pub dest_port: u16, diff --git a/src/tracing/tracer.rs b/src/tracing/tracer.rs index d302f7b3..e9e94844 100644 --- a/src/tracing/tracer.rs +++ b/src/tracing/tracer.rs @@ -144,17 +144,19 @@ impl)> Tracer { fn recv_response(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> { let next = network.recv_probe()?; match next { - Some(ProbeResponse::TimeExceeded(data)) => { + Some(ProbeResponse::TimeExceeded(data, extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); let is_target = host == self.config.target_addr; if self.check_trace_id(trace_id) && st.in_round(sequence) { - st.complete_probe_time_exceeded(sequence, host, received, is_target); + st.complete_probe_time_exceeded( + sequence, host, received, is_target, extensions, + ); } } - Some(ProbeResponse::DestinationUnreachable(data)) => { + Some(ProbeResponse::DestinationUnreachable(data, extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); if self.check_trace_id(trace_id) && st.in_round(sequence) { - st.complete_probe_unreachable(sequence, host, received); + st.complete_probe_unreachable(sequence, host, received, extensions); } } Some(ProbeResponse::EchoReply(data)) => { @@ -279,11 +281,13 @@ impl)> Tracer { /// the `TracerState` struct. mod state { use crate::tracing::constants::MAX_SEQUENCE_PER_ROUND; + use crate::tracing::probe::Extensions; use crate::tracing::types::{MaxRounds, Port, Round, Sequence, TimeToLive, TraceId}; use crate::tracing::{ IcmpPacketType, MultipathStrategy, PortDirection, Probe, ProbeStatus, TracerConfig, TracerProtocol, }; + use std::array::from_fn; use std::net::IpAddr; use std::time::SystemTime; use tracing::instrument; @@ -347,7 +351,7 @@ mod state { pub fn new(config: TracerConfig) -> Self { Self { config, - buffer: [Probe::default(); BUFFER_SIZE as usize], + buffer: from_fn(|_| Probe::default()), sequence: config.initial_sequence, round_sequence: config.initial_sequence, ttl: config.first_ttl, @@ -368,7 +372,7 @@ mod state { /// Get the `Probe` for `sequence` pub fn probe_at(&self, sequence: Sequence) -> Probe { - self.buffer[usize::from(sequence - self.round_sequence)] + self.buffer[usize::from(sequence - self.round_sequence)].clone() } pub const fn ttl(&self) -> TimeToLive { @@ -430,7 +434,7 @@ mod state { self.round, SystemTime::now(), ); - self.buffer[usize::from(self.sequence - self.round_sequence)] = probe; + self.buffer[usize::from(self.sequence - self.round_sequence)] = probe.clone(); debug_assert!(self.ttl < TimeToLive(u8::MAX)); self.ttl += TimeToLive(1); debug_assert!(self.sequence < Sequence(u16::MAX)); @@ -460,7 +464,7 @@ mod state { self.round, SystemTime::now(), ); - self.buffer[usize::from(self.sequence - self.round_sequence)] = probe; + self.buffer[usize::from(self.sequence - self.round_sequence)] = probe.clone(); debug_assert!(self.sequence < Sequence(u16::MAX)); self.sequence += Sequence(1); probe @@ -553,6 +557,7 @@ mod state { host: IpAddr, received: SystemTime, is_target: bool, + extensions: Option, ) { self.complete_probe( sequence, @@ -560,6 +565,7 @@ mod state { host, received, is_target, + extensions, ); } @@ -570,8 +576,16 @@ mod state { sequence: Sequence, host: IpAddr, received: SystemTime, + extensions: Option, ) { - self.complete_probe(sequence, IcmpPacketType::Unreachable, host, received, true); + self.complete_probe( + sequence, + IcmpPacketType::Unreachable, + host, + received, + true, + extensions, + ); } /// Mark the `Probe` at `sequence` completed as `EchoReply` and update the round state. @@ -582,7 +596,14 @@ mod state { host: IpAddr, received: SystemTime, ) { - self.complete_probe(sequence, IcmpPacketType::EchoReply, host, received, true); + self.complete_probe( + sequence, + IcmpPacketType::EchoReply, + host, + received, + true, + None, + ); } /// Mark the `Probe` at `sequence` completed as `NotApplicable` and update the round state. @@ -599,6 +620,7 @@ mod state { host, received, true, + None, ); } @@ -623,6 +645,7 @@ mod state { host: IpAddr, received: SystemTime, is_target: bool, + extensions: Option, ) { // Retrieve and update the `Probe` at `sequence`. let probe = self @@ -630,8 +653,9 @@ mod state { .with_status(ProbeStatus::Complete) .with_icmp_packet_type(icmp_packet_type) .with_host(host) - .with_received(received); - self.buffer[usize::from(sequence - self.round_sequence)] = probe; + .with_received(received) + .with_extensions(extensions); + self.buffer[usize::from(sequence - self.round_sequence)] = probe.clone(); // If this `Probe` found the target then we set the `target_tll` if not already set, // being careful to account for `Probes` being received out-of-order. @@ -737,7 +761,7 @@ mod state { // Update the state of the probe 1 after receiving a TimeExceeded let received_1 = SystemTime::now(); let host = IpAddr::V4(Ipv4Addr::LOCALHOST); - state.complete_probe_time_exceeded(Sequence(33000), host, received_1, false); + state.complete_probe_time_exceeded(Sequence(33000), host, received_1, false, None); // Validate the state of the probe 1 after the update let probe_1_fetch = state.probe_at(Sequence(33000)); @@ -766,8 +790,8 @@ mod state { // Validate the probes() iterator returns returns only a single probe { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_1_fetch, probe_next1); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_1_fetch, probe_next1); assert_eq!(None, probe_iter.next()); } @@ -809,7 +833,7 @@ mod state { // Update the state of probe 2 after receiving a TimeExceeded let received_2 = SystemTime::now(); let host = IpAddr::V4(Ipv4Addr::LOCALHOST); - state.complete_probe_time_exceeded(Sequence(33001), host, received_2, false); + state.complete_probe_time_exceeded(Sequence(33001), host, received_2, false, None); let probe_2_recv = state.probe_at(Sequence(33001)); // Validate the TracerState after the update to probe 2 @@ -825,10 +849,10 @@ mod state { // Validate the probes() iterator returns the two probes in the states we expect { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_2_recv, probe_next1); - let probe_next2 = *probe_iter.next().unwrap(); - assert_eq!(probe_3, probe_next2); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_2_recv, probe_next1); + let probe_next2 = probe_iter.next().unwrap(); + assert_eq!(&probe_3, probe_next2); } // Update the state of probe 3 after receiving a EchoReply @@ -850,10 +874,10 @@ mod state { // Validate the probes() iterator returns the two probes in the states we expect { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_2_recv, probe_next1); - let probe_next2 = *probe_iter.next().unwrap(); - assert_eq!(probe_3_recv, probe_next2); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_2_recv, probe_next1); + let probe_next2 = probe_iter.next().unwrap(); + assert_eq!(&probe_3_recv, probe_next2); } } diff --git a/trippy-config-sample.toml b/trippy-config-sample.toml index 1ea474a0..a3865db2 100644 --- a/trippy-config-sample.toml +++ b/trippy-config-sample.toml @@ -160,6 +160,18 @@ payload-pattern = 0 # This is also known as DSCP+ECN. tos = 0 +# Whether to parse ICMP extensions. +# +# If enabled, all extensions attached to incoming ICMP TimeExceeded and DestinationUnavailable messages will be parsed +# and provided as part of the trace response data. +# +# The following ICMP Extension Object Classes are supported: +# 1 - MPLS Label Stack Class (RFC4950) +# +# Extension objects with an unknown class will be parsed to capture generic information including the class, subtype, +# length and payload bytes. +icmp_extensions = false + # The socket read timeout [default: 10ms] read-timeout = "10ms"