Skip to content

Commit

Permalink
Buffer TCP writes
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Sep 28, 2023
1 parent 5edc74f commit f88a6ba
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 12 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ object-chain = "0.1.3"
bad-server = { path = "bad-server" }
defmt = { version = "=0.3.5" }
ufmt = "0.2.0"
slice-string = { git = "https://github.com/bugadani/slice-string.git", branch = "ufmt" }
slice-string = "0.7.0"
tinyvec = "1.6.0"

[dependencies]
embassy-futures = { version = "0.1.0" }
Expand Down
6 changes: 3 additions & 3 deletions src/board/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ pub mod utils;
pub mod wifi;

use alloc::boxed::Box;
use embassy_net::tcp::client::TcpClientState;
use esp_backtrace as _;

#[cfg(feature = "esp32s2")]
Expand All @@ -34,6 +33,7 @@ use crate::{
initialized::Board,
wifi::sta::{ConnectionState, Sta},
},
buffered_tcp_client::BufferedTcpClientState,
states::display_message,
};

Expand All @@ -43,14 +43,14 @@ pub struct MiscPins {
}

pub struct HttpClientResources {
pub client_state: TcpClientState<1, 4096, 4096>,
pub client_state: BufferedTcpClientState<1, 4096, 4096, 1024>,
pub rx_buffer: [u8; 512],
}

impl HttpClientResources {
pub fn new_boxed() -> Box<Self> {
Box::new(Self {
client_state: TcpClientState::new(),
client_state: BufferedTcpClientState::new(),
rx_buffer: [0; 512],
})
}
Expand Down
225 changes: 225 additions & 0 deletions src/buffered_tcp_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
ptr::NonNull,
sync::atomic::{AtomicBool, Ordering},
};

use embassy_net::{
driver::Driver,
tcp::{
client::{TcpClient, TcpClientState, TcpConnection},
Error,
},
Stack,
};
use embedded_io::{
asynch::{Read, Write},
Io,
};
use embedded_nal_async::{SocketAddr, TcpConnect};
use slice_string::tinyvec::SliceVec;

/// TCP client connection pool compatible with `embedded-nal-async` traits.
///
/// The pool is capable of managing up to N concurrent connections with tx and rx buffers according to TX_SZ and RX_SZ.
pub struct BufferedTcpClient<
'd,
D: Driver,
const N: usize,
const TX_SZ: usize = 1024,
const RX_SZ: usize = 1024,
const W_SZ: usize = 1024,
> {
inner: TcpClient<'d, D, N, TX_SZ, RX_SZ>,
pool: &'d Pool<[u8; W_SZ], N>,
}

impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize>
BufferedTcpClient<'d, D, N, TX_SZ, RX_SZ, W_SZ>
{
/// Create a new `TcpClient`.
pub fn new(
stack: &'d Stack<D>,
state: &'d BufferedTcpClientState<N, TX_SZ, RX_SZ, W_SZ>,
) -> Self {
Self {
inner: TcpClient::new(stack, &state.inner),
pool: &state.pool,
}
}
}

impl<'d, D: Driver, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize>
TcpConnect for BufferedTcpClient<'d, D, N, TX_SZ, RX_SZ, W_SZ>
{
type Error = Error;
type Connection<'m> = BufferedTcpConnection<'m, N, TX_SZ, RX_SZ, W_SZ> where Self: 'm;

async fn connect<'a>(&'a self, remote: SocketAddr) -> Result<Self::Connection<'a>, Self::Error>
where
Self: 'a,
{
let connection = self.inner.connect(remote).await?;

BufferedTcpConnection::new(connection, self.pool)
}
}

/// Opened TCP connection in a [`TcpClient`].
pub struct BufferedTcpConnection<
'd,
const N: usize,
const TX_SZ: usize,
const RX_SZ: usize,
const W_SZ: usize,
> {
inner: TcpConnection<'d, N, TX_SZ, RX_SZ>,
pool: &'d Pool<[u8; W_SZ], N>,
bufs: NonNull<[u8; W_SZ]>,
write_buffer: SliceVec<'d, u8>,
}

impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize>
BufferedTcpConnection<'d, N, TX_SZ, RX_SZ, W_SZ>
{
fn new(
inner: TcpConnection<'d, N, TX_SZ, RX_SZ>,
pool: &'d Pool<[u8; W_SZ], N>,
) -> Result<Self, Error> {
let bufs = pool.alloc().ok_or(Error::ConnectionReset)?;
let write_buffer =
SliceVec::from_slice_len(unsafe { bufs.as_ptr().as_mut().unwrap().as_mut_slice() }, 0);
Ok(Self {
inner,
pool,
bufs,
write_buffer,
})
}

async fn write_buffered(&mut self) -> Result<(), Error> {
if !self.write_buffer.is_empty() {
self.inner.write(&self.write_buffer).await?;
self.write_buffer.clear();
}
Ok(())
}
}

impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize> Drop
for BufferedTcpConnection<'d, N, TX_SZ, RX_SZ, W_SZ>
{
fn drop(&mut self) {
unsafe {
self.pool.free(self.bufs);
}
}
}

impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize> Io
for BufferedTcpConnection<'d, N, TX_SZ, RX_SZ, W_SZ>
{
type Error = Error;
}

impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize> Read
for BufferedTcpConnection<'d, N, TX_SZ, RX_SZ, W_SZ>
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
self.inner.read(buf).await
}
}

impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize> Write
for BufferedTcpConnection<'d, N, TX_SZ, RX_SZ, W_SZ>
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if buf.len() > self.write_buffer.capacity() {
self.write_buffered().await?;
return self.inner.write(buf).await;
}

let space = self.write_buffer.capacity() - self.write_buffer.len();
let len = buf.len().min(space);

self.write_buffer.extend_from_slice(&buf[..len]);

if self.write_buffer.len() == self.write_buffer.capacity() {
self.write_buffered().await?;
}

Ok(len)
}

async fn flush(&mut self) -> Result<(), Self::Error> {
self.write_buffered().await?;
self.inner.flush().await
}
}

/// State for TcpClient
pub struct BufferedTcpClientState<
const N: usize,
const TX_SZ: usize,
const RX_SZ: usize,
const W_SZ: usize,
> {
inner: TcpClientState<N, TX_SZ, RX_SZ>,
pool: Pool<[u8; W_SZ], N>,
}

impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize>
BufferedTcpClientState<N, TX_SZ, RX_SZ, W_SZ>
{
/// Create a new `TcpClientState`.
pub const fn new() -> Self {
Self {
inner: TcpClientState::new(),
pool: Pool::new(),
}
}
}

unsafe impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize, const W_SZ: usize> Sync
for BufferedTcpClientState<N, TX_SZ, RX_SZ, W_SZ>
{
}

struct Pool<T, const N: usize> {
used: [AtomicBool; N],
data: [UnsafeCell<MaybeUninit<T>>; N],
}

impl<T, const N: usize> Pool<T, N> {
const VALUE: AtomicBool = AtomicBool::new(false);
const UNINIT: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit());

const fn new() -> Self {
Self {
used: [Self::VALUE; N],
data: [Self::UNINIT; N],
}
}
}

impl<T, const N: usize> Pool<T, N> {
fn alloc(&self) -> Option<NonNull<T>> {
for n in 0..N {
if self.used[n].swap(true, Ordering::SeqCst) == false {
let p = self.data[n].get() as *mut T;
return Some(unsafe { NonNull::new_unchecked(p) });
}
}
None
}

/// safety: p must be a pointer obtained from self.alloc that hasn't been freed yet.
unsafe fn free(&self, p: NonNull<T>) {
let origin = self.data.as_ptr() as *mut T;
let n = p.as_ptr().offset_from(origin);
assert!(n >= 0);
assert!((n as usize) < N);
self.used[n as usize].store(false, Ordering::SeqCst);
}
}
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ use crate::{
};

mod board;
mod buffered_tcp_client;
mod heap;
mod replace_with;
mod sleep;
Expand Down
5 changes: 3 additions & 2 deletions src/states/firmware_update.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use embassy_futures::select::{select, Either};
use embassy_net::{dns::DnsSocket, tcp::client::TcpClient};
use embassy_net::dns::DnsSocket;
use embassy_time::{Duration, Instant, Timer};
use embedded_io::asynch::Read;
use reqwless::{client::HttpClient, request::Method, response::Status};
Expand All @@ -11,6 +11,7 @@ use crate::{
ota::{Ota0Partition, Ota1Partition, OtaClient, OtaDataPartition},
wait_for_connection, HttpClientResources,
},
buffered_tcp_client::BufferedTcpClient,
states::{display_message, menu::AppMenu},
AppState, SerialNumber,
};
Expand Down Expand Up @@ -87,7 +88,7 @@ async fn do_update(board: &mut Board) -> UpdateResult {

let mut resources = HttpClientResources::new_boxed();

let client = TcpClient::new(sta.stack(), &resources.client_state);
let client = BufferedTcpClient::new(sta.stack(), &resources.client_state);
let dns = DnsSocket::new(sta.stack());

let mut client = HttpClient::new(&client, &dns);
Expand Down
10 changes: 4 additions & 6 deletions src/states/upload_or_store_measurement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use core::{

use alloc::{boxed::Box, vec::Vec};
use embassy_futures::select::{select, Either};
use embassy_net::{
dns::DnsSocket,
tcp::client::{TcpClient, TcpClientState},
};
use embassy_net::dns::DnsSocket;
use embassy_time::{Duration, Timer};
use embedded_menu::items::NavigationItem;
use embedded_nal_async::{Dns, TcpConnect};
Expand All @@ -30,6 +27,7 @@ use crate::{
initialized::{Board, StaMode},
wait_for_connection, HttpClientResources,
},
buffered_tcp_client::BufferedTcpClient,
states::{
display_menu_screen, display_message, menu::storage::MeasurementAction, MenuEventHandler,
},
Expand Down Expand Up @@ -185,7 +183,7 @@ async fn try_to_upload(board: &mut Board, buffer: &[u8]) -> StoreMeasurement {

let mut resources = HttpClientResources::new_boxed();

let client = TcpClient::new(sta.stack(), &resources.client_state);
let client = BufferedTcpClient::new(sta.stack(), &resources.client_state);
let dns = DnsSocket::new(sta.stack());

let mut client = HttpClient::new(&client, &dns);
Expand Down Expand Up @@ -239,7 +237,7 @@ async fn upload_stored(board: &mut Board) -> bool {

let mut resources = HttpClientResources::new_boxed();

let client = TcpClient::new(sta.stack(), &resources.client_state);
let client = BufferedTcpClient::new(sta.stack(), &resources.client_state);
let dns = DnsSocket::new(sta.stack());

let mut client = HttpClient::new(&client, &dns);
Expand Down

0 comments on commit f88a6ba

Please sign in to comment.