Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/app.zig
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub const App = struct {
tls_verify_host: bool = true,
http_proxy: ?std.Uri = null,
proxy_type: ?http.ProxyType = null,
proxy_auth: ?http.ProxyAuth = null,
};

pub fn init(allocator: Allocator, config: Config) !*App {
Expand Down Expand Up @@ -58,6 +59,7 @@ pub const App = struct {
.max_concurrent = 3,
.http_proxy = config.http_proxy,
.proxy_type = config.proxy_type,
.proxy_auth = config.proxy_auth,
.tls_verify_host = config.tls_verify_host,
}),
.config = config,
Expand Down
97 changes: 93 additions & 4 deletions src/http/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ pub const ProxyType = enum {
connect,
};

pub const ProxyAuth = union(enum) {
basic: struct { user_pass: []const u8 },
bearer: struct { token: []const u8 },

pub fn header_value(self: ProxyAuth, allocator: Allocator) ![]const u8 {
switch (self) {
.basic => |*auth| {
if (std.mem.indexOfScalar(u8, auth.user_pass, ':') == null) return error.InvalidProxyAuth;

const prefix = "Basic ";
var encoder = std.base64.standard.Encoder;
const size = encoder.calcSize(auth.user_pass.len);
var buffer = try allocator.alloc(u8, size + prefix.len);
@memcpy(buffer[0..prefix.len], prefix);
_ = std.base64.standard.Encoder.encode(buffer[prefix.len..], auth.user_pass);
return buffer;
},
.bearer => |*auth| {
const prefix = "Bearer ";
var buffer = try allocator.alloc(u8, auth.token.len + prefix.len);
@memcpy(buffer[0..prefix.len], prefix);
@memcpy(buffer[prefix.len..], auth.token);
return buffer;
},
}
}
};

// Thread-safe. Holds our root certificate, connection pool and state pool
// Used to create Requests.
pub const Client = struct {
Expand All @@ -54,6 +82,7 @@ pub const Client = struct {
state_pool: StatePool,
http_proxy: ?Uri,
proxy_type: ?ProxyType,
proxy_auth: ?[]const u8, // Basic <user:pass; base64> or Bearer <token>
root_ca: tls.config.CertBundle,
tls_verify_host: bool = true,
connection_manager: ConnectionManager,
Expand All @@ -63,6 +92,7 @@ pub const Client = struct {
max_concurrent: usize = 3,
http_proxy: ?std.Uri = null,
proxy_type: ?ProxyType = null,
proxy_auth: ?ProxyAuth = null,
tls_verify_host: bool = true,
max_idle_connection: usize = 10,
};
Expand All @@ -71,10 +101,10 @@ pub const Client = struct {
var root_ca: tls.config.CertBundle = if (builtin.is_test) .{} else try tls.config.CertBundle.fromSystem(allocator);
errdefer root_ca.deinit(allocator);

const state_pool = try StatePool.init(allocator, opts.max_concurrent);
var state_pool = try StatePool.init(allocator, opts.max_concurrent);
errdefer state_pool.deinit(allocator);

const connection_manager = ConnectionManager.init(allocator, opts.max_idle_connection);
var connection_manager = ConnectionManager.init(allocator, opts.max_idle_connection);
errdefer connection_manager.deinit();

return .{
Expand All @@ -84,6 +114,7 @@ pub const Client = struct {
.state_pool = state_pool,
.http_proxy = opts.http_proxy,
.proxy_type = if (opts.http_proxy == null) null else (opts.proxy_type orelse .connect),
.proxy_auth = if (opts.proxy_auth) |*auth| try auth.header_value(allocator) else null,
.tls_verify_host = opts.tls_verify_host,
.connection_manager = connection_manager,
.request_pool = std.heap.MemoryPool(Request).init(allocator),
Expand All @@ -98,6 +129,10 @@ pub const Client = struct {
self.state_pool.deinit(allocator);
self.connection_manager.deinit();
self.request_pool.deinit();

if (self.proxy_auth) |auth| {
allocator.free(auth);
}
}

pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !*Request {
Expand Down Expand Up @@ -763,6 +798,13 @@ pub const Request = struct {

try self.headers.append(arena, .{ .name = "User-Agent", .value = "Lightpanda/1.0" });
try self.headers.append(arena, .{ .name = "Accept", .value = "*/*" });

if (self._client.isSimpleProxy()) {
if (self._client.proxy_auth) |proxy_auth| {
try self.headers.append(arena, .{ .name = "Proxy-Authorization", .value = proxy_auth });
}
}

self.requestStarting();
}

Expand Down Expand Up @@ -887,7 +929,13 @@ pub const Request = struct {
var writer = fbs.writer();

try writer.print("CONNECT {s}:{d} HTTP/1.1\r\n", .{ self._request_host, self._request_port });
try writer.print("Host: {s}:{d}\r\n\r\n", .{ self._request_host, self._request_port });
try writer.print("Host: {s}:{d}\r\n", .{ self._request_host, self._request_port });

if (self._client.proxy_auth) |proxy_auth| {
try writer.print("Proxy-Authorization: {s}\r\n", .{proxy_auth});
}

_ = try writer.write("\r\n");
return buf[0..fbs.pos];
}

Expand Down Expand Up @@ -3030,15 +3078,56 @@ test "HttpClient: sync with body proxy CONNECT" {
}
try testing.expectEqual("over 9000!", try res.next());
try testing.expectEqual(201, res.header.status);
try testing.expectEqual(5, res.header.count());
try testing.expectEqual(6, res.header.count());
try testing.expectEqual("Close", res.header.get("connection"));
try testing.expectEqual("10", res.header.get("content-length"));
try testing.expectEqual("127.0.0.1", res.header.get("_host"));
try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent"));
try testing.expectEqual("*/*", res.header.get("_accept"));
// Proxy headers
try testing.expectEqual("127.0.0.1:9582", res.header.get("__host"));
}
}

test "HttpClient: basic authentication CONNECT" {
const proxy_uri = try Uri.parse("http://127.0.0.1:9582/");
var client = try testClient(.{ .proxy_type = .connect, .http_proxy = proxy_uri, .proxy_auth = .{ .basic = .{ .user_pass = "user:pass" } } });
defer client.deinit();

const uri = try Uri.parse("http://127.0.0.1:9582/http_client/echo");
var req = try client.request(.GET, &uri);
defer req.deinit();

var res = try req.sendSync(.{});

try testing.expectEqual(201, res.header.status);
// Destination headers
try testing.expectEqual(null, res.header.get("_authorization"));
try testing.expectEqual(null, res.header.get("_proxy-authorization"));
// Proxy headers
try testing.expectEqual(null, res.header.get("__authorization"));
try testing.expectEqual("Basic dXNlcjpwYXNz", res.header.get("__proxy-authorization"));
}
test "HttpClient: bearer authentication CONNECT" {
const proxy_uri = try Uri.parse("http://127.0.0.1:9582/");
var client = try testClient(.{ .proxy_type = .connect, .http_proxy = proxy_uri, .proxy_auth = .{ .bearer = .{ .token = "fruitsalad" } } });
defer client.deinit();

const uri = try Uri.parse("http://127.0.0.1:9582/http_client/echo");
var req = try client.request(.GET, &uri);
defer req.deinit();

var res = try req.sendSync(.{});

try testing.expectEqual(201, res.header.status);
// Destination headers
try testing.expectEqual(null, res.header.get("_authorization"));
try testing.expectEqual(null, res.header.get("_proxy-authorization"));
// Proxy headers
try testing.expectEqual(null, res.header.get("__authorization"));
try testing.expectEqual("Bearer fruitsalad", res.header.get("__proxy-authorization"));
}

test "HttpClient: sync with gzip body" {
for (0..2) |i| {
var client = try testClient(.{});
Expand Down
58 changes: 58 additions & 0 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ fn run(alloc: Allocator) !void {
.run_mode = args.mode,
.http_proxy = args.httpProxy(),
.proxy_type = args.proxyType(),
.proxy_auth = args.proxyAuth(),
.tls_verify_host = args.tlsVerifyHost(),
});
defer app.deinit();
Expand Down Expand Up @@ -164,6 +165,13 @@ const Command = struct {
};
}

fn proxyAuth(self: *const Command) ?http.ProxyAuth {
return switch (self.mode) {
inline .serve, .fetch => |opts| opts.common.proxy_auth,
else => unreachable,
};
}

fn logLevel(self: *const Command) ?log.Level {
return switch (self.mode) {
inline .serve, .fetch => |opts| opts.common.log_level,
Expand Down Expand Up @@ -208,6 +216,7 @@ const Command = struct {
const Common = struct {
http_proxy: ?std.Uri = null,
proxy_type: ?http.ProxyType = null,
proxy_auth: ?http.ProxyAuth = null,
tls_verify_host: bool = true,
log_level: ?log.Level = null,
log_format: ?log.Format = null,
Expand All @@ -233,6 +242,14 @@ const Command = struct {
\\ and expects the proxy to MITM the request.
\\ Defaults to connect when --http_proxy is set.
\\
\\--proxy_bearer_token
\\ The token to send for bearer authentication with the proxy
\\ Proxy-Authorization: Bearer <token>
\\
\\--proxy_basic_auth
\\ The user:password to send for basic authentication with the proxy
\\ Proxy-Authorization: Basic <base64(user:password)>
\\
\\--log_level The log level: debug, info, warn, error or fatal.
\\ Defaults to
++ (if (builtin.mode == .Debug) " info." else "warn.") ++
Expand Down Expand Up @@ -492,6 +509,31 @@ fn parseCommonArg(
return true;
}

if (std.mem.eql(u8, "--proxy_bearer_token", opt)) {
if (common.proxy_auth != null) {
log.fatal(.app, "proxy auth already set", .{ .arg = "--proxy_bearer_token" });
return error.InvalidArgument;
}
const str = args.next() orelse {
log.fatal(.app, "missing argument value", .{ .arg = "--proxy_bearer_token" });
return error.InvalidArgument;
};
common.proxy_auth = .{ .bearer = .{ .token = str } };
return true;
}
if (std.mem.eql(u8, "--proxy_basic_auth", opt)) {
if (common.proxy_auth != null) {
log.fatal(.app, "proxy auth already set", .{ .arg = "--proxy_basic_auth" });
return error.InvalidArgument;
}
const str = args.next() orelse {
log.fatal(.app, "missing argument value", .{ .arg = "--proxy_basic_auth" });
return error.InvalidArgument;
};
common.proxy_auth = .{ .basic = .{ .user_pass = str } };
return true;
}

if (std.mem.eql(u8, "--log_level", opt)) {
const str = args.next() orelse {
log.fatal(.app, "missing argument value", .{ .arg = "--log_level" });
Expand Down Expand Up @@ -606,6 +648,7 @@ fn serveHTTP(address: std.net.Address) !void {
var conn = try listener.accept();
defer conn.stream.close();
var http_server = std.http.Server.init(conn, &read_buffer);
var connect_headers: std.ArrayListUnmanaged(std.http.Header) = .{};
REQUEST: while (true) {
var request = http_server.receiveHead() catch |err| switch (err) {
error.HttpConnectionClosing => continue :ACCEPT,
Expand All @@ -617,6 +660,16 @@ fn serveHTTP(address: std.net.Address) !void {

if (request.head.method == .CONNECT) {
try request.respond("", .{ .status = .ok });

// Proxy headers and destination headers are separated in the case of a CONNECT proxy
// We store the CONNECT headers, then continue with the request for the destination
var it = request.iterateHeaders();
while (it.next()) |hdr| {
try connect_headers.append(aa, .{
.name = try std.fmt.allocPrint(aa, "__{s}", .{hdr.name}),
.value = try aa.dupe(u8, hdr.value),
});
}
continue :REQUEST;
}

Expand Down Expand Up @@ -657,6 +710,11 @@ fn serveHTTP(address: std.net.Address) !void {
.value = hdr.value,
});
}

if (connect_headers.items.len > 0) {
try headers.appendSlice(aa, connect_headers.items);
connect_headers.clearRetainingCapacity();
}
try headers.append(aa, .{ .name = "Connection", .value = "Close" });

try request.respond("over 9000!", .{
Expand Down
5 changes: 4 additions & 1 deletion src/testing.zig
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ pub fn expectEqual(expected: anytype, actual: anytype) !void {
if (@typeInfo(@TypeOf(expected)) == .null) {
return std.testing.expectEqual(null, actual);
}
return expectEqual(expected, actual.?);
if (actual) |_actual| {
return expectEqual(expected, _actual);
}
return std.testing.expectEqual(expected, null);
},
.@"union" => |union_info| {
if (union_info.tag_type == null) {
Expand Down