Skip to content

Commit

Permalink
snapviewGH-46 Return the response in case of non-101 response from se…
Browse files Browse the repository at this point in the history
…rver

Following redirects inside the lib (snapview#148)
has few flows: there is no redirect loop prevention in this case, in case of using the lib in proxy it's impossible to return the upstream response to browser, etc.

With this change response is propagated to the lib's user, so it can decide what to do with it
in case of redirects: either send it to browser for it to follow redirects or implement redirect following on their side.
  • Loading branch information
inikulin committed Dec 1, 2020
1 parent 96a8499 commit b4383b9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
3 changes: 3 additions & 0 deletions src/error.rs
Expand Up @@ -65,6 +65,8 @@ pub enum Error {
Url(Cow<'static, str>),
/// HTTP error.
Http(http::StatusCode),
/// HTTP response error.
HttpResponse(http::Response<()>),
/// HTTP format error.
HttpFormat(http::Error),
}
Expand All @@ -83,6 +85,7 @@ impl fmt::Display for Error {
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpResponse(ref res) => write!(f, "HTTP response error: {}", res.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
Expand Down
41 changes: 24 additions & 17 deletions src/handshake/client.rs
Expand Up @@ -90,6 +90,12 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result,
tail,
} => {
// If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if result.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::HttpResponse(result));
}

self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
let websocket =
Expand All @@ -105,16 +111,18 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();

let authority = uri.authority()
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}

write!(
Expand All @@ -138,7 +146,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {

for (k, v) in request.headers() {
let mut k = k.as_str();
if k == "sec-websocket-protocol" {
if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol";
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
Expand All @@ -157,14 +165,9 @@ struct VerifyData {

impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status()));
}
let headers = response.headers();

// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
// 1. If the response lacks an |Upgrade| header field or the |Upgrade|
// header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455)
Expand All @@ -178,7 +181,7 @@ impl VerifyData {
"No \"Upgrade: websocket\" in server reply".into(),
));
}
// 3. If the response lacks a |Connection| header field or the
// 2. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
Expand All @@ -192,7 +195,7 @@ impl VerifyData {
"No \"Connection: upgrade\" in server reply".into(),
));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// 3. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
Expand All @@ -205,14 +208,14 @@ impl VerifyData {
"Key mismatch in Sec-WebSocket-Accept".into(),
));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// 4. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO

// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// 5. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
// not present in the client's handshake (the server has indicated a
// subprotocol not requested by the client), the client MUST _Fail
Expand Down Expand Up @@ -266,8 +269,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;

#[test]
fn random_keys() {
Expand Down Expand Up @@ -304,7 +307,9 @@ mod tests {

#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Expand All @@ -321,7 +326,9 @@ mod tests {

#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://user:pass@localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Expand Down

0 comments on commit b4383b9

Please sign in to comment.