Skip to content

Commit 47671b7

Browse files
committed
fix: Remaining broken unit tests. Also refactor tests to not use macros.
1 parent b97ce43 commit 47671b7

File tree

7 files changed

+259
-165
lines changed

7 files changed

+259
-165
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ http = "0.2.3"
2121
failure = "0.1.5"
2222
aws_lambda_events = "0.4.0"
2323
tokio = { version = "1", features = ["full"] }
24+
parking_lot = "0.11.1"
2425

2526
[dev-dependencies]
2627
# Enable test-utilities in dev mode only. This is mostly for tests.
2728
tokio = { version = "1", features = ["test-util"] }
29+
tokio-test = "0.4.0"

src/builder.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
use crate::config::*;
2-
use crate::handler::RocketHandler;
1+
use std::sync::Arc;
2+
33
use lamedh_http::handler;
44
use lamedh_runtime::run;
5-
use rocket::local::asynchronous::Client;
65
use rocket::Rocket;
7-
use std::sync::Arc;
6+
7+
use crate::config::*;
8+
use crate::handler::RocketHandler;
9+
use crate::LazyClient;
10+
use parking_lot::Mutex;
811

912
/// A builder to create and configure a [RocketHandler](RocketHandler).
1013
pub struct RocketHandlerBuilder {
@@ -40,14 +43,12 @@ impl RocketHandlerBuilder {
4043
/// use lamedh_runtime::run;
4144
/// use lamedh_http::handler;
4245
///
43-
/// let rocket_handler = rocket::ignite().lambda().into_handler();
46+
/// let rocket_handler = tokio_test::block_on(rocket::ignite().lambda().into_handler());
4447
/// run(handler(rocket_handler));
4548
/// ```
4649
pub async fn into_handler(self) -> RocketHandler {
47-
// TODO: Change this to async Client?
48-
let client = Arc::new(Client::untracked(self.rocket).await.unwrap());
4950
RocketHandler {
50-
client,
51+
lazy_client: Arc::new(Mutex::new(LazyClient::Uninitialized(Some(self.rocket)))),
5152
config: Arc::new(self.config),
5253
}
5354
}
@@ -64,9 +65,9 @@ impl RocketHandlerBuilder {
6465
///
6566
/// ```rust,no_run
6667
/// use rocket_lamb::RocketExt;
67-
/// use lambda_http::lambda::lambda;
68+
/// use lamedh_http::lambda::lambda;
6869
///
69-
/// rocket::ignite().lambda().launch();
70+
/// tokio_test::block_on(rocket::ignite().lambda().launch());
7071
/// ```
7172
pub async fn launch(self) -> ! {
7273
run(handler(self.into_handler().await)).await.unwrap();

src/handler.rs

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,40 @@
1-
use crate::config::*;
2-
use crate::error::RocketLambError;
3-
use crate::request_ext::RequestExt as _;
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
use std::sync::Arc;
4+
45
use aws_lambda_events::encodings::Body;
56
use lamedh_http::{Handler, Request, RequestExt, Response};
67
use lamedh_runtime::Context;
8+
use parking_lot::Mutex;
79
use rocket::http::{uri::Uri, Header};
810
use rocket::local::asynchronous::{Client, LocalRequest, LocalResponse};
9-
use std::future::Future;
10-
use std::pin::Pin;
11-
use std::sync::Arc;
11+
use rocket::{Rocket, Route};
12+
13+
use crate::config::*;
14+
use crate::error::RocketLambError;
15+
use crate::request_ext::RequestExt as _;
1216

1317
/// A Lambda handler for API Gateway events that processes requests using a [Rocket](rocket::Rocket) instance.
1418
pub struct RocketHandler {
15-
pub(super) client: Arc<Client>,
19+
pub(super) lazy_client: Arc<Mutex<LazyClient>>,
1620
pub(super) config: Arc<Config>,
1721
}
1822

23+
pub(super) enum LazyClient {
24+
Uninitialized(Option<Rocket>),
25+
Ready(Arc<Client>),
26+
}
27+
1928
impl Handler for RocketHandler {
2029
type Error = failure::Error;
2130
type Response = Response<Body>;
2231
type Fut = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + 'static>>;
2332

2433
fn call(&mut self, req: Request, _ctx: Context) -> Self::Fut {
25-
let client = Arc::clone(&self.client);
2634
let config = Arc::clone(&self.config);
35+
let lazy_client = Arc::clone(&self.lazy_client);
2736
let fut = async {
28-
process_request(client, config, req)
37+
process_request(lazy_client, config, req)
2938
.await
3039
.map_err(failure::Error::from)
3140
.map_err(failure::Error::into)
@@ -35,10 +44,9 @@ impl Handler for RocketHandler {
3544
}
3645

3746
fn get_path_and_query(config: &Config, req: &Request) -> String {
38-
// TODO: Figure out base path behavior per request since the client doesn't have it now
39-
let mut uri = match config.base_path_behaviour {
40-
BasePathBehaviour::Include | BasePathBehaviour::RemountAndInclude => req.full_path(),
41-
BasePathBehaviour::Exclude => req.api_path().to_owned(),
47+
let mut uri = match &config.base_path_behaviour {
48+
BasePathBehaviour::Include | BasePathBehaviour::RemountAndInclude => dbg!(req.full_path()),
49+
BasePathBehaviour::Exclude => dbg!(req.api_path().to_owned()),
4250
};
4351
let query = req.query_string_parameters();
4452

@@ -58,15 +66,46 @@ fn get_path_and_query(config: &Config, req: &Request) -> String {
5866
}
5967

6068
async fn process_request(
61-
client: Arc<Client>,
69+
lazy_client: Arc<Mutex<LazyClient>>,
6270
config: Arc<Config>,
6371
req: Request,
6472
) -> Result<Response<Body>, RocketLambError> {
73+
let client = get_client_from_lazy(&lazy_client, &config, &req).await;
6574
let local_req = create_rocket_request(&client, Arc::clone(&config), req)?;
6675
let local_res = local_req.dispatch().await;
6776
create_lambda_response(config, local_res).await
6877
}
6978

79+
async fn get_client_from_lazy(
80+
lazy_client_lock: &Mutex<LazyClient>,
81+
config: &Config,
82+
req: &Request,
83+
) -> Arc<Client> {
84+
let mut lazy_client = lazy_client_lock.lock();
85+
match &mut *lazy_client {
86+
LazyClient::Ready(c) => Arc::clone(&c),
87+
LazyClient::Uninitialized(r) => {
88+
let r = r
89+
.take()
90+
.expect("It should not be possible for this to be None");
91+
let base_path = req.base_path();
92+
let client = if config.base_path_behaviour == BasePathBehaviour::RemountAndInclude
93+
&& !base_path.is_empty()
94+
{
95+
let routes: Vec<Route> = r.routes().cloned().collect();
96+
let rocket = r.mount(&base_path, routes);
97+
Client::untracked(rocket).await.unwrap()
98+
} else {
99+
Client::untracked(r).await.unwrap()
100+
};
101+
let client = Arc::new(client);
102+
let client_clone = Arc::clone(&client);
103+
*lazy_client = LazyClient::Ready(client);
104+
client_clone
105+
}
106+
}
107+
}
108+
70109
fn create_rocket_request(
71110
client: &Client,
72111
config: Arc<Config>,

src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ This *should* also work with requests from an AWS Application Load Balancer, but
88
## Usage
99
1010
```rust,no_run
11-
#![feature(proc_macro_hygiene, decl_macro)]
12-
1311
#[macro_use] extern crate rocket;
1412
use rocket_lamb::RocketExt;
1513
@@ -18,11 +16,12 @@ fn hello() -> &'static str {
1816
"Hello, world!"
1917
}
2018
21-
fn main() {
19+
#[tokio::main]
20+
async fn main() {
2221
rocket::ignite()
2322
.mount("/hello", routes![hello])
2423
.lambda() // launch the Rocket as a Lambda
25-
.launch();
24+
.launch().await;
2625
}
2726
```
2827
*/

src/request_ext.rs

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
use aws_lambda_events::event::apigw::{
2+
ApiGatewayProxyRequestContext, ApiGatewayV2httpRequestContext,
3+
ApiGatewayV2httpRequestContextHttpDescription,
4+
};
15
use http::header::HOST;
26
use lamedh_http::request::RequestContext;
37
use lamedh_http::{Request, RequestExt as _};
@@ -12,7 +16,9 @@ pub(crate) trait RequestExt {
1216

1317
impl RequestExt for Request {
1418
fn full_path(&self) -> String {
15-
if matches!(self.request_context(), RequestContext::Alb(_)) || !is_default_api_gateway_url(self) {
19+
if matches!(self.request_context(), RequestContext::Alb(_))
20+
|| !is_default_api_gateway_url(self)
21+
{
1622
self.uri().path().to_owned()
1723
} else {
1824
let mut path = self.base_path();
@@ -22,35 +28,35 @@ impl RequestExt for Request {
2228
}
2329

2430
fn base_path(&self) -> String {
25-
// TODO: Find out what this is for and is supposed to return >.>
26-
String::new()
27-
// match self.request_context() {
28-
// RequestContext::ApiGateway {
29-
// stage,
30-
// resource_path,
31-
// ..
32-
// } => {
33-
// if is_default_api_gateway_url(self) {
34-
// format!("/{}", stage)
35-
// } else {
36-
// let resource_path = populate_resource_path(self, resource_path);
37-
// let full_path = self.uri().path();
38-
// let resource_path_index =
39-
// full_path.rfind(&resource_path).unwrap_or_else(|| {
40-
// panic!(
41-
// "Could not find segment '{}' in path '{}'.",
42-
// resource_path, full_path
43-
// )
44-
// });
45-
// full_path[..resource_path_index].to_owned()
46-
// }
47-
// }
48-
// RequestContext::Alb { .. } => String::new(),
49-
// }
31+
let (stage, path) = match self.request_context() {
32+
RequestContext::ApiGatewayV1(ApiGatewayProxyRequestContext {
33+
stage,
34+
resource_path,
35+
..
36+
}) => (stage, resource_path),
37+
RequestContext::ApiGatewayV2(ApiGatewayV2httpRequestContext {
38+
stage,
39+
http: ApiGatewayV2httpRequestContextHttpDescription { path, .. },
40+
..
41+
}) => (stage, path),
42+
RequestContext::Alb(..) => (None, None),
43+
};
44+
if is_default_api_gateway_url(self) {
45+
format!("/{}", stage.unwrap_or("".to_string()))
46+
} else {
47+
let path = populate_resource_path(self, path.unwrap_or("".to_string()));
48+
let full_path = self.uri().path();
49+
let resource_path_index = full_path.rfind(&path).unwrap_or_else(|| {
50+
panic!("Could not find segment '{}' in path '{}'.", path, full_path)
51+
});
52+
full_path[..resource_path_index].to_owned()
53+
}
5054
}
5155

5256
fn api_path(&self) -> &str {
53-
if matches!(self.request_context(), RequestContext::Alb(_)) || is_default_api_gateway_url(self) {
57+
if matches!(self.request_context(), RequestContext::Alb(_))
58+
|| is_default_api_gateway_url(self)
59+
{
5460
self.uri().path()
5561
} else {
5662
&self.uri().path()[self.base_path().len()..]
@@ -66,21 +72,21 @@ fn is_default_api_gateway_url(req: &Request) -> bool {
6672
.unwrap_or(false)
6773
}
6874

69-
// fn populate_resource_path(req: &Request, resource_path: String) -> String {
70-
// let path_parameters = req.path_parameters();
71-
// resource_path
72-
// .split('/')
73-
// .map(|segment| {
74-
// if segment.starts_with('{') {
75-
// let end = if segment.ends_with("+}") { 2 } else { 1 };
76-
// let param = &segment[1..segment.len() - end];
77-
// path_parameters
78-
// .get(param)
79-
// .unwrap_or_else(|| panic!("Could not find path parameter '{}'.", param))
80-
// } else {
81-
// segment
82-
// }
83-
// })
84-
// .collect::<Vec<&str>>()
85-
// .join("/")
86-
// }
75+
fn populate_resource_path(req: &Request, resource_path: String) -> String {
76+
let path_parameters = req.path_parameters();
77+
resource_path
78+
.split('/')
79+
.map(|segment| {
80+
if segment.starts_with('{') {
81+
let end = if segment.ends_with("+}") { 2 } else { 1 };
82+
let param = &segment[1..segment.len() - end];
83+
path_parameters
84+
.get(param)
85+
.unwrap_or_else(|| panic!("Could not find path parameter '{}'.", param))
86+
} else {
87+
segment
88+
}
89+
})
90+
.collect::<Vec<&str>>()
91+
.join("/")
92+
}

0 commit comments

Comments
 (0)