Skip to content

Commit

Permalink
Fix --base-url improper path and protocol handling using zola serve (
Browse files Browse the repository at this point in the history
…#2311)

* Fix --base-url improper path and protocol handling.

* Fix formatting.
  • Loading branch information
jamwil authored and Keats committed Jun 20, 2024
1 parent aa81986 commit c072b32
Showing 1 changed file with 113 additions and 27 deletions.
140 changes: 113 additions & 27 deletions src/cmd/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,34 @@ fn set_serve_error(msg: &'static str, e: errors::Error) {
}
}

async fn handle_request(req: Request<Body>, mut root: PathBuf) -> Result<Response<Body>> {
async fn handle_request(
req: Request<Body>,
mut root: PathBuf,
base_path: String,
) -> Result<Response<Body>> {
let path_str = req.uri().path();
if !path_str.starts_with(&base_path) {
return Ok(not_found());
}

let trimmed_path = &path_str[base_path.len() - 1..];

let original_root = root.clone();
let mut path = RelativePathBuf::new();
// https://zola.discourse.group/t/percent-encoding-for-slugs/736
let decoded = match percent_encoding::percent_decode_str(req.uri().path()).decode_utf8() {
let decoded = match percent_encoding::percent_decode_str(trimmed_path).decode_utf8() {
Ok(d) => d,
Err(_) => return Ok(not_found()),
};

for c in decoded.split('/') {
let decoded_path = if base_path != "/" && decoded.starts_with(&base_path) {
// Remove the base_path from the request path before processing
decoded[base_path.len()..].to_string()
} else {
decoded.to_string()
};

for c in decoded_path.split('/') {
path.push(c);
}

Expand Down Expand Up @@ -318,6 +336,39 @@ fn rebuild_done_handling(broadcaster: &Sender, res: Result<()>, reload_path: &st
}
}

fn construct_url(base_url: &str, no_port_append: bool, interface_port: u16) -> String {
if base_url == "/" {
return String::from("/");
}

let (protocol, stripped_url) = match base_url {
url if url.starts_with("http://") => ("http://", &url[7..]),
url if url.starts_with("https://") => ("https://", &url[8..]),
url => ("http://", url),
};

let (domain, path) = {
let parts: Vec<&str> = stripped_url.splitn(2, '/').collect();
if parts.len() > 1 {
(parts[0], format!("/{}", parts[1]))
} else {
(parts[0], String::new())
}
};

let full_address = if no_port_append {
format!("{}{}{}", protocol, domain, path)
} else {
format!("{}{}:{}{}", protocol, domain, interface_port, path)
};

if full_address.ends_with('/') {
full_address
} else {
format!("{}/", full_address)
}
}

#[allow(clippy::too_many_arguments)]
fn create_new_site(
root_dir: &Path,
Expand All @@ -330,7 +381,7 @@ fn create_new_site(
include_drafts: bool,
mut no_port_append: bool,
ws_port: Option<u16>,
) -> Result<(Site, SocketAddr)> {
) -> Result<(Site, SocketAddr, String)> {
SITE_CONTENT.write().unwrap().clear();

let mut site = Site::new(root_dir, config_file)?;
Expand All @@ -345,24 +396,10 @@ fn create_new_site(
|u| u.to_string(),
);

let base_url = if base_url == "/" {
String::from("/")
} else {
let base_address = if no_port_append {
base_url.to_string()
} else {
format!("{}:{}", base_url, interface_port)
};

if site.config.base_url.ends_with('/') {
format!("http://{}/", base_address)
} else {
format!("http://{}", base_address)
}
};
let constructed_base_url = construct_url(&base_url, no_port_append, interface_port);

site.enable_serve_mode();
site.set_base_url(base_url);
site.set_base_url(constructed_base_url.clone());
if let Some(output_dir) = output_dir {
if !force && output_dir.exists() {
return Err(Error::msg(format!(
Expand All @@ -384,7 +421,7 @@ fn create_new_site(
messages::notify_site_size(&site);
messages::warn_about_ignored_pages(&site);
site.build()?;
Ok((site, address))
Ok((site, address, constructed_base_url))
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -403,7 +440,7 @@ pub fn serve(
utc_offset: UtcOffset,
) -> Result<()> {
let start = Instant::now();
let (mut site, bind_address) = create_new_site(
let (mut site, bind_address, constructed_base_url) = create_new_site(
root_dir,
interface,
interface_port,
Expand All @@ -415,6 +452,11 @@ pub fn serve(
no_port_append,
None,
)?;
let base_path = match constructed_base_url.splitn(4, '/').nth(3) {
Some(path) => format!("/{}", path),
None => "/".to_string(),
};

messages::report_elapsed_time(start);

// Stop right there if we can't bind to the address
Expand Down Expand Up @@ -479,19 +521,27 @@ pub fn serve(
rt.block_on(async {
let make_service = make_service_fn(move |_| {
let static_root = static_root.clone();
let base_path = base_path.clone();

async {
Ok::<_, hyper::Error>(service_fn(move |req| {
response_error_injector(handle_request(req, static_root.clone()))
response_error_injector(handle_request(
req,
static_root.clone(),
base_path.clone(),
))
}))
}
});

let server = Server::bind(&bind_address).serve(make_service);

println!("Web server is available at http://{}\n", bind_address);
println!(
"Web server is available at {} (bound to {})\n",
&constructed_base_url, &bind_address
);
if open {
if let Err(err) = open::that(format!("http://{}", bind_address)) {
if let Err(err) = open::that(format!("{}", &constructed_base_url)) {
eprintln!("Failed to open URL in your browser: {}", err);
}
}
Expand Down Expand Up @@ -618,7 +668,7 @@ pub fn serve(
no_port_append,
ws_port,
) {
Ok((s, _)) => {
Ok((s, _, _)) => {
clear_serve_error();
rebuild_done_handling(&broadcaster, Ok(()), "/x.js");

Expand Down Expand Up @@ -801,7 +851,7 @@ fn is_folder_empty(dir: &Path) -> bool {
mod tests {
use std::path::{Path, PathBuf};

use super::{detect_change_kind, is_temp_file, ChangeKind};
use super::{construct_url, detect_change_kind, is_temp_file, ChangeKind};

#[test]
fn can_recognize_temp_files() {
Expand Down Expand Up @@ -893,4 +943,40 @@ mod tests {
let config_filename = Path::new("config.toml");
assert_eq!(expected, detect_change_kind(pwd, path, config_filename));
}

#[test]
fn test_construct_url_base_url_is_slash() {
let result = construct_url("/", false, 8080);
assert_eq!(result, "/");
}

#[test]
fn test_construct_url_http_protocol() {
let result = construct_url("http://example.com", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}

#[test]
fn test_construct_url_https_protocol() {
let result = construct_url("https://example.com", false, 8080);
assert_eq!(result, "https://example.com:8080/");
}

#[test]
fn test_construct_url_no_protocol() {
let result = construct_url("example.com", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}

#[test]
fn test_construct_url_no_port_append() {
let result = construct_url("https://example.com", true, 8080);
assert_eq!(result, "https://example.com/");
}

#[test]
fn test_construct_url_trailing_slash() {
let result = construct_url("http://example.com/", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}
}

0 comments on commit c072b32

Please sign in to comment.