-
Notifications
You must be signed in to change notification settings - Fork 46
/
installer.go
350 lines (316 loc) · 10.8 KB
/
installer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Package installer provides functionality to install binary components of supported kubernetes versions.
package installer
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path"
"slices"
"time"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions/components"
"github.com/spf13/afero"
"github.com/vincent-petithory/dataurl"
"k8s.io/utils/clock"
)
const (
// determines the period after which retryDownloadToTempDir will retry a download.
downloadInterval = 10 * time.Millisecond
executablePerm = 0o544
)
// OsInstaller installs binary components of supported kubernetes versions.
type OsInstaller struct {
fs *afero.Afero
hClient httpClient
// clock is needed for testing purposes
clock clock.WithTicker
// retriable is the function used to check if an error is retriable. Needed for testing.
retriable func(error) bool
}
// NewOSInstaller creates a new osInstaller.
func NewOSInstaller() *OsInstaller {
return &OsInstaller{
fs: &afero.Afero{Fs: afero.NewOsFs()},
hClient: &http.Client{},
clock: clock.RealClock{},
retriable: isRetriable,
}
}
// Install downloads a resource from a URL, applies any given text transformations and extracts the resulting file if required.
// The resulting file(s) are copied to the destination. It also verifies the sha256 hash of the downloaded file.
func (i *OsInstaller) Install(ctx context.Context, kubernetesComponent *components.Component) error {
tempPath, err := i.retryDownloadToTempDir(ctx, kubernetesComponent.Url)
if err != nil {
return err
}
file, err := i.fs.OpenFile(tempPath, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("opening file %q: %w", tempPath, err)
}
sha := sha256.New()
if _, err := io.Copy(sha, file); err != nil {
return fmt.Errorf("reading file %q: %w", tempPath, err)
}
calculatedHash := fmt.Sprintf("sha256:%x", sha.Sum(nil))
if len(kubernetesComponent.Hash) > 0 && calculatedHash != kubernetesComponent.Hash {
return fmt.Errorf("hash of file %q %s does not match expected hash %s", tempPath, calculatedHash, kubernetesComponent.Hash)
}
defer func() {
_ = i.fs.Remove(tempPath)
}()
if kubernetesComponent.Extract {
err = i.extractArchive(tempPath, kubernetesComponent.InstallPath, executablePerm)
} else {
err = i.copy(tempPath, kubernetesComponent.InstallPath, executablePerm)
}
if err != nil {
return fmt.Errorf("installing from %q: copying to destination %q: %w", kubernetesComponent.Url, kubernetesComponent.InstallPath, err)
}
return nil
}
// extractArchive extracts tar gz archives to a prefixed destination.
func (i *OsInstaller) extractArchive(archivePath, prefix string, perm fs.FileMode) error {
archiveFile, err := i.fs.Open(archivePath)
if err != nil {
return fmt.Errorf("opening archive file: %w", err)
}
defer archiveFile.Close()
gzReader, err := gzip.NewReader(archiveFile)
if err != nil {
return fmt.Errorf("reading archive file as gzip: %w", err)
}
defer gzReader.Close()
if err := i.fs.MkdirAll(prefix, fs.ModePerm); err != nil {
return fmt.Errorf("creating prefix folder: %w", err)
}
tarReader := tar.NewReader(gzReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("parsing tar header: %w", err)
}
if err := verifyTarPath(header.Name); err != nil {
return fmt.Errorf("verifying tar path %q: %w", header.Name, err)
}
switch header.Typeflag {
case tar.TypeDir:
if len(header.Name) == 0 {
return errors.New("cannot create dir for empty path")
}
prefixedPath := path.Join(prefix, header.Name)
if err := i.fs.Mkdir(prefixedPath, fs.FileMode(header.Mode)&perm); err != nil && !errors.Is(err, os.ErrExist) {
return fmt.Errorf("creating folder %q: %w", prefixedPath, err)
}
case tar.TypeReg:
if len(header.Name) == 0 {
return errors.New("cannot create file for empty path")
}
prefixedPath := path.Join(prefix, header.Name)
out, err := i.fs.OpenFile(prefixedPath, os.O_WRONLY|os.O_CREATE, fs.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("creating file %q for writing: %w", prefixedPath, err)
}
defer out.Close()
if _, err := io.Copy(out, tarReader); err != nil {
return fmt.Errorf("writing extracted file contents: %w", err)
}
case tar.TypeSymlink:
if err := verifyTarPath(header.Linkname); err != nil {
return fmt.Errorf("invalid tar path %q: %w", header.Linkname, err)
}
if len(header.Name) == 0 {
return errors.New("cannot symlink file for empty oldname")
}
if len(header.Linkname) == 0 {
return errors.New("cannot symlink file for empty newname")
}
if symlinker, ok := i.fs.Fs.(afero.Symlinker); ok {
if err := symlinker.SymlinkIfPossible(path.Join(prefix, header.Name), path.Join(prefix, header.Linkname)); err != nil {
return fmt.Errorf("creating symlink: %w", err)
}
} else {
return errors.New("fs does not support symlinks")
}
default:
return fmt.Errorf("unsupported tar record: %v", header.Typeflag)
}
}
}
func (i *OsInstaller) retryDownloadToTempDir(ctx context.Context, url string) (fileName string, someError error) {
doer := downloadDoer{
url: url,
downloader: i,
}
// Retries are canceled as soon as the context is canceled.
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
if err := retrier.Do(ctx); err != nil {
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
}
return doer.path, nil
}
// retriableHTTPStatusCodes are status codes that might flip to 200 if retried.
// This arguably depends on the web server implementation, but below list is
// a reasonable selection, cf. https://stackoverflow.com/a/74627395.
var retriableHTTPStatusCodes = []int{
http.StatusRequestTimeout,
http.StatusTooEarly,
http.StatusTooManyRequests,
http.StatusBadGateway,
http.StatusServiceUnavailable,
http.StatusGatewayTimeout,
}
// downloadHTTP downloads the given URL with the embedded HTTP client and writes the content to out.
func (i *OsInstaller) downloadHTTP(ctx context.Context, url string, out io.Writer) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("request to download %q: %w", url, err)
}
resp, err := i.hClient.Do(req)
if err != nil {
// A failure at this point might be transient, such as network connectivity.
return fmt.Errorf("request to download %q: %w", url, &retriableError{err: err})
}
if resp.StatusCode != http.StatusOK {
// The HTTP request went through, but the result is not what we
// expected. Wrap the error return in case we think the request could
// be retried.
err = fmt.Errorf("request to download %q failed with status code: %v", url, resp.Status)
if slices.Contains(retriableHTTPStatusCodes, resp.StatusCode) {
err = &retriableError{err: err}
}
return err
}
defer resp.Body.Close()
if _, err = io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("downloading %q: %w", url, &retriableError{err: err})
}
return nil
}
// unpackData parses the given data URL and writes the content to out.
func (i *OsInstaller) unpackData(url string, out io.Writer) error {
dataURL, err := dataurl.DecodeString(url)
if err != nil {
return fmt.Errorf("parsing data URL: %w", err)
}
buf := bytes.NewBuffer(dataURL.Data)
if _, err = io.Copy(out, buf); err != nil {
return fmt.Errorf("writing content of data URL %q: %w", url, err)
}
return nil
}
// downloadToTempDir downloads a file from the given URL to a temporary location and returns the path to the downloaded file.
func (i *OsInstaller) downloadToTempDir(ctx context.Context, u string) (string, error) {
url, err := url.Parse(u)
if err != nil {
return "", fmt.Errorf("parsing component URL: %w", err)
}
out, err := afero.TempFile(i.fs, "", "")
if err != nil {
return "", fmt.Errorf("creating destination temp file: %w", err)
}
if url.Scheme == "data" {
err = i.unpackData(u, out)
} else {
err = i.downloadHTTP(ctx, u, out)
}
out.Close()
if err != nil {
removeErr := i.fs.Remove(out.Name())
return "", errors.Join(err, removeErr)
}
return out.Name(), nil
}
// copy copies a file from oldname to newname.
func (i *OsInstaller) copy(oldname, newname string, perm fs.FileMode) (err error) {
old, openOldErr := i.fs.OpenFile(oldname, os.O_RDONLY, fs.ModePerm)
if openOldErr != nil {
return fmt.Errorf("copying %q to %q: cannot open source file for reading: %w", oldname, newname, openOldErr)
}
defer func() { _ = old.Close() }()
// create destination path if not exists
if err := i.fs.MkdirAll(path.Dir(newname), fs.ModePerm); err != nil {
return fmt.Errorf("copying %q to %q: unable to create destination folder: %w", oldname, newname, err)
}
newFile, openNewErr := i.fs.OpenFile(newname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, perm)
if openNewErr != nil {
return fmt.Errorf("copying %q to %q: cannot open destination file for writing: %w", oldname, newname, openNewErr)
}
defer func() {
_ = newFile.Close()
if err != nil {
_ = i.fs.Remove(newname)
}
}()
if _, err := io.Copy(newFile, old); err != nil {
return fmt.Errorf("copying %q to %q: copying file contents: %w", oldname, newname, err)
}
return nil
}
type downloadDoer struct {
url string
downloader downloader
path string
}
type downloader interface {
downloadToTempDir(ctx context.Context, url string) (string, error)
}
func (d *downloadDoer) Do(ctx context.Context) error {
path, err := d.downloader.downloadToTempDir(ctx, d.url)
d.path = path
return err
}
// retriableError is an error that can be retried.
type retriableError struct{ err error }
func (e *retriableError) Error() string {
return fmt.Sprintf("retriable error: %s", e.err.Error())
}
func (e *retriableError) Unwrap() error { return e.err }
// isRetriable returns true if the action resulting in this error can be retried.
func isRetriable(err error) bool {
retriableError := &retriableError{}
return errors.As(err, &retriableError)
}
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
func verifyTarPath(pat string) error {
n := len(pat)
r := 0
for r < n {
switch {
case os.IsPathSeparator(pat[r]):
// empty path element
r++
case pat[r] == '.' && (r+1 == n || os.IsPathSeparator(pat[r+1])):
// . element
r++
case pat[r] == '.' && pat[r+1] == '.' && (r+2 == n || os.IsPathSeparator(pat[r+2])):
// .. element
return errors.New("path contains \"..\"")
default:
// skip to next path element
for r < n && !os.IsPathSeparator(pat[r]) {
r++
}
}
}
return nil
}
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}