diff --git a/fuzz/src/full_stack.rs b/fuzz/src/full_stack.rs index 674be81d823..97a74871ea4 100644 --- a/fuzz/src/full_stack.rs +++ b/fuzz/src/full_stack.rs @@ -195,7 +195,7 @@ struct Peer<'a> { peers_connected: &'a RefCell<[bool; 256]>, } impl<'a> SocketDescriptor for Peer<'a> { - fn send_data(&mut self, data: &[u8], _resume_read: bool) -> usize { + fn send_data(&mut self, data: &[u8], _continue_read: bool) -> usize { data.len() } fn disconnect_socket(&mut self) { @@ -695,7 +695,7 @@ pub fn do_test(mut data: &[u8], logger: &Arc) { } let mut peer = Peer { id: peer_id, peers_connected: &peers }; match loss_detector.handler.read_event(&mut peer, get_slice!(get_slice!(1)[0])) { - Ok(res) => assert!(!res), + Ok(()) => {}, Err(_) => { peers.borrow_mut()[peer_id as usize] = false; }, diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 2b8bfb0c1a2..4d139257d00 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -774,7 +774,7 @@ use futures_util::{dummy_waker, Joiner, OptionalSelector, Selector, SelectorOutp /// # #[derive(Eq, PartialEq, Clone, Hash)] /// # struct SocketDescriptor {} /// # impl lightning::ln::peer_handler::SocketDescriptor for SocketDescriptor { -/// # fn send_data(&mut self, _data: &[u8], _resume_read: bool) -> usize { 0 } +/// # fn send_data(&mut self, _data: &[u8], _continue_read: bool) -> usize { 0 } /// # fn disconnect_socket(&mut self) {} /// # } /// # type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc, Arc>; @@ -1878,7 +1878,7 @@ mod tests { #[derive(Clone, Hash, PartialEq, Eq)] struct TestDescriptor {} impl SocketDescriptor for TestDescriptor { - fn send_data(&mut self, _data: &[u8], _resume_read: bool) -> usize { + fn send_data(&mut self, _data: &[u8], _continue_read: bool) -> usize { 0 } diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 2ec69de3f5d..068f77a84bb 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -243,13 +243,8 @@ impl Connection { Ok(len) => { let read_res = peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]); - let mut us_lock = us.lock().unwrap(); match read_res { - Ok(pause_read) => { - if pause_read { - us_lock.read_paused = true; - } - }, + Ok(()) => {}, Err(_) => break Disconnect::CloseConnection, } }, @@ -533,7 +528,7 @@ impl SocketDescriptor { } } impl peer_handler::SocketDescriptor for SocketDescriptor { - fn send_data(&mut self, data: &[u8], resume_read: bool) -> usize { + fn send_data(&mut self, data: &[u8], continue_read: bool) -> usize { // To send data, we take a lock on our Connection to access the TcpStream, writing to it if // there's room in the kernel buffer, or otherwise create a new Waker with a // SocketDescriptor in it which can wake up the write_avail Sender, waking up the @@ -544,13 +539,16 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { return 0; } - if resume_read && us.read_paused { + let read_was_paused = us.read_paused; + us.read_paused = !continue_read; + + if continue_read && read_was_paused { // The schedule_read future may go to lock up but end up getting woken up by there // being more room in the write buffer, dropping the other end of this Sender // before we get here, so we ignore any failures to wake it up. - us.read_paused = false; let _ = us.read_waker.try_send(()); } + if data.is_empty() { return 0; } @@ -576,16 +574,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { } }, task::Poll::Ready(Err(_)) => return written_len, - task::Poll::Pending => { - // We're queued up for a write event now, but we need to make sure we also - // pause read given we're now waiting on the remote end to ACK (and in - // accordance with the send_data() docs). - us.read_paused = true; - // Further, to avoid any current pending read causing a `read_event` call, wake - // up the read_waker and restart its loop. - let _ = us.read_waker.try_send(()); - return written_len; - }, + task::Poll::Pending => return written_len, } } } diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 3cf6c6cc2ad..74f081b03ae 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -632,16 +632,15 @@ pub trait SocketDescriptor: cmp::Eq + hash::Hash + Clone { /// /// If the returned size is smaller than `data.len()`, a /// [`PeerManager::write_buffer_space_avail`] call must be made the next time more data can be - /// written. Additionally, until a `send_data` event completes fully, no further - /// [`PeerManager::read_event`] calls should be made for the same peer! Because this is to - /// prevent denial-of-service issues, you should not read or buffer any data from the socket - /// until then. + /// written. /// - /// If a [`PeerManager::read_event`] call on this descriptor had previously returned true - /// (indicating that read events should be paused to prevent DoS in the send buffer), - /// `resume_read` may be set indicating that read events on this descriptor should resume. A - /// `resume_read` of false carries no meaning, and should not cause any action. - fn send_data(&mut self, data: &[u8], resume_read: bool) -> usize; + /// If `continue_read` is *not* set, further [`PeerManager::read_event`] calls should be + /// avoided until another call is made with it set. This allows us to pause read if there are + /// too many outgoing messages queued for a peer to avoid DoS issues where a peer fills our + /// buffer by sending us messages that need response without reading the responses. + /// + /// Note that calls may be made with an empty `data` to update the `continue_read` flag. + fn send_data(&mut self, data: &[u8], continue_read: bool) -> usize; /// Disconnect the socket pointed to by this SocketDescriptor. /// /// You do *not* need to call [`PeerManager::socket_disconnected`] with this socket after this @@ -782,6 +781,9 @@ struct Peer { /// Note that these messages are *not* encrypted/MAC'd, and are only serialized. gossip_broadcast_buffer: VecDeque, awaiting_write_event: bool, + /// Set to true if the last call to [`SocketDescriptor::send_data`] for this peer had the + /// `should_read` flag unset, indicating we've told the driver to stop reading from this peer. + sent_pause_read: bool, pending_read_buffer: Vec, pending_read_buffer_pos: usize, @@ -1441,6 +1443,7 @@ where pending_outbound_buffer_first_msg_offset: 0, gossip_broadcast_buffer: VecDeque::new(), awaiting_write_event: false, + sent_pause_read: false, pending_read_buffer, pending_read_buffer_pos: 0, @@ -1501,6 +1504,7 @@ where pending_outbound_buffer_first_msg_offset: 0, gossip_broadcast_buffer: VecDeque::new(), awaiting_write_event: false, + sent_pause_read: false, pending_read_buffer, pending_read_buffer_pos: 0, @@ -1523,7 +1527,7 @@ where } } - fn peer_should_read(&self, peer: &mut Peer) -> bool { + fn should_read_from(&self, peer: &mut Peer) -> bool { peer.should_read(self.gossip_processing_backlogged.load(Ordering::Relaxed)) } @@ -1536,10 +1540,14 @@ where } fn do_attempt_write_data( - &self, descriptor: &mut Descriptor, peer: &mut Peer, force_one_write: bool, + &self, descriptor: &mut Descriptor, peer: &mut Peer, mut force_one_write: bool, ) { - let mut have_written = false; - while !peer.awaiting_write_event { + // If we detect that we should be reading from the peer but reads are currently paused, or + // vice versa, then we need to tell the socket driver to update their internal flag + // indicating whether or not reads are paused. Do this by forcing a write with the desired + // `continue_read` flag set, even if no outbound messages are currently queued. + force_one_write |= self.should_read_from(peer) == peer.sent_pause_read; + while force_one_write || !peer.awaiting_write_event { if peer.should_buffer_onion_message() { if let Some((peer_node_id, _)) = peer.their_node_id { let handler = &self.message_handler.onion_message_handler; @@ -1604,23 +1612,23 @@ where self.maybe_send_extra_ping(peer); } - let should_read = self.peer_should_read(peer); + let should_read = self.should_read_from(peer); let next_buff = match peer.pending_outbound_buffer.front() { None => { - if force_one_write && !have_written { - if should_read { - let data_sent = descriptor.send_data(&[], should_read); - debug_assert_eq!(data_sent, 0, "Can't write more than no data"); - } + if force_one_write { + let data_sent = descriptor.send_data(&[], should_read); + debug_assert_eq!(data_sent, 0, "Can't write more than no data"); + peer.sent_pause_read = !should_read; } return; }, Some(buff) => buff, }; + force_one_write = false; let pending = &next_buff[peer.pending_outbound_buffer_first_msg_offset..]; let data_sent = descriptor.send_data(pending, should_read); - have_written = true; + peer.sent_pause_read = !should_read; peer.pending_outbound_buffer_first_msg_offset += data_sent; if peer.pending_outbound_buffer_first_msg_offset == next_buff.len() { peer.pending_outbound_buffer_first_msg_offset = 0; @@ -1664,7 +1672,10 @@ where Some(peer_mutex) => { let mut peer = peer_mutex.lock().unwrap(); peer.awaiting_write_event = false; - self.do_attempt_write_data(descriptor, &mut peer, false); + // We go ahead and force at least one write here, because if we don't have any + // messages to send and the net driver thought we did that's weird, so they might + // also have a confused read-paused state that we should go ahead and clear. + self.do_attempt_write_data(descriptor, &mut peer, true); }, }; Ok(()) @@ -1676,11 +1687,9 @@ where /// /// Will *not* call back into [`send_data`] on any descriptors to avoid reentrancy complexity. /// Thus, however, you should call [`process_events`] after any `read_event` to generate - /// [`send_data`] calls to handle responses. - /// - /// If `Ok(true)` is returned, further read_events should not be triggered until a - /// [`send_data`] call on this descriptor has `resume_read` set (preventing DoS issues in the - /// send buffer). + /// [`send_data`] calls to handle responses. This is also important to give [`send_data`] calls + /// a chance to pause reads if too many messages have been queued in response allowing a peer + /// to bloat our memory. /// /// In order to avoid processing too many messages at once per peer, `data` should be on the /// order of 4KiB. @@ -1689,7 +1698,7 @@ where /// [`process_events`]: PeerManager::process_events pub fn read_event( &self, peer_descriptor: &mut Descriptor, data: &[u8], - ) -> Result { + ) -> Result<(), PeerHandleError> { match self.do_read_event(peer_descriptor, data) { Ok(res) => Ok(res), Err(e) => { @@ -1718,8 +1727,7 @@ where fn do_read_event( &self, peer_descriptor: &mut Descriptor, data: &[u8], - ) -> Result { - let mut pause_read = false; + ) -> Result<(), PeerHandleError> { let peers = self.peers.read().unwrap(); let mut msgs_to_forward = Vec::new(); let mut peer_node_id = None; @@ -1994,7 +2002,6 @@ where }, } } - pause_read = !self.peer_should_read(peer); if let Some(message) = msg_to_handle { match self.handle_message(&peer_mutex, peer_lock, message) { @@ -2027,7 +2034,7 @@ where ); } - Ok(pause_read) + Ok(()) } /// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer @@ -3725,7 +3732,7 @@ mod tests { } impl SocketDescriptor for FileDescriptor { - fn send_data(&mut self, data: &[u8], _resume_read: bool) -> usize { + fn send_data(&mut self, data: &[u8], _continue_read: bool) -> usize { if self.hang_writes.load(Ordering::Acquire) { 0 } else { @@ -3939,12 +3946,8 @@ mod tests { fn try_establish_connection<'a>( peer_a: &TestPeer<'a>, peer_b: &TestPeer<'a>, - ) -> ( - FileDescriptor, - FileDescriptor, - Result, - Result, - ) { + ) -> (FileDescriptor, FileDescriptor, Result<(), PeerHandleError>, Result<(), PeerHandleError>) + { let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; @@ -3958,11 +3961,11 @@ mod tests { let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); - assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); + peer_a.read_event(&mut fd_a, &initial_data).unwrap(); peer_a.process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false); + peer_b.read_event(&mut fd_b, &a_data).unwrap(); peer_b.process_events(); let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); @@ -3989,8 +3992,8 @@ mod tests { let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b); - assert_eq!(a_refused.unwrap(), false); - assert_eq!(b_refused.unwrap(), false); + a_refused.unwrap(); + b_refused.unwrap(); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b)); @@ -4113,11 +4116,11 @@ mod tests { let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); - assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); + peer_a.read_event(&mut fd_a, &initial_data).unwrap(); peer_a.process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false); + peer_b.read_event(&mut fd_b, &a_data).unwrap(); peer_b.process_events(); let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); @@ -4144,11 +4147,11 @@ mod tests { let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); - assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); + peer_a.read_event(&mut fd_a, &initial_data).unwrap(); peer_a.process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false); + peer_b.read_event(&mut fd_b, &a_data).unwrap(); peer_b.process_events(); let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); @@ -4220,7 +4223,7 @@ mod tests { peers[0].process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peers[1].read_event(&mut fd_b, &a_data).unwrap(), false); + peers[1].read_event(&mut fd_b, &a_data).unwrap(); } #[test] @@ -4240,13 +4243,13 @@ mod tests { let mut dup_encryptor = PeerChannelEncryptor::new_outbound(id_a, SecretKey::from_slice(&[42; 32]).unwrap()); let initial_data = dup_encryptor.get_act_one(&peers[1].secp_ctx); - assert_eq!(peers[0].read_event(&mut fd_dup, &initial_data).unwrap(), false); + peers[0].read_event(&mut fd_dup, &initial_data).unwrap(); peers[0].process_events(); let a_data = fd_dup.outbound_data.lock().unwrap().split_off(0); let (act_three, _) = dup_encryptor.process_act_two(&a_data[..], &&cfgs[1].node_signer).unwrap(); - assert_eq!(peers[0].read_event(&mut fd_dup, &act_three).unwrap(), false); + peers[0].read_event(&mut fd_dup, &act_three).unwrap(); let not_init_msg = msgs::Ping { ponglen: 4, byteslen: 0 }; let msg_bytes = dup_encryptor.encrypt_message(¬_init_msg); @@ -4504,10 +4507,10 @@ mod tests { assert_eq!(peers_len, 1); } - assert_eq!(peers[0].read_event(&mut fd_a, &initial_data).unwrap(), false); + peers[0].read_event(&mut fd_a, &initial_data).unwrap(); peers[0].process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peers[1].read_event(&mut fd_b, &a_data).unwrap(), false); + peers[1].read_event(&mut fd_b, &a_data).unwrap(); peers[1].process_events(); // ...but if we get a second timer tick, we should disconnect the peer @@ -4557,11 +4560,11 @@ mod tests { let act_one = peer_b.new_outbound_connection(a_id, fd_b.clone(), None).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), None).unwrap(); - assert_eq!(peer_a.read_event(&mut fd_a, &act_one).unwrap(), false); + peer_a.read_event(&mut fd_a, &act_one).unwrap(); peer_a.process_events(); let act_two = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &act_two).unwrap(), false); + peer_b.read_event(&mut fd_b, &act_two).unwrap(); peer_b.process_events(); // Calling this here triggers the race on inbound connections. @@ -4575,7 +4578,7 @@ mod tests { assert!(!handshake_complete); } - assert_eq!(peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(), false); + peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(); peer_a.process_events(); { @@ -4595,7 +4598,7 @@ mod tests { assert!(!handshake_complete); } - assert_eq!(peer_b.read_event(&mut fd_b, &init_a).unwrap(), false); + peer_b.read_event(&mut fd_b, &init_a).unwrap(); peer_b.process_events(); { @@ -4632,7 +4635,7 @@ mod tests { peer_a.process_events(); let msg = fd_a.outbound_data.lock().unwrap().split_off(0); assert!(!msg.is_empty()); - assert_eq!(peer_b.read_event(&mut fd_b, &msg).unwrap(), false); + peer_b.read_event(&mut fd_b, &msg).unwrap(); peer_b.process_events(); }; @@ -4675,12 +4678,12 @@ mod tests { let msg = fd_a.outbound_data.lock().unwrap().split_off(0); if !msg.is_empty() { - assert_eq!(peers[1].read_event(&mut fd_b, &msg).unwrap(), false); + peers[1].read_event(&mut fd_b, &msg).unwrap(); continue; } let msg = fd_b.outbound_data.lock().unwrap().split_off(0); if !msg.is_empty() { - assert_eq!(peers[0].read_event(&mut fd_a, &msg).unwrap(), false); + peers[0].read_event(&mut fd_a, &msg).unwrap(); continue; } break;