Skip to content

Commit

Permalink
Fixing race condition in TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
Mallets committed Mar 16, 2021
1 parent 2acdb4d commit 2fbbf5a
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions zenoh/src/net/protocol/link/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use async_std::channel::{bounded, Receiver, Sender};
use async_std::fs;
use async_std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use async_std::prelude::*;
use async_std::sync::{Arc, Barrier, RwLock};
use async_std::sync::{Arc, Barrier, Mutex, RwLock};
use async_std::task;
use async_trait::async_trait;
use std::cell::UnsafeCell;
Expand Down Expand Up @@ -118,6 +118,7 @@ async fn get_tls_dns(locator: &Locator) -> ZResult<DNSName> {
}

#[allow(unreachable_patterns)]
#[inline(always)]
fn get_tls_prop(property: &LocatorProperty) -> ZResult<&LocatorPropertyTls> {
match property {
LocatorProperty::Tls(prop) => Ok(prop),
Expand Down Expand Up @@ -304,6 +305,9 @@ pub struct LinkTls {
src_addr: SocketAddr,
// The destination socket address of this link (address used on the local host)
dst_addr: SocketAddr,
// Make sure there are no concurrent read or writes
write_mtx: Mutex<()>,
read_mtx: Mutex<()>,
}

unsafe impl Send for LinkTls {}
Expand Down Expand Up @@ -342,12 +346,14 @@ impl LinkTls {
inner: UnsafeCell::new(socket),
src_addr,
dst_addr,
write_mtx: Mutex::new(()),
read_mtx: Mutex::new(()),
}
}

// NOTE: It is safe to suppress Clippy warning since no concurrent reads
// or concurrent writes will ever happen. This is enforced by the
// transmission and reception logic in zenoh.
// or concurrent writes will ever happen. The read_mtx and write_mtx
// are respectively acquired in any read and write operation.
#[allow(clippy::mut_from_ref)]
fn get_sock_mut(&self) -> &mut TlsStream<TcpStream> {
unsafe { &mut *self.inner.get() }
Expand All @@ -356,6 +362,7 @@ impl LinkTls {
pub(crate) async fn close(&self) -> ZResult<()> {
log::trace!("Closing TLS link: {}", self);
// Flush the TLS stream
let _guard = zasynclock!(self.write_mtx);
let tls_stream = self.get_sock_mut();
let res = tls_stream.flush().await;
log::trace!("TLS link flush {}: {:?}", self, res);
Expand All @@ -370,8 +377,9 @@ impl LinkTls {
})
}

#[inline]
#[inline(always)]
pub(crate) async fn write(&self, buffer: &[u8]) -> ZResult<usize> {
let _guard = zasynclock!(self.write_mtx);
match self.get_sock_mut().write(buffer).await {
Ok(n) => Ok(n),
Err(e) => {
Expand All @@ -383,8 +391,9 @@ impl LinkTls {
}
}

#[inline]
#[inline(always)]
pub(crate) async fn write_all(&self, buffer: &[u8]) -> ZResult<()> {
let _guard = zasynclock!(self.write_mtx);
match self.get_sock_mut().write_all(buffer).await {
Ok(_) => Ok(()),
Err(e) => {
Expand All @@ -396,8 +405,9 @@ impl LinkTls {
}
}

#[inline]
#[inline(always)]
pub(crate) async fn read(&self, buffer: &mut [u8]) -> ZResult<usize> {
let _guard = zasynclock!(self.read_mtx);
match self.get_sock_mut().read(buffer).await {
Ok(n) => Ok(n),
Err(e) => {
Expand All @@ -409,8 +419,9 @@ impl LinkTls {
}
}

#[inline]
#[inline(always)]
pub(crate) async fn read_exact(&self, buffer: &mut [u8]) -> ZResult<()> {
let _guard = zasynclock!(self.read_mtx);
match self.get_sock_mut().read_exact(buffer).await {
Ok(_) => Ok(()),
Err(e) => {
Expand All @@ -422,22 +433,22 @@ impl LinkTls {
}
}

#[inline]
#[inline(always)]
pub(crate) fn get_src(&self) -> Locator {
Locator::Tls(LocatorTls::SocketAddr(self.src_addr))
}

#[inline]
#[inline(always)]
pub(crate) fn get_dst(&self) -> Locator {
Locator::Tls(LocatorTls::SocketAddr(self.dst_addr))
}

#[inline]
#[inline(always)]
pub(crate) fn get_mtu(&self) -> usize {
*TLS_DEFAULT_MTU
}

#[inline]
#[inline(always)]
pub(crate) fn is_reliable(&self) -> bool {
true
}
Expand Down

0 comments on commit 2fbbf5a

Please sign in to comment.