diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DelegatingDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DelegatingDownloader.java index 249cc22c7c0d19..129a730bcfc8b6 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DelegatingDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DelegatingDownloader.java @@ -47,6 +47,7 @@ public void setDelegate(@Nullable Downloader delegate) { @Override public void download( List urls, + Map> headers, Credentials credentials, Optional checksum, String canonicalId, @@ -60,6 +61,6 @@ public void download( downloader = delegate; } downloader.download( - urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type); + urls, headers, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type); } } diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java index 43853fdfdd6701..a56ce634c24307 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java @@ -116,6 +116,7 @@ public void setCredentialFactory(CredentialFactory credentialFactory) { public Future startDownload( List originalUrls, + Map> headers, Map>> authHeaders, Optional checksum, String canonicalId, @@ -129,6 +130,7 @@ public Future startDownload( try (SilentCloseable c = Profiler.instance().profile("fetching: " + context)) { return downloadInExecutor( originalUrls, + headers, authHeaders, checksum, canonicalId, @@ -154,6 +156,7 @@ public Path finalizeDownload(Future download) throws IOException, Interrup public Path download( List originalUrls, + Map> headers, Map>> authHeaders, Optional checksum, String canonicalId, @@ -166,6 +169,7 @@ public Path download( Future future = startDownload( originalUrls, + headers, authHeaders, checksum, canonicalId, @@ -197,6 +201,7 @@ public Path download( */ private Path downloadInExecutor( List originalUrls, + Map> headers, Map>> authHeaders, Optional checksum, String canonicalId, @@ -339,6 +344,7 @@ private Path downloadInExecutor( try { downloader.download( rewrittenUrls, + headers, credentialFactory.create(rewrittenAuthHeaders), checksum, canonicalId, diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java index 1e8fc932b43b08..79a0076a3f56e1 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java @@ -42,6 +42,7 @@ public interface Downloader { */ void download( List urls, + Map> headers, Credentials credentials, Optional checksum, String canonicalId, diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java index b770d5b4731aea..8062bddc51095e 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java @@ -75,7 +75,7 @@ final class HttpConnectorMultiplexer { } public HttpStream connect(URL url, Optional checksum) throws IOException { - return connect(url, checksum, StaticCredentials.EMPTY, Optional.empty()); + return connect(url, checksum, ImmutableMap.of(), StaticCredentials.EMPTY, Optional.empty()); } /** @@ -96,14 +96,19 @@ public HttpStream connect(URL url, Optional checksum) throws IOExcepti * @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol */ public HttpStream connect( - URL url, Optional checksum, Credentials credentials, Optional type) + URL url, Optional checksum, Map> headers, Credentials credentials, Optional type) throws IOException { Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url)); if (Thread.interrupted()) { throw new InterruptedIOException(); } + ImmutableMap.Builder> baseHeaders = new ImmutableMap.Builder(); + baseHeaders.putAll(headers); + // REQUEST_HEADERS should not be overridable by user provided headers + baseHeaders.putAll(REQUEST_HEADERS); + Function>> headerFunction = - getHeaderFunction(REQUEST_HEADERS, credentials, eventHandler); + getHeaderFunction(baseHeaders.buildKeepingLast(), credentials, eventHandler); URLConnection connection = connector.connect(url, headerFunction); return httpStreamFactory.create( connection, diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java index 3e9e0f150a6ac8..fbd1f0b8eaceac 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java @@ -15,6 +15,7 @@ package com.google.devtools.build.lib.bazel.repository.downloader; import com.google.auth.Credentials; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.io.ByteStreams; @@ -75,6 +76,7 @@ public void setMaxRetryTimeout(Duration maxRetryTimeout) { @Override public void download( List urls, + Map> headers, Credentials credentials, Optional checksum, String canonicalId, @@ -94,7 +96,7 @@ public void download( for (URL url : urls) { SEMAPHORE.acquire(); - try (HttpStream payload = multiplexer.connect(url, checksum, credentials, type); + try (HttpStream payload = multiplexer.connect(url, checksum, headers, credentials, type); OutputStream out = destination.getOutputStream()) { try { ByteStreams.copy(payload, out); @@ -153,7 +155,7 @@ public byte[] downloadAndReadOneUrl( ByteArrayOutputStream out = new ByteArrayOutputStream(); SEMAPHORE.acquire(); try (HttpStream payload = - multiplexer.connect(url, Optional.empty(), credentials, Optional.empty())) { + multiplexer.connect(url, Optional.empty(), ImmutableMap.of(), credentials, Optional.empty())) { ByteStreams.copy(payload, out); } catch (SocketTimeoutException e) { // SocketTimeoutExceptions are InterruptedIOExceptions; however they do not signify diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java index 954b74fdbb4fba..53cdf3932f6472 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java @@ -80,6 +80,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -281,6 +282,27 @@ private static ImmutableMap>> getAuthHeaders( return res; } + private static ImmutableMap> getHeaderContents(Dict x, String what) + throws EvalException { + Dict headersUnchecked = (Dict) Dict.cast(x, String.class, Object.class, what); + ImmutableMap.Builder> headers = new ImmutableMap.Builder<>(); + + for (Map.Entry entry : headersUnchecked.entrySet()) { + List headerValue; + Object valueUnchecked = entry.getValue(); + if (valueUnchecked instanceof Sequence) { + headerValue = Sequence.cast(valueUnchecked, String.class, "header values").getImmutableList(); + } else if (valueUnchecked instanceof String) { + headerValue = List.of(valueUnchecked.toString()); + } else { + throw new EvalException( + String.format("%s argument must be a dict whose keys are string and whose values are either string or sequence of string", what)); + } + headers.put(entry.getKey(), headerValue); + } + return headers.buildOrThrow(); + } + private static ImmutableList checkAllUrls(Iterable urlList) throws EvalException { ImmutableList.Builder result = ImmutableList.builder(); @@ -577,6 +599,11 @@ private StructImpl completeDownload(PendingDownload pendingDownload) defaultValue = "{}", named = true, doc = "An optional dict specifying authentication information for some of the URLs."), + @Param( + name = "headers", + defaultValue = "{}", + named = true, + doc = "An optional dict specifying http headers for all URLs."), @Param( name = "integrity", defaultValue = "''", @@ -606,7 +633,8 @@ public Object download( Boolean executable, Boolean allowFail, String canonicalId, - Dict authUnchecked, // expected + Dict authUnchecked, // expected + Dict headersUnchecked, // | String> expected String integrity, Boolean block, StarlarkThread thread) @@ -615,6 +643,8 @@ public Object download( ImmutableMap>> authHeaders = getAuthHeaders(getAuthContents(authUnchecked, "auth")); + ImmutableMap> headers = getHeaderContents(headersUnchecked, "headers"); + ImmutableList urls = getUrls( url, @@ -660,6 +690,7 @@ public Object download( Future downloadFuture = downloadManager.startDownload( urls, + headers, authHeaders, checksum, canonicalId, @@ -768,6 +799,11 @@ public Object download( defaultValue = "{}", named = true, doc = "An optional dict specifying authentication information for some of the URLs."), + @Param( + name = "headers", + defaultValue = "{}", + named = true, + doc = "An optional dict specifying http headers for all URLs."), @Param( name = "integrity", defaultValue = "''", @@ -799,13 +835,16 @@ public StructImpl downloadAndExtract( String stripPrefix, Boolean allowFail, String canonicalId, - Dict auth, // expected + Dict authUnchecked, // expected + Dict headersUnchecked, // | String> expected String integrity, Dict renameFiles, // expected StarlarkThread thread) throws RepositoryFunctionException, InterruptedException, EvalException { ImmutableMap>> authHeaders = - getAuthHeaders(getAuthContents(auth, "auth")); + getAuthHeaders(getAuthContents(authUnchecked, "auth")); + + ImmutableMap> headers = getHeaderContents(headersUnchecked, "headers"); ImmutableList urls = getUrls( @@ -852,6 +891,7 @@ public StructImpl downloadAndExtract( Future pendingDownload = downloadManager.startDownload( urls, + headers, authHeaders, checksum, canonicalId, diff --git a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java index 29466a5e9d9407..44ff796b529dc3 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java @@ -110,6 +110,7 @@ public void close() { @Override public void download( List urls, + Map> headers, Credentials credentials, Optional checksum, String canonicalId, @@ -154,7 +155,7 @@ public void download( eventHandler.handle( Event.warn("Remote Cache: " + Utils.grpcAwareErrorMessage(e, verboseFailures))); fallbackDownloader.download( - urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type); + urls, headers, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type); } } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java index 2451af6027aabf..0e59d412abb975 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java @@ -117,6 +117,7 @@ public void downloadFrom1UrlOk() throws IOException, InterruptedException { Collections.singletonList( new URL(String.format("http://localhost:%d/foo", server.getLocalPort()))), Collections.emptyMap(), + Collections.emptyMap(), Optional.empty(), "testCanonicalId", Optional.empty(), @@ -181,6 +182,7 @@ public void downloadFrom2UrlsFirstOk() throws IOException, InterruptedException downloadManager.download( urls, Collections.emptyMap(), + Collections.emptyMap(), Optional.empty(), "testCanonicalId", Optional.empty(), @@ -248,6 +250,7 @@ public void downloadFrom2UrlsFirstSocketTimeoutOnBodyReadSecondOk() downloadManager.download( urls, Collections.emptyMap(), + Collections.emptyMap(), Optional.empty(), "testCanonicalId", Optional.empty(), @@ -317,6 +320,7 @@ public void downloadFrom2UrlsBothSocketTimeoutDuringBodyRead() downloadManager.download( urls, Collections.emptyMap(), + Collections.emptyMap(), Optional.empty(), "testCanonicalId", Optional.empty(), @@ -371,6 +375,7 @@ public void downloadOneUrl_ok() throws IOException, InterruptedException { httpDownloader.download( Collections.singletonList( new URL(String.format("http://localhost:%d/foo", server.getLocalPort()))), + Collections.emptyMap(), StaticCredentials.EMPTY, Optional.empty(), "testCanonicalId", @@ -410,6 +415,7 @@ public void downloadOneUrl_notFound() throws IOException, InterruptedException { httpDownloader.download( Collections.singletonList( new URL(String.format("http://localhost:%d/foo", server.getLocalPort()))), + Collections.emptyMap(), StaticCredentials.EMPTY, Optional.empty(), "testCanonicalId", @@ -470,6 +476,7 @@ public void downloadTwoUrls_firstNotFoundAndSecondOk() throws IOException, Inter Path destination = fs.getPath(workingDir.newFile().getAbsolutePath()); httpDownloader.download( urls, + Collections.emptyMap(), StaticCredentials.EMPTY, Optional.empty(), "testCanonicalId", @@ -564,13 +571,14 @@ public void download_contentLengthMismatch_propagateErrorIfNotRetry() throws Exc throw new ContentLengthMismatchException(0, data.length); }) .when(downloader) - .download(any(), any(), any(), any(), any(), any(), any(), any()); + .download(any(), any(), any(), any(), any(), any(), any(), any(), any()); assertThrows( ContentLengthMismatchException.class, () -> downloadManager.download( ImmutableList.of(new URL("http://localhost")), + Collections.emptyMap(), ImmutableMap.of(), Optional.empty(), "testCanonicalId", @@ -597,7 +605,7 @@ public void download_contentLengthMismatch_retries() throws Exception { if (times.getAndIncrement() < 3) { throw new ContentLengthMismatchException(0, data.length); } - Path output = invocationOnMock.getArgument(4, Path.class); + Path output = invocationOnMock.getArgument(5, Path.class); try (OutputStream outputStream = output.getOutputStream()) { ByteStreams.copy(new ByteArrayInputStream(data), outputStream); } @@ -605,12 +613,13 @@ public void download_contentLengthMismatch_retries() throws Exception { return null; }) .when(downloader) - .download(any(), any(), any(), any(), any(), any(), any(), any()); + .download(any(), any(), any(), any(), any(), any(), any(), any(), any()); Path result = downloadManager.download( ImmutableList.of(new URL("http://localhost")), ImmutableMap.of(), + ImmutableMap.of(), Optional.empty(), "testCanonicalId", Optional.empty(), diff --git a/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java index ced0fe561c830a..f881d862800e83 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java @@ -178,6 +178,7 @@ private static byte[] downloadBlob( final Path destination = scratch.resolve("output file path"); downloader.download( urls, + ImmutableMap.of(), StaticCredentials.EMPTY, checksum, canonicalId, @@ -240,13 +241,13 @@ public void fetchBlob( invocation -> { List urls = invocation.getArgument(0); if (urls.equals(ImmutableList.of(new URL("http://example.com/content.txt")))) { - Path output = invocation.getArgument(4); + Path output = invocation.getArgument(5); FileSystemUtils.writeContent(output, content); } return null; }) .when(fallbackDownloader) - .download(any(), any(), any(), any(), any(), any(), any(), any()); + .download(any(), any(), any(), any(), any(), any(), any(), any(), any()); final GrpcRemoteDownloader downloader = newDownloader(cacheClient, fallbackDownloader); final byte[] downloaded = diff --git a/src/test/shell/bazel/remote_helpers.sh b/src/test/shell/bazel/remote_helpers.sh index 6a66398d3eabd9..66b3899794bc9d 100755 --- a/src/test/shell/bazel/remote_helpers.sh +++ b/src/test/shell/bazel/remote_helpers.sh @@ -179,6 +179,28 @@ function serve_timeout() { cd - } +# Serves a HTTP 200 Ok response with headers dumped into the file +# Args: +# $1: required; path to the file +# $2: optional; path to the file where headers will be written to. +function serve_file_header_dump() { + file_name=served_file.$$ + cat $1 > "${TEST_TMPDIR}/$file_name" + nc_log="${TEST_TMPDIR}/nc.log" + rm -f $nc_log + touch $nc_log + cd "${TEST_TMPDIR}" + port_file=server-port.$$ + rm -f $port_file + python3 $python_server always $file_name --dump_headers ${2:-"headers.json"} > $port_file & + nc_pid=$! + while ! grep started $port_file; do sleep 1; done + nc_port=$(head -n 1 $port_file) + fileserver_port=$nc_port + wait_for_server_startup + cd - +} + # Waits for the SimpleHTTPServer to actually start up before the test is run. # Otherwise the entire test can run before the server starts listening for # connections, which causes flakes. diff --git a/src/test/shell/bazel/starlark_repository_test.sh b/src/test/shell/bazel/starlark_repository_test.sh index 4a4fe82597c3bb..fcf705ee5f54ca 100755 --- a/src/test/shell/bazel/starlark_repository_test.sh +++ b/src/test/shell/bazel/starlark_repository_test.sh @@ -2354,8 +2354,8 @@ genrule( cmd = "cp $< $@", ) EOF - - bazel build --repository_disable_download //:it || fail "Failed to build" + # for some reason --repository_disable_download fails with bzlmod trying to download @platforms. + bazel build --repository_disable_download --noenable_bzlmod //:it || fail "Failed to build" } function test_no_restarts_fetching_with_worker_thread() { @@ -2408,4 +2408,179 @@ EOF || fail "Expected build to succeed" } + +function test_cred_helper_overrides_starlark_headers() { + if "$is_windows"; then + # Skip on Windows: credential helper is a Python script. + return + fi + + setup_credential_helper + + filename="cred_helper_starlark.txt" + echo $filename > $filename + sha256="$(sha256sum $filename | head -c 64)" + serve_file_header_dump $filename credhelper_headers.json + + setup_starlark_repository + + cat > test.bzl < $filename + sha256="$(sha256sum $filename | head -c 64)" + serve_file_header_dump $filename netrc_headers.json + + setup_starlark_repository + + cat > .netrc < test.bzl < $filename + sha256="$(sha256sum $filename | head -c 64)" + serve_file_header_dump $filename default_headers.json + + setup_starlark_repository + + cat > test.bzl < $filename + sha256="$(sha256sum $filename | head -c 64)" + serve_file_header_dump $filename invalid_headers.json + + setup_starlark_repository + + cat > test.bzl <& $TEST_log && fail "expected bazel to fail" || : + expect_log "headers argument must be a dict whose keys are string and whose values are either string or sequence of string" +} + +function test_string_starlark_headers() { + + filename="string_headers.txt" + echo $filename > $filename + sha256="$(sha256sum $filename | head -c 64)" + serve_file_header_dump $filename string_headers.json + + setup_starlark_repository + + cat > test.bzl <