Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 88 additions & 48 deletions josh-proxy/src/bin/josh-proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,10 @@ async fn call_service(
}
};

let remote_url = [
serv.upstream_url.as_str(),
parsed_url.upstream_repo.as_str(),
]
.join("");
let upstream_repo = parsed_url.upstream_repo;
let filter = parsed_url.filter;

let remote_url = [serv.upstream_url.as_str(), upstream_repo.as_str()].join("");

if parsed_url.pathinfo.starts_with("/info/lfs") {
return Ok(Response::builder()
Expand Down Expand Up @@ -498,7 +497,7 @@ async fn call_service(
let block = block.split(";").collect::<Vec<_>>();

for b in block {
if b == parsed_url.upstream_repo {
if b == upstream_repo {
return Ok(make_response(
hyper::Body::from(formatdoc!(
r#"
Expand All @@ -510,9 +509,22 @@ async fn call_service(
}
}

if parsed_url.api == "/~/graphql" {
return serve_graphql(serv, req, upstream_repo.to_owned(), remote_url, auth).await;
}

if parsed_url.api == "/~/graphiql" {
let addr = format!("/~/graphql{}", upstream_repo);
return Ok(tokio::task::spawn_blocking(move || {
josh_proxy::juniper_hyper::graphiql(&addr, None)
})
.in_current_span()
.await??);
}

match fetch_upstream(
serv.clone(),
parsed_url.upstream_repo.to_owned(),
upstream_repo.to_owned(),
&auth,
remote_url.to_owned(),
&headref,
Expand All @@ -535,48 +547,16 @@ async fn call_service(
}
}

if parsed_url.api == "/~/graphiql" {
let addr = format!("/~/graphql{}", parsed_url.upstream_repo);
return Ok(tokio::task::spawn_blocking(move || {
josh_proxy::juniper_hyper::graphiql(&addr, None)
})
.in_current_span()
.await??);
}

if parsed_url.api == "/~/graphql" {
return serve_graphql(
serv,
req,
parsed_url.upstream_repo.to_owned(),
remote_url,
auth,
)
.await;
}

if let (Some(q), true) = (
req.uri().query().map(|x| x.to_string()),
parsed_url.pathinfo.is_empty(),
) {
return serve_query(
serv,
q,
parsed_url.upstream_repo,
parsed_url.filter,
headref,
)
.await;
return serve_query(serv, q, upstream_repo, filter, headref).await;
}

let temp_ns = prepare_namespace(
serv.clone(),
&parsed_url.upstream_repo,
&parsed_url.filter,
&headref,
)
.in_current_span()
.await?;
let temp_ns = prepare_namespace(serv.clone(), &upstream_repo, &filter, &headref)
.in_current_span()
.await?;

let repo_path = serv
.repo_path
Expand Down Expand Up @@ -606,8 +586,8 @@ async fn call_service(
remote_url: remote_url.clone(),
auth,
port: serv.port.clone(),
filter_spec: parsed_url.filter.clone(),
base_ns: josh::to_ns(&parsed_url.upstream_repo),
filter_spec: filter.clone(),
base_ns: josh::to_ns(&upstream_repo),
git_ns: temp_ns.name().to_string(),
git_dir: repo_path.clone(),
mirror_git_dir: mirror_repo_path.clone(),
Expand Down Expand Up @@ -1007,6 +987,11 @@ async fn serve_graphql(
remote_url: String,
auth: josh_proxy::auth::Handle,
) -> josh::JoshResult<Response<hyper::Body>> {
let parsed = match josh_proxy::juniper_hyper::parse_req(req).await {
Ok(r) => r,
Err(resp) => return Ok(resp),
};

let transaction_mirror = josh::cache::Transaction::open(
&serv.repo_path.join("mirror"),
Some(&format!(
Expand All @@ -1031,9 +1016,64 @@ async fn serve_graphql(
.to_string(),
false,
));
let gql_result = josh_proxy::juniper_hyper::graphql(root_node, context.clone(), req)
.in_current_span()
.await?;

let res = {
// First attempt to serve GraphQL query. If we can serve it
// that means all requested revisions were specified by SHA and we could find
// all of them locally, so no need to fetch.
let res = parsed.execute(&root_node, &context).await;

// The "allow_refs" flag will be set by the query handler if we need to do a fetch
// to complete the query.
if !*context.allow_refs.lock().unwrap() {
res
} else {
match fetch_upstream(
serv.clone(),
upstream_repo.to_owned(),
&auth,
remote_url.to_owned(),
&"HEAD",
false,
)
.in_current_span()
.await
{
Ok(res) => {
if !res {
let builder = Response::builder()
.header("WWW-Authenticate", "Basic realm=User Visible Realm")
.status(hyper::StatusCode::UNAUTHORIZED);
return Ok(builder.body(hyper::Body::empty())?);
}
}
Err(res) => {
let builder =
Response::builder().status(hyper::StatusCode::INTERNAL_SERVER_ERROR);
return Ok(builder.body(hyper::Body::from(res.0))?);
}
};

parsed.execute(&root_node, &context).await
}
};

let code = if res.is_ok() {
hyper::StatusCode::OK
} else {
hyper::StatusCode::BAD_REQUEST
};

let body = hyper::Body::from(serde_json::to_string_pretty(&res).unwrap());
let mut resp = Response::new(hyper::Body::empty());
*resp.status_mut() = code;
resp.headers_mut().insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
*resp.body_mut() = body;
let gql_result = resp;

tokio::task::spawn_blocking(move || -> josh::JoshResult<_> {
let temp_ns = Arc::new(josh_proxy::TmpGitNamespace::new(
&serv.repo_path.join("overlay"),
Expand Down
4 changes: 2 additions & 2 deletions josh-proxy/src/juniper_hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
})
}

async fn parse_req<S: ScalarValue>(
pub async fn parse_req<S: ScalarValue>(
req: Request<Body>,
) -> Result<GraphQLBatchRequest<S>, Response<Body>> {
match *req.method() {
Expand Down Expand Up @@ -182,7 +182,7 @@ where
resp
}

async fn execute_request<CtxT, QueryT, MutationT, SubscriptionT, S>(
pub async fn execute_request<CtxT, QueryT, MutationT, SubscriptionT, S>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
request: GraphQLBatchRequest<S>,
Expand Down
33 changes: 30 additions & 3 deletions src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ pub struct Context {
pub transaction: std::sync::Arc<std::sync::Mutex<cache::Transaction>>,
pub transaction_mirror: std::sync::Arc<std::sync::Mutex<cache::Transaction>>,
pub to_push: std::sync::Arc<std::sync::Mutex<Vec<(String, git2::Oid)>>>,
pub allow_refs: std::sync::Mutex<bool>,
}

impl juniper::Context for Context {}
Expand Down Expand Up @@ -747,6 +748,13 @@ impl RepositoryMut {
add: Vec<MarkersInput>,
context: &Context,
) -> FieldResult<bool> {
{
let mut allow_refs = context.allow_refs.lock()?;
if !*allow_refs {
*allow_refs = true;
return Err(josh_error("ref query not allowed").into());
};
}
let transaction = context.transaction.lock()?;
let transaction_mirror = context.transaction_mirror.lock()?;

Expand Down Expand Up @@ -834,6 +842,13 @@ impl Repository {
}

fn refs(&self, context: &Context, pattern: Option<String>) -> FieldResult<Vec<Reference>> {
{
let mut allow_refs = context.allow_refs.lock()?;
if !*allow_refs {
*allow_refs = true;
return Err(josh_error("ref query not allowed").into());
};
}
let transaction_mirror = context.transaction_mirror.lock()?;
let refname = format!(
"{}{}",
Expand Down Expand Up @@ -863,10 +878,21 @@ impl Repository {
let rev = format!("{}{}", self.ns, at);

let transaction_mirror = context.transaction_mirror.lock()?;
let commit_id = if let Ok(id) = git2::Oid::from_str(&at) {
let commit_id = {
let mut allow_refs = context.allow_refs.lock()?;
let id = if let Ok(id) = git2::Oid::from_str(&at) {
id
} else if *allow_refs {
transaction_mirror.repo().revparse_single(&rev)?.id()
} else {
git2::Oid::zero()
};

if !transaction_mirror.repo().odb()?.exists(id) {
*allow_refs = true;
return Err(josh_error("ref query not allowed").into());
}
id
} else {
transaction_mirror.repo().revparse_single(&rev)?.id()
};

Ok(Revision {
Expand All @@ -887,6 +913,7 @@ pub fn context(transaction: cache::Transaction, transaction_mirror: cache::Trans
transaction_mirror: std::sync::Arc::new(std::sync::Mutex::new(transaction_mirror)),
transaction: std::sync::Arc::new(std::sync::Mutex::new(transaction)),
to_push: std::sync::Arc::new(std::sync::Mutex::new(vec![])),
allow_refs: std::sync::Mutex::new(false),
}
}

Expand Down