From feee882031114cfe9d6f3f5ba628fb0295018e58 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Tue, 9 Apr 2024 11:16:50 -0500 Subject: [PATCH 1/2] chore(compute/metadata): fixup warnings --- compute/metadata/metadata.go | 6 +- compute/metadata/metadata_go113_test.go | 111 ------------------------ compute/metadata/metadata_test.go | 92 +++++++++++++++++++- compute/metadata/retry.go | 2 +- 4 files changed, 93 insertions(+), 118 deletions(-) delete mode 100644 compute/metadata/metadata_go113_test.go diff --git a/compute/metadata/metadata.go b/compute/metadata/metadata.go index c17faa142a44..d32262d4fe6c 100644 --- a/compute/metadata/metadata.go +++ b/compute/metadata/metadata.go @@ -23,7 +23,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/url" @@ -197,7 +197,7 @@ func systemInfoSuggestsGCE() bool { // We don't have any non-Linux clues available, at least yet. return false } - slurp, _ := ioutil.ReadFile("/sys/class/dmi/id/product_name") + slurp, _ := os.ReadFile("/sys/class/dmi/id/product_name") name := strings.TrimSpace(string(slurp)) return name == "Google" || name == "Google Compute Engine" } @@ -336,7 +336,7 @@ func (c *Client) getETag(suffix string) (value, etag string, err error) { if res.StatusCode == http.StatusNotFound { return "", "", NotDefinedError(suffix) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { return "", "", err } diff --git a/compute/metadata/metadata_go113_test.go b/compute/metadata/metadata_go113_test.go deleted file mode 100644 index c9d32bf94931..000000000000 --- a/compute/metadata/metadata_go113_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2016 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build go1.13 -// +build go1.13 - -package metadata - -import ( - "io" - "io/ioutil" - "net/http" - "strings" - "testing" -) - -func TestRetry(t *testing.T) { - tests := []struct { - name string - timesToFail int - failCode int - failErr error - response string - expectError bool - }{ - { - name: "no retries", - response: "test", - }, - { - name: "retry 500 once", - response: "test", - failCode: 500, - timesToFail: 1, - }, - { - name: "retry io.ErrUnexpectedEOF once", - response: "test", - failErr: io.ErrUnexpectedEOF, - timesToFail: 1, - }, - { - name: "retry io.ErrUnexpectedEOF permanent", - failErr: io.ErrUnexpectedEOF, - timesToFail: maxRetryAttempts + 1, - expectError: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ft := &failingTransport{ - timesToFail: tt.timesToFail, - failCode: tt.failCode, - failErr: tt.failErr, - response: tt.response, - } - c := NewClient(&http.Client{Transport: ft}) - s, err := c.Get("") - if tt.expectError && err == nil { - t.Fatalf("did not receive expected error") - } else if !tt.expectError && err != nil { - t.Fatalf("unexpected error: %v", err) - } - - expectedCount := ft.failedAttempts + 1 - if tt.expectError { - expectedCount = ft.failedAttempts - } else if s != tt.response { - // Responses are only meaningful if err == nil - t.Fatalf("c.Get() = %q, want %q", s, tt.response) - } - - if ft.called != expectedCount { - t.Fatalf("failed %d times, want %d", ft.called, expectedCount) - } - }) - } -} - -type failingTransport struct { - timesToFail int - failCode int - failErr error - response string - - failedAttempts int - called int -} - -func (r *failingTransport) RoundTrip(req *http.Request) (*http.Response, error) { - r.called++ - if r.failedAttempts < r.timesToFail { - r.failedAttempts++ - if r.failErr != nil { - return nil, r.failErr - } - return &http.Response{StatusCode: r.failCode}, nil - } - return &http.Response{StatusCode: http.StatusOK, Body: ioutil.NopCloser(strings.NewReader(r.response))}, nil -} diff --git a/compute/metadata/metadata_test.go b/compute/metadata/metadata_test.go index f779c18b7198..281ca1c4d7b1 100644 --- a/compute/metadata/metadata_test.go +++ b/compute/metadata/metadata_test.go @@ -16,10 +16,11 @@ package metadata import ( "bytes" - "io/ioutil" + "io" "log" "net/http" "os" + "strings" "sync" "testing" ) @@ -106,7 +107,7 @@ type captureTransport struct { func (ct *captureTransport) RoundTrip(req *http.Request) (*http.Response, error) { ct.url = req.URL.String() - return &http.Response{Body: ioutil.NopCloser(bytes.NewReader(nil))}, nil + return &http.Response{Body: io.NopCloser(io.Reader(bytes.NewReader(nil)))}, nil } type userAgentTransport struct { @@ -125,5 +126,90 @@ type rrt struct { func (r *rrt) RoundTrip(req *http.Request) (*http.Response, error) { r.gotUserAgent = req.Header.Get("User-Agent") - return &http.Response{Body: ioutil.NopCloser(bytes.NewReader(nil))}, nil + return &http.Response{Body: io.NopCloser(bytes.NewReader(nil))}, nil +} + +func TestRetry(t *testing.T) { + tests := []struct { + name string + timesToFail int + failCode int + failErr error + response string + expectError bool + }{ + { + name: "no retries", + response: "test", + }, + { + name: "retry 500 once", + response: "test", + failCode: 500, + timesToFail: 1, + }, + { + name: "retry io.ErrUnexpectedEOF once", + response: "test", + failErr: io.ErrUnexpectedEOF, + timesToFail: 1, + }, + { + name: "retry io.ErrUnexpectedEOF permanent", + failErr: io.ErrUnexpectedEOF, + timesToFail: maxRetryAttempts + 1, + expectError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ft := &failingTransport{ + timesToFail: tt.timesToFail, + failCode: tt.failCode, + failErr: tt.failErr, + response: tt.response, + } + c := NewClient(&http.Client{Transport: ft}) + s, err := c.Get("") + if tt.expectError && err == nil { + t.Fatalf("did not receive expected error") + } else if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedCount := ft.failedAttempts + 1 + if tt.expectError { + expectedCount = ft.failedAttempts + } else if s != tt.response { + // Responses are only meaningful if err == nil + t.Fatalf("c.Get() = %q, want %q", s, tt.response) + } + + if ft.called != expectedCount { + t.Fatalf("failed %d times, want %d", ft.called, expectedCount) + } + }) + } +} + +type failingTransport struct { + timesToFail int + failCode int + failErr error + response string + + failedAttempts int + called int +} + +func (r *failingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + r.called++ + if r.failedAttempts < r.timesToFail { + r.failedAttempts++ + if r.failErr != nil { + return nil, r.failErr + } + return &http.Response{StatusCode: r.failCode}, nil + } + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(r.response))}, nil } diff --git a/compute/metadata/retry.go b/compute/metadata/retry.go index 0f18f3cda1e2..3d4bc75ddf26 100644 --- a/compute/metadata/retry.go +++ b/compute/metadata/retry.go @@ -27,7 +27,7 @@ const ( ) var ( - syscallRetryable = func(err error) bool { return false } + syscallRetryable = func(error) bool { return false } ) // defaultBackoff is basically equivalent to gax.Backoff without the need for From 7d2ab7ff7c7e93089c4e13cc66f4e4de96ccb1a1 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Tue, 9 Apr 2024 12:08:30 -0500 Subject: [PATCH 2/2] feat(compute/metadata): add context aware functions This change adds the minimal amount of context aware functionality so that users can pass a context to all metadata requests. This does not re-expose all the helper methods this package provides. We can add context varients for all of these in the future and/or if we ever create a v2 of this package. Fixes: #4483 --- compute/metadata/examples_test.go | 6 +- compute/metadata/metadata.go | 98 +++++++++++++++++++++---------- compute/metadata/metadata_test.go | 59 +++++++++++++++++-- 3 files changed, 126 insertions(+), 37 deletions(-) diff --git a/compute/metadata/examples_test.go b/compute/metadata/examples_test.go index 6c546a63c935..d043354c252f 100644 --- a/compute/metadata/examples_test.go +++ b/compute/metadata/examples_test.go @@ -15,6 +15,7 @@ package metadata_test import ( + "context" "net/http" "cloud.google.com/go/compute/metadata" @@ -22,15 +23,16 @@ import ( // This example demonstrates how to use your own transport when using this package. func ExampleNewClient() { + ctx := context.Background() c := metadata.NewClient(&http.Client{Transport: userAgentTransport{ userAgent: "my-user-agent", base: http.DefaultTransport, }}) - p, err := c.ProjectID() + pID, err := c.GetWithContext(ctx, "project/project-id") if err != nil { // TODO: Handle error. } - _ = p // TODO: Use p. + _ = pID // TODO: Use p. } // userAgentTransport sets the User-Agent header before calling base. diff --git a/compute/metadata/metadata.go b/compute/metadata/metadata.go index d32262d4fe6c..f67e3c7eeae0 100644 --- a/compute/metadata/metadata.go +++ b/compute/metadata/metadata.go @@ -95,9 +95,9 @@ func (c *cachedValue) get(cl *Client) (v string, err error) { return c.v, nil } if c.trim { - v, err = cl.getTrimmed(c.k) + v, err = cl.getTrimmed(context.Background(), c.k) } else { - v, err = cl.Get(c.k) + v, err = cl.GetWithContext(context.Background(), c.k) } if err == nil { c.v = v @@ -202,13 +202,27 @@ func systemInfoSuggestsGCE() bool { return name == "Google" || name == "Google Compute Engine" } -// Subscribe calls Client.Subscribe on the default client. +// Subscribe calls Client.SubscribeWithContext on the default client. func Subscribe(suffix string, fn func(v string, ok bool) error) error { - return defaultClient.Subscribe(suffix, fn) + return defaultClient.SubscribeWithContext(context.Background(), suffix, func(ctx context.Context, v string, ok bool) error { return fn(v, ok) }) } -// Get calls Client.Get on the default client. -func Get(suffix string) (string, error) { return defaultClient.Get(suffix) } +// SubscribeWithContext calls Client.SubscribeWithContext on the default client. +func SubscribeWithContext(ctx context.Context, suffix string, fn func(ctx context.Context, v string, ok bool) error) error { + return defaultClient.SubscribeWithContext(ctx, suffix, fn) +} + +// Get calls Client.GetWithContext on the default client. +// +// Deprecated: Please use the context aware variant [GetWithContext]. +func Get(suffix string) (string, error) { + return defaultClient.GetWithContext(context.Background(), suffix) +} + +// GetWithContext calls Client.GetWithContext on the default client. +func GetWithContext(ctx context.Context, suffix string) (string, error) { + return defaultClient.GetWithContext(ctx, suffix) +} // ProjectID returns the current instance's project ID string. func ProjectID() (string, error) { return defaultClient.ProjectID() } @@ -288,8 +302,7 @@ func NewClient(c *http.Client) *Client { // getETag returns a value from the metadata service as well as the associated ETag. // This func is otherwise equivalent to Get. -func (c *Client) getETag(suffix string) (value, etag string, err error) { - ctx := context.TODO() +func (c *Client) getETag(ctx context.Context, suffix string) (value, etag string, err error) { // Using a fixed IP makes it very difficult to spoof the metadata service in // a container, which is an important use-case for local testing of cloud // deployments. To enable spoofing of the metadata service, the environment @@ -306,7 +319,7 @@ func (c *Client) getETag(suffix string) (value, etag string, err error) { } suffix = strings.TrimLeft(suffix, "/") u := "http://" + host + "/computeMetadata/v1/" + suffix - req, err := http.NewRequest("GET", u, nil) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) if err != nil { return "", "", err } @@ -354,19 +367,33 @@ func (c *Client) getETag(suffix string) (value, etag string, err error) { // // If the requested metadata is not defined, the returned error will // be of type NotDefinedError. +// +// Deprecated: Please use the context aware variant [Client.GetWithContext]. func (c *Client) Get(suffix string) (string, error) { - val, _, err := c.getETag(suffix) + return c.GetWithContext(context.Background(), suffix) +} + +// GetWithContext returns a value from the metadata service. +// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/". +// +// If the GCE_METADATA_HOST environment variable is not defined, a default of +// 169.254.169.254 will be used instead. +// +// If the requested metadata is not defined, the returned error will +// be of type NotDefinedError. +func (c *Client) GetWithContext(ctx context.Context, suffix string) (string, error) { + val, _, err := c.getETag(ctx, suffix) return val, err } -func (c *Client) getTrimmed(suffix string) (s string, err error) { - s, err = c.Get(suffix) +func (c *Client) getTrimmed(ctx context.Context, suffix string) (s string, err error) { + s, err = c.GetWithContext(ctx, suffix) s = strings.TrimSpace(s) return } func (c *Client) lines(suffix string) ([]string, error) { - j, err := c.Get(suffix) + j, err := c.GetWithContext(context.Background(), suffix) if err != nil { return nil, err } @@ -388,7 +415,7 @@ func (c *Client) InstanceID() (string, error) { return instID.get(c) } // InternalIP returns the instance's primary internal IP address. func (c *Client) InternalIP() (string, error) { - return c.getTrimmed("instance/network-interfaces/0/ip") + return c.getTrimmed(context.Background(), "instance/network-interfaces/0/ip") } // Email returns the email address associated with the service account. @@ -398,25 +425,25 @@ func (c *Client) Email(serviceAccount string) (string, error) { if serviceAccount == "" { serviceAccount = "default" } - return c.getTrimmed("instance/service-accounts/" + serviceAccount + "/email") + return c.getTrimmed(context.Background(), "instance/service-accounts/"+serviceAccount+"/email") } // ExternalIP returns the instance's primary external (public) IP address. func (c *Client) ExternalIP() (string, error) { - return c.getTrimmed("instance/network-interfaces/0/access-configs/0/external-ip") + return c.getTrimmed(context.Background(), "instance/network-interfaces/0/access-configs/0/external-ip") } // Hostname returns the instance's hostname. This will be of the form // ".c..internal". func (c *Client) Hostname() (string, error) { - return c.getTrimmed("instance/hostname") + return c.getTrimmed(context.Background(), "instance/hostname") } // InstanceTags returns the list of user-defined instance tags, // assigned when initially creating a GCE instance. func (c *Client) InstanceTags() ([]string, error) { var s []string - j, err := c.Get("instance/tags") + j, err := c.GetWithContext(context.Background(), "instance/tags") if err != nil { return nil, err } @@ -428,12 +455,12 @@ func (c *Client) InstanceTags() ([]string, error) { // InstanceName returns the current VM's instance ID string. func (c *Client) InstanceName() (string, error) { - return c.getTrimmed("instance/name") + return c.getTrimmed(context.Background(), "instance/name") } // Zone returns the current VM's zone, such as "us-central1-b". func (c *Client) Zone() (string, error) { - zone, err := c.getTrimmed("instance/zone") + zone, err := c.getTrimmed(context.Background(), "instance/zone") // zone is of the form "projects//zones/". if err != nil { return "", err @@ -460,7 +487,7 @@ func (c *Client) ProjectAttributes() ([]string, error) { return c.lines("project // InstanceAttributeValue may return ("", nil) if the attribute was // defined to be the empty string. func (c *Client) InstanceAttributeValue(attr string) (string, error) { - return c.Get("instance/attributes/" + attr) + return c.GetWithContext(context.Background(), "instance/attributes/"+attr) } // ProjectAttributeValue returns the value of the provided @@ -472,7 +499,7 @@ func (c *Client) InstanceAttributeValue(attr string) (string, error) { // ProjectAttributeValue may return ("", nil) if the attribute was // defined to be the empty string. func (c *Client) ProjectAttributeValue(attr string) (string, error) { - return c.Get("project/attributes/" + attr) + return c.GetWithContext(context.Background(), "project/attributes/"+attr) } // Scopes returns the service account scopes for the given account. @@ -489,21 +516,30 @@ func (c *Client) Scopes(serviceAccount string) ([]string, error) { // The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/". // The suffix may contain query parameters. // -// Subscribe calls fn with the latest metadata value indicated by the provided -// suffix. If the metadata value is deleted, fn is called with the empty string -// and ok false. Subscribe blocks until fn returns a non-nil error or the value -// is deleted. Subscribe returns the error value returned from the last call to -// fn, which may be nil when ok == false. +// Deprecated: Please use the context aware variant [Client.SubscribeWithContext]. func (c *Client) Subscribe(suffix string, fn func(v string, ok bool) error) error { + return c.SubscribeWithContext(context.Background(), suffix, func(ctx context.Context, v string, ok bool) error { return fn(v, ok) }) +} + +// SubscribeWithContext subscribes to a value from the metadata service. +// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/". +// The suffix may contain query parameters. +// +// SubscribeWithContext calls fn with the latest metadata value indicated by the +// provided suffix. If the metadata value is deleted, fn is called with the +// empty string and ok false. Subscribe blocks until fn returns a non-nil error +// or the value is deleted. Subscribe returns the error value returned from the +// last call to fn, which may be nil when ok == false. +func (c *Client) SubscribeWithContext(ctx context.Context, suffix string, fn func(ctx context.Context, v string, ok bool) error) error { const failedSubscribeSleep = time.Second * 5 // First check to see if the metadata value exists at all. - val, lastETag, err := c.getETag(suffix) + val, lastETag, err := c.getETag(ctx, suffix) if err != nil { return err } - if err := fn(val, true); err != nil { + if err := fn(ctx, val, true); err != nil { return err } @@ -514,7 +550,7 @@ func (c *Client) Subscribe(suffix string, fn func(v string, ok bool) error) erro suffix += "?wait_for_change=true&last_etag=" } for { - val, etag, err := c.getETag(suffix + url.QueryEscape(lastETag)) + val, etag, err := c.getETag(ctx, suffix+url.QueryEscape(lastETag)) if err != nil { if _, deleted := err.(NotDefinedError); !deleted { time.Sleep(failedSubscribeSleep) @@ -524,7 +560,7 @@ func (c *Client) Subscribe(suffix string, fn func(v string, ok bool) error) erro } lastETag = etag - if err := fn(val, ok); err != nil || !ok { + if err := fn(ctx, val, ok); err != nil || !ok { return err } } diff --git a/compute/metadata/metadata_test.go b/compute/metadata/metadata_test.go index 281ca1c4d7b1..16b90db5bf2c 100644 --- a/compute/metadata/metadata_test.go +++ b/compute/metadata/metadata_test.go @@ -16,6 +16,7 @@ package metadata import ( "bytes" + "context" "io" "log" "net/http" @@ -23,6 +24,7 @@ import ( "strings" "sync" "testing" + "time" ) func TestOnGCE_Stress(t *testing.T) { @@ -53,21 +55,23 @@ func TestOnGCE_Force(t *testing.T) { } func TestOverrideUserAgent(t *testing.T) { + ctx := context.Background() const userAgent = "my-user-agent" rt := &rrt{} c := NewClient(&http.Client{Transport: userAgentTransport{userAgent, rt}}) - c.Get("foo") + c.GetWithContext(ctx, "foo") if got, want := rt.gotUserAgent, userAgent; got != want { t.Errorf("got %q, want %q", got, want) } } func TestGetFailsOnBadURL(t *testing.T) { + ctx := context.Background() c := NewClient(http.DefaultClient) old := os.Getenv(metadataHostEnv) defer os.Setenv(metadataHostEnv, old) os.Setenv(metadataHostEnv, "host:-1") - _, err := c.Get("suffix") + _, err := c.GetWithContext(ctx, "suffix") log.Printf("%v", err) if err == nil { t.Errorf("got %v, want non-nil error", err) @@ -91,9 +95,10 @@ func TestGet_LeadingSlash(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() ct := &captureTransport{} c := NewClient(&http.Client{Transport: ct}) - c.Get(tc.suffix) + c.GetWithContext(ctx, tc.suffix) if ct.url != want { t.Fatalf("got %v, want %v", ct.url, want) } @@ -163,6 +168,7 @@ func TestRetry(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() ft := &failingTransport{ timesToFail: tt.timesToFail, failCode: tt.failCode, @@ -170,7 +176,7 @@ func TestRetry(t *testing.T) { response: tt.response, } c := NewClient(&http.Client{Transport: ft}) - s, err := c.Get("") + s, err := c.GetWithContext(ctx, "") if tt.expectError && err == nil { t.Fatalf("did not receive expected error") } else if !tt.expectError && err != nil { @@ -192,6 +198,51 @@ func TestRetry(t *testing.T) { } } +func TestClientGetWithContext(t *testing.T) { + tests := []struct { + name string + ctxTimeout time.Duration + wantErr bool + }{ + { + name: "ok", + ctxTimeout: 1 * time.Second, + }, + { + name: "times out", + ctxTimeout: 200 * time.Millisecond, + wantErr: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.ctxTimeout) + defer cancel() + c := NewClient(&http.Client{Transport: sleepyTransport{}}) + _, err := c.GetWithContext(ctx, "foo") + if tc.wantErr && err == nil { + t.Fatal("c.GetWithContext() == nil, want an error") + } + if !tc.wantErr && err != nil { + t.Fatalf("c.GetWithContext() = %v, want nil", err) + } + }) + } +} + +type sleepyTransport struct { +} + +func (s sleepyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Context().Done() + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(500 * time.Millisecond): + } + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("I woke up"))}, nil +} + type failingTransport struct { timesToFail int failCode int