Skip to content

Commit

Permalink
chech docker image platform when pulling docker image
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrey Balakshiy committed Mar 15, 2024
1 parent 50e1686 commit ef7bc11
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 12 deletions.
12 changes: 12 additions & 0 deletions lzy/execution-env/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.11.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>5.2.0</version>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package ai.lzy.env.base;

import com.github.dockerjava.core.DockerClientConfig;
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public record DockerEnvDescription(
Expand All @@ -18,7 +22,9 @@ public record DockerEnvDescription(
List<String> envVars, // In format <NAME>=<value>
@Nullable
String networkMode,
DockerClientConfig dockerClientConfig
DockerClientConfig dockerClientConfig,
@Nonnull
Set<String> allowedPlatforms // In format os/arch like "linux/amd64". Empty means all allowed
) {

public static Builder newBuilder() {
Expand All @@ -32,6 +38,7 @@ public String toString() {
", image='" + image + '\'' +
", needGpu=" + needGpu +
", networkMode=" + networkMode +
", allowedPlatforms=" + String.join(", ", allowedPlatforms) +
", mounts=[" + mounts.stream()
.map(it -> it.source() + " -> " + it.target() + (it.isRshared() ? " (R_SHARED)" : ""))
.collect(Collectors.joining(", ")) + "]" +
Expand All @@ -52,6 +59,7 @@ public static class Builder {
List<String> envVars = new ArrayList<>();
String networkMode = null;
DockerClientConfig dockerClientConfig;
Set<String> allowedPlatforms = new HashSet<>();

public Builder withName(String name) {
this.name = name;
Expand Down Expand Up @@ -93,13 +101,18 @@ public Builder withDockerClientConfig(DockerClientConfig dockerClientConfig) {
return this;
}

public Builder withAllowedPlatforms(Collection<String> allowedPlatforms) {
this.allowedPlatforms.addAll(allowedPlatforms);
return this;
}

public DockerEnvDescription build() {
if (StringUtils.isBlank(name)) {
name = "job-" + RandomStringUtils.randomAlphanumeric(5);
}
return new DockerEnvDescription(name, image, mounts, gpu, envVars, networkMode, dockerClientConfig);
return new DockerEnvDescription(name, image, mounts, gpu, envVars, networkMode, dockerClientConfig,
allowedPlatforms);
}

}

public record ContainerRegistryCredentials(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
import ai.lzy.env.EnvironmentInstallationException;
import ai.lzy.env.logs.LogStream;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.async.ResultCallback;
import com.github.dockerjava.api.async.ResultCallbackTemplate;
import com.github.dockerjava.api.command.ExecCreateCmd;
import com.github.dockerjava.api.command.ExecCreateCmdResponse;
import com.github.dockerjava.api.command.InspectImageResponse;
import com.github.dockerjava.api.command.PullImageResultCallback;
import com.github.dockerjava.api.exception.DockerClientException;
import com.github.dockerjava.api.exception.DockerException;
import com.github.dockerjava.api.exception.NotFoundException;
import com.github.dockerjava.api.model.*;
import com.github.dockerjava.api.model.BindOptions;
import com.github.dockerjava.api.model.BindPropagation;
import com.github.dockerjava.api.model.DeviceRequest;
import com.github.dockerjava.api.model.Frame;
import com.github.dockerjava.api.model.HostConfig;
import com.github.dockerjava.api.model.Mount;
import com.github.dockerjava.api.model.MountType;
import com.github.dockerjava.api.model.PruneType;
import com.github.dockerjava.api.model.PullResponseItem;
import com.github.dockerjava.core.DockerClientImpl;
import com.github.dockerjava.httpclient5.ApacheDockerHttpClient;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
Expand All @@ -28,6 +39,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -37,6 +49,7 @@ public class DockerEnvironment extends BaseEnvironment {
private static final Logger LOG = LogManager.getLogger(DockerEnvironment.class);
private static final long GB_AS_BYTES = 1073741824;
private static final String ROOT_USER_UID = "0";
private static final String NO_MATCHING_MANIFEST_ERROR = "no matching manifest";

@Nullable
public String containerId = null;
Expand Down Expand Up @@ -275,12 +288,14 @@ public void close() throws Exception {
}
}

private void prepareImage(String image, LogStream out) throws Exception {
@VisibleForTesting
void prepareImage(String image, LogStream out) throws Exception {
try {
client.inspectImageCmd(image).exec();
var inspectImageResponse = client.inspectImageCmd(image).exec();
var msg = "Image %s exists".formatted(image);
LOG.info(msg);
out.log(msg);
checkPlatform(inspectImageResponse, out);
return;
} catch (NotFoundException ignored) {
var msg = "Image %s not found in cached images".formatted(image);
Expand All @@ -291,16 +306,63 @@ private void prepareImage(String image, LogStream out) throws Exception {
var msg = "Pulling image %s ...".formatted(image);
LOG.info(msg);
out.log(msg);
Set<String> allowedPlatforms = config.allowedPlatforms();
AtomicInteger pullingAttempt = new AtomicInteger(0);
retry.executeCallable(() -> {
try (var pullResponseItem = retry.executeCallable(() -> {
LOG.info("Pulling image {}, attempt {}", image, pullingAttempt.incrementAndGet());
final var pullingImage = client
.pullImageCmd(config.image())
.exec(new PullImageResultCallback());
return pullingImage.awaitCompletion();
});
if (allowedPlatforms.isEmpty()) {
return pullWithPlatform(image, null);
} else {
for (String platform : config.allowedPlatforms()) {
try {
return pullWithPlatform(image, platform);
} catch (DockerClientException e) {
if (e.getMessage().contains(NO_MATCHING_MANIFEST_ERROR)) {
LOG.info("Cannot find image = {} for platform = {}", image, platform);
} else {
throw e;
}
}
}
}
return null;
})
) {
if (pullResponseItem == null) {
throw new RuntimeException("Cannot pull image for allowed platforms = %s".formatted(String.join(", ", allowedPlatforms)));
}
}

msg = "Pulling image %s done".formatted(image);
LOG.info(msg);
out.log(msg);
}

private ResultCallback.Adapter<PullResponseItem> pullWithPlatform(String image, @Nullable String platform)
throws InterruptedException {
var pullingImage = client.pullImageCmd(image);
if (platform != null) {
pullingImage = pullingImage.withPlatform(platform);
}
return pullingImage.exec(new PullImageResultCallback()).awaitCompletion();
}

private void checkPlatform(InspectImageResponse inspectImageResponse, LogStream out) {
Set<String> allowedPlatforms = config.allowedPlatforms();
if (allowedPlatforms.isEmpty()) {
return;
}

String platform = inspectImageResponse.getOs() + "/" + inspectImageResponse.getArch();
if (!allowedPlatforms.contains(platform)) {
var allowedPlatformsStr = String.join(", ", allowedPlatforms);
var msg = "Image %s platform = %s is not in allowed platforms = %s".formatted(
config.image(), platform, allowedPlatformsStr);
LOG.info(msg);
out.log(msg);

throw new RuntimeException("Cached image platform = %s is not in allowed platforms = %s".formatted(
platform, allowedPlatformsStr));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package ai.lzy.env.base;

import ai.lzy.env.logs.LogStream;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.async.ResultCallback;
import com.github.dockerjava.api.command.InspectImageCmd;
import com.github.dockerjava.api.command.InspectImageResponse;
import com.github.dockerjava.api.command.PullImageCmd;
import com.github.dockerjava.api.command.PullImageResultCallback;
import com.github.dockerjava.api.exception.DockerClientException;
import com.github.dockerjava.api.exception.NotFoundException;
import com.github.dockerjava.api.model.PullResponseItem;
import com.github.dockerjava.core.DefaultDockerClientConfig;
import com.github.dockerjava.core.DockerClientConfig;
import com.github.dockerjava.core.DockerClientImpl;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DockerEnvironmentTest {
private final DockerClient dockerClient = mock(DockerClientImpl.class);

private final InspectImageCmd inspectImageCmd = mock(InspectImageCmd.class);

private final InspectImageResponse inspectImageResponse = mock(InspectImageResponse.class);

private final PullImageCmd pullImageCmd = mock(PullImageCmd.class);

private final PullImageCmd pullImageCmdForRightPlatform = mock(PullImageCmd.class);

private final ResultCallback.Adapter<PullResponseItem> callbackAdapter = mock(ResultCallback.Adapter.class);

private static final String IMAGE = RandomStringUtils.randomAlphanumeric(20).toLowerCase();

private final LogStream logStream = mock(LogStream.class);

@Before
public void setUp() throws Exception {
when(dockerClient.inspectImageCmd(IMAGE)).thenReturn(inspectImageCmd);
when(inspectImageCmd.exec()).thenReturn(inspectImageResponse);
when(dockerClient.pullImageCmd(IMAGE)).thenReturn(pullImageCmd);
when(pullImageCmd.withPlatform("linux/amd64")).thenReturn(pullImageCmdForRightPlatform);
when(pullImageCmd.withPlatform(anyString())).thenAnswer((arg) -> {
if ("linux/amd64".equals(arg.getArguments()[0])) {
return pullImageCmdForRightPlatform;
} else {
return pullImageCmd;
}
});
PullImageResultCallback pullImageResultCallback = mock(PullImageResultCallback.class);
when(pullImageCmdForRightPlatform.exec(any())).thenReturn(pullImageResultCallback);
when(pullImageCmd.exec(any())).thenThrow(new DockerClientException(
"Could not pull image: no matching manifest for %s in the manifest list entries"));

when(pullImageResultCallback.awaitCompletion()).thenReturn(callbackAdapter);
}

@Test
public void testPrepareImageCachedImage() throws Exception {
executeTest(this::doTestPrepareImageCachedImage);
}

@Test
public void testPrepareImageCachedImageWithNotAllowedPlatform() throws Exception {
executeTest(this::doTestPrepareImageCachedImageWithNotAllowedPlatform);
}

@Test
public void testPrepareImageNodCachedImageWithRightPlatform() throws Exception {
executeTest(this::doTestPrepareImageNodCachedImageWithRightPlatform);
}

@Test
public void testPrepareImageNodCachedImageWithoutRightPlatform() throws Exception {
executeTest(this::doTestPrepareImageNodCachedImageWithoutRightPlatform);
}

private void doTestPrepareImageCachedImage() throws Exception {
when(inspectImageResponse.getArch()).thenReturn("amd64");
when(inspectImageResponse.getOs()).thenReturn("linux");

DockerEnvironment environment = new DockerEnvironment(createDockerEnvDescription(
List.of("darwin/arm64", "linux/amd64")));

environment.prepareImage(IMAGE, logStream);

verify(dockerClient, never()).pullImageCmd(any());
}

private void doTestPrepareImageCachedImageWithNotAllowedPlatform() {
when(inspectImageResponse.getOs()).thenReturn("not_existed_os");
when(inspectImageResponse.getArch()).thenReturn("not_existed_arch");

DockerEnvironment environment = new DockerEnvironment(createDockerEnvDescription(
List.of("darwin/arm64", "linux/amd64")));

RuntimeException exception = Assert.assertThrows(RuntimeException.class,
() -> environment.prepareImage(IMAGE, logStream));
assertNotNull(exception);
assertEquals("Cached image platform = not_existed_os/not_existed_arch is not in allowed platforms = darwin/arm64, linux/amd64",
exception.getMessage());

verify(dockerClient, never()).pullImageCmd(any());
}

private void doTestPrepareImageNodCachedImageWithRightPlatform() throws Exception {
when(inspectImageCmd.exec()).thenThrow(new NotFoundException("com.github.dockerjava.api.exception.NotFoundException:" +
" Status 404: {\"message\":\"No such image: %s\"}\n".formatted(IMAGE)));
DockerEnvironment environment = new DockerEnvironment(createDockerEnvDescription(
List.of("darwin/arm64", "linux/amd64")));

environment.prepareImage(IMAGE, logStream);
verify(dockerClient, times(2)).pullImageCmd(IMAGE);
}

private void doTestPrepareImageNodCachedImageWithoutRightPlatform() throws Exception {
when(inspectImageCmd.exec()).thenThrow(new NotFoundException("com.github.dockerjava.api.exception.NotFoundException:" +
" Status 404: {\"message\":\"No such image: %s\"}\n".formatted(IMAGE)));
DockerEnvironment environment = new DockerEnvironment(createDockerEnvDescription(
List.of("darwin/arm64", "linux/win32")));

RuntimeException exception = Assert.assertThrows(RuntimeException.class,
() -> environment.prepareImage(IMAGE, logStream));
assertNotNull(exception);
assertEquals("Cannot pull image for allowed platforms = linux/win32, darwin/arm64", exception.getMessage());

verify(dockerClient, times(2)).pullImageCmd(IMAGE);
}

private DockerEnvDescription createDockerEnvDescription(List<String> allowedPlatforms) {
DockerClientConfig dockerClientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder().build();
return DockerEnvDescription.newBuilder()
.withDockerClientConfig(dockerClientConfig)
.withAllowedPlatforms(allowedPlatforms)
.build();
}


private void executeTest(Executable test) throws Exception {
try (var mockedDockerClient = mockStatic(DockerClientImpl.class)) {
mockedDockerClient.when(() -> DockerClientImpl.getInstance(any(), any())).thenReturn(dockerClient);
test.execute();
}
}


@FunctionalInterface
private interface Executable {
void execute() throws Exception;
}
}

0 comments on commit ef7bc11

Please sign in to comment.