Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base64 simd support #12

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ categories = ["embedded", "no-std", "network-programming"]
readme = "README.md"

[dependencies]
sha1 = "0.6"
heapless = "0.5"
byteorder = { version = "1.4", default-features = false }
httparse = { version = "1.4", default-features = false }
rand_core = "0.6"
sha1 = "0.10.1"
heapless = "0.7.14"
byteorder = { version = "1.4.3", default-features = false }
httparse = { version = "1.7.1", default-features = false }
rand_core = "0.6.3"
base64 = { version = "0.13.0", default-features = false }
base64-simd = { version = "0.5.0", default-features = false, optional = true }
cfg-if = "1.0.0"

[dev-dependencies]
rand = "0.8.3"
rand = "0.8.5"

# see readme for no_std support
[features]
default = ["std"]
# default = []
std = []
std = []
107 changes: 0 additions & 107 deletions src/base64.rs

This file was deleted.

40 changes: 30 additions & 10 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use heapless::{String, Vec};
/// Websocket details extracted from the http header
pub struct WebSocketContext {
/// The list of sub protocols is restricted to a maximum of 3
pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, U3>,
pub sec_websocket_protocol_list: Vec<WebSocketSubProtocol, 3>,
/// The websocket key user to build the accept string to complete the opening handshake
pub sec_websocket_key: WebSocketKey,
}
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct WebSocketContext {
pub fn read_http_header<'a>(
headers: impl Iterator<Item = (&'a str, &'a [u8])>,
) -> Result<Option<WebSocketContext>> {
let mut sec_websocket_protocol_list: Vec<String<U24>, U3> = Vec::new();
let mut sec_websocket_protocol_list: Vec<String<24>, 3> = Vec::new();
let mut is_websocket_request = false;
let mut sec_websocket_key = String::new();

Expand Down Expand Up @@ -130,13 +130,23 @@ pub fn build_connect_handshake_request(
rng: &mut impl RngCore,
to: &mut [u8],
) -> Result<(usize, WebSocketKey)> {
let mut http_request: String<U1024> = String::new();
let mut http_request: String<1024> = String::new();
let mut key_as_base64: [u8; 24] = [0; 24];

let mut key: [u8; 16] = [0; 16];
rng.fill_bytes(&mut key);
base64::encode(&key, &mut key_as_base64);
let sec_websocket_key: String<U24> = String::from(str::from_utf8(&key_as_base64)?);

cfg_if::cfg_if! {
if #[cfg(feature = "base64-simd")] {
use base64_simd::{Base64, OutBuf};
Base64::STANDARD.encode(&key, OutBuf::from_slice_mut(&mut key_as_base64))?;
} else {
base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64);
}
}


let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?);

http_request.push_str("GET ")?;
http_request.push_str(websocket_options.path)?;
Expand Down Expand Up @@ -177,7 +187,7 @@ pub fn build_connect_handshake_response(
sec_websocket_protocol: Option<&WebSocketSubProtocol>,
to: &mut [u8],
) -> Result<usize> {
let mut http_response: String<U1024> = String::new();
let mut http_response: String<1024> = String::new();
http_response.push_str(
"HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\nUpgrade: websocket\r\n",
Expand All @@ -204,13 +214,23 @@ pub fn build_connect_handshake_response(

pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) -> Result<()> {
// concatenate the key with a known websocket GUID (as per the spec)
let mut accept_string: String<U64> = String::new();
let mut accept_string: String<64> = String::new();
accept_string.push_str(sec_websocket_key)?;
accept_string.push_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")?;

// calculate the base64 encoded sha1 hash of the accept string above
let sha1 = Sha1::from(&accept_string);
let input = sha1.digest().bytes();
base64::encode(&input, output); // no need for slices since the output WILL be 28 bytes
let mut sha1 = Sha1::new();
sha1.update(&accept_string);
let input = sha1.finalize();

cfg_if::cfg_if! {
if #[cfg(feature = "base64-simd")] {
use base64_simd::{Base64, OutBuf};
Base64::STANDARD.encode(&input, OutBuf::from_slice_mut(output))?;
} else {
base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes
}
}

Ok(())
}
21 changes: 14 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@

use byteorder::{BigEndian, ByteOrder};
use core::{cmp, result, str};
use heapless::consts::{U1024, U24, U256, U3, U64};
use heapless::{String, Vec};
use rand_core::RngCore;
use sha1::Sha1;
use sha1::{Sha1, Digest};

mod base64;
mod http;
pub mod random;
pub use self::http::{read_http_header, WebSocketContext};
Expand All @@ -36,10 +34,10 @@ const MASK_KEY_LEN: usize = 4;
pub type Result<T> = result::Result<T, Error>;

/// A fixed length 24-character string used to hold a websocket key for the opening handshake
pub type WebSocketKey = String<U24>;
pub type WebSocketKey = String<24>;

/// A maximum sized 24-character string used to store a sub protocol (e.g. `chat`)
pub type WebSocketSubProtocol = String<U24>;
pub type WebSocketSubProtocol = String<24>;

/// Websocket send message type used when sending a websocket frame
#[derive(PartialEq, Debug, Copy, Clone)]
Expand Down Expand Up @@ -213,6 +211,8 @@ pub enum Error {
ConvertInfallible,
RandCore,
UnexpectedContinuationFrame,
#[cfg(feature = "base64-simd")]
Base64Error,
}

impl From<httparse::Error> for Error {
Expand All @@ -239,6 +239,13 @@ impl From<()> for Error {
}
}

#[cfg(feature = "base64-simd")]
impl From<base64_simd::Error> for Error {
fn from(_: base64_simd::Error) -> Error {
Error::Base64Error
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
enum WebSocketOpCode {
ContinuationFrame = 0,
Expand Down Expand Up @@ -680,7 +687,7 @@ where
if self.state == WebSocketState::Open {
self.state = WebSocketState::CloseSent;
if let Some(status_description) = status_description {
let mut from_buffer: Vec<u8, U256> = Vec::new();
let mut from_buffer: Vec<u8, 256> = Vec::new();
BigEndian::write_u16(&mut from_buffer, close_status.to_u16());

// restrict the max size of the status_description
Expand All @@ -690,7 +697,7 @@ where
254
};

from_buffer.extend(status_description[..len].as_bytes());
from_buffer.extend_from_slice(status_description[..len].as_bytes())?;
self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
} else {
let mut from_buffer: [u8; 2] = [0; 2];
Expand Down