Skip to content

Commit

Permalink
Add tests for the registry package;
Browse files Browse the repository at this point in the history
  • Loading branch information
ycliuhw committed Sep 6, 2021
1 parent 06a5476 commit b444778
Show file tree
Hide file tree
Showing 19 changed files with 1,644 additions and 537 deletions.
21 changes: 0 additions & 21 deletions docker/auth.go
Expand Up @@ -22,20 +22,6 @@ import (

var logger = loggo.GetLogger("juju.docker")

// APIVersion is the API version type.
type APIVersion string

const (
// APIVersionV1 is the API version v1.
APIVersionV1 APIVersion = "v1"
// APIVersionV2 is the API version v2.
APIVersionV2 APIVersion = "v2"
)

func (v APIVersion) String() string {
return string(v)
}

// TokenAuthConfig contains authorization information for token auth.
// Juju does not support the docker credential helper because k8s does not support it either.
// https://kubernetes.io/docs/concepts/containers/images/#configuring-nodes-to-authenticate-to-a-private-registry
Expand Down Expand Up @@ -173,13 +159,6 @@ func (rid ImageRepoDetails) Empty() bool {
return rid == ImageRepoDetails{}
}

func (rid ImageRepoDetails) APIVersion() APIVersion {
if rid.IsPrivate() {
return APIVersionV2
}
return APIVersionV1
}

func fileExists(p string) (bool, error) {
info, err := os.Stat(p)
if os.IsNotExist(err) {
Expand Down
69 changes: 44 additions & 25 deletions docker/registry/acr.go
Expand Up @@ -11,18 +11,20 @@ import (
"github.com/juju/errors"

"github.com/juju/juju/docker"
"github.com/juju/juju/tools"
)

type acr struct {
// TODO(ycliuhw): test and verify azureContainerRegistry integration further.
type azureContainerRegistry struct {
*baseClient
}

func newACR(repoDetails docker.ImageRepoDetails, transport http.RoundTripper) RegistryInternal {
func newAzureContainerRegistry(repoDetails docker.ImageRepoDetails, transport http.RoundTripper) RegistryInternal {
c := newBase(repoDetails, DefaultTransport)
return &acr{c}
return &azureContainerRegistry{c}
}

func (c *acr) Match() bool {
func (c *azureContainerRegistry) Match() bool {
return strings.Contains(c.repoDetails.ServerAddress, "azurecr.io")
}

Expand All @@ -38,30 +40,47 @@ func getUserNameFromAuthForACR(auth string) (string, error) {
return parts[0], nil
}

func (c *acr) WrapTransport() error {
if !c.repoDetails.IsPrivate() {
return nil
}
func (c *azureContainerRegistry) WrapTransport() error {
transport := c.client.Transport
if !c.repoDetails.TokenAuthConfig.Empty() {
username := c.repoDetails.Username
if username == "" {
var err error
username, err = getUserNameFromAuthForACR(c.repoDetails.Auth)
if err != nil {
return errors.Trace(err)
if c.repoDetails.IsPrivate() {
if !c.repoDetails.TokenAuthConfig.Empty() {
username := c.repoDetails.Username
if username == "" {
var err error
username, err = getUserNameFromAuthForACR(c.repoDetails.Auth)
if err != nil {
return errors.Trace(err)
}
}
password := c.repoDetails.Password
if password == "" {
password = c.repoDetails.IdentityToken
}
transport = newTokenTransport(
transport,
username, password,
"", "",
)
}
password := c.repoDetails.Password
if password == "" {
password = c.repoDetails.IdentityToken
}
transport = newTokenTransport(
transport,
username, password,
"", "",
)
}
c.client.Transport = errorTransport{transport}
c.client.Transport = newErrorTransport(transport)
return nil
}

// Tags fetches tags for an OCI image.
func (c azureContainerRegistry) Tags(imageName string) (versions tools.Versions, err error) {
apiVersion := c.APIVersion()

if apiVersion == APIVersionV1 {
url := c.url("/repositories/%s/tags", imageName)
var response tagsResponseV1
return c.fetchTags(url, &response)
}
if apiVersion == APIVersionV2 {
url := c.url("/%s/tags/list", imageName)
var response tagsResponseV2
return c.fetchTags(url, &response)
}
// This should never happen.
return nil, nil
}
196 changes: 196 additions & 0 deletions docker/registry/base.go
@@ -0,0 +1,196 @@
// Copyright 2021 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package registry

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"path"
"regexp"
"strings"
"time"

"github.com/juju/errors"
"github.com/juju/loggo"

"github.com/juju/juju/docker"
)

var logger = loggo.GetLogger("juju.docker.registry")

const (
defaultTimeout = 15 * time.Second
)

// APIVersion is the API version type.
type APIVersion string

const (
// APIVersionV1 is the API version v1.
APIVersionV1 APIVersion = "v1"
// APIVersionV2 is the API version v2.
APIVersionV2 APIVersion = "v2"
)

func (v APIVersion) String() string {
return string(v)
}

var (
// Override for testing.
DefaultTransport = http.DefaultTransport
)

type baseClient struct {
baseURL *url.URL
client *http.Client
repoDetails *docker.ImageRepoDetails
}

func newBase(repoDetails docker.ImageRepoDetails, transport http.RoundTripper) *baseClient {
c := &baseClient{
repoDetails: &repoDetails,
client: &http.Client{
Transport: transport,
Timeout: defaultTimeout,
},
}
return c
}

func (c *baseClient) Match() bool {
return false
}

// APIVersion returns the registry API version to use.
func (c *baseClient) APIVersion() APIVersion {
if c.repoDetails.IsPrivate() {
return APIVersionV2
}
return APIVersionV1
}

func (c *baseClient) WrapTransport() error {
logger.Criticalf("baseClient.WrapTransport")
transport := c.client.Transport
if c.repoDetails.IsPrivate() {
if !c.repoDetails.BasicAuthConfig.Empty() {
transport = newTokenTransport(
transport, c.repoDetails.Username, c.repoDetails.Password, c.repoDetails.Auth, "",
)
}
if !c.repoDetails.TokenAuthConfig.Empty() {
return errors.New(
fmt.Sprintf(
`only "username" and "password" or "auth" token authorization is supported for registry %q`,
c.repoDetails.ServerAddress,
),
)
}
}
c.client.Transport = newErrorTransport(transport)
return nil
}

func (c *baseClient) DecideBaseURL() error {
logger.Criticalf("baseClient.DecideBaseURL")
addr := c.repoDetails.ServerAddress
if addr == "" {
return errors.NotValidf("empty server address for %q", c.repoDetails.Repository)
}
url, err := url.Parse(addr)
if err != nil {
return errors.Trace(err)
}
serverAddressURL := *url
apiVersion := c.APIVersion().String()
if !strings.Contains(url.Path, "/"+apiVersion) {
url.Path = path.Join(url.Path, apiVersion)
}
if url.Scheme == "" {
url.Scheme = "https"
}
c.baseURL = url

serverAddressURL.Scheme = ""
c.repoDetails.ServerAddress = serverAddressURL.String()
logger.Criticalf("baseClient.DecideBaseURL c.baseURL %q, r.repoDetails.ServerAddress %q", c.baseURL, c.repoDetails.ServerAddress)
return nil
}

func (c baseClient) url(pathTemplate string, args ...interface{}) string {
pathSuffix := fmt.Sprintf(pathTemplate, args...)
url := *c.baseURL
ver := c.APIVersion().String()
if !strings.HasSuffix(strings.TrimRight(url.Path, "/"), ver) {
url.Path = path.Join(url.Path, ver)
}
if url.Scheme == "" {
url.Scheme = "https"
}
url.Path = path.Join(url.Path, pathSuffix)
logger.Criticalf("baseClient url.Path ===> %q, %q", url.Path, pathSuffix)
logger.Criticalf("baseClient c.baseURL ===> %q, url.String() ===> %q", c.baseURL, url.String())
return url.String()
}

// Ping pings the baseClient endpoint.
func (c baseClient) Ping() error {
url := c.url("/")
logger.Debugf("baseClient ping %q", url)
resp, err := c.client.Get(url)
if resp != nil {
defer resp.Body.Close()
}
return errors.Trace(err)
}

func (c baseClient) ImageRepoDetails() (o docker.ImageRepoDetails) {
if c.repoDetails != nil {
return *c.repoDetails
}
return o
}

// Close closes the transport used by the client.
func (c *baseClient) Close() error {
if t, ok := c.client.Transport.(*http.Transport); ok {
t.CloseIdleConnections()
}
return nil
}

func (c baseClient) getPaginatedJSON(url string, response interface{}) (string, error) {
logger.Criticalf("baseClient.getPaginatedJSON url ===> %q", url)
resp, err := c.client.Get(url)
logger.Criticalf("getPaginatedJSON err %#v", err)
if err != nil {
return "", errors.Trace(err)
}
defer resp.Body.Close()

decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(response)
if err != nil {
return "", errors.Trace(err)
}
return getNextLink(resp)
}

var (
nextLinkRE = regexp.MustCompile(`^ *<?([^;>]+)>? *(?:;[^;]*)*; *rel="?next"?(?:;.*)?`)
errNoMorePages = errors.New("no more pages")
)

func getNextLink(resp *http.Response) (string, error) {
for _, link := range resp.Header[http.CanonicalHeaderKey("Link")] {
parts := nextLinkRE.FindStringSubmatch(link)
if parts != nil {
return parts[1], nil
}
}
return "", errNoMorePages
}

0 comments on commit b444778

Please sign in to comment.