diff --git a/pkg/ghet/download/checksums.go b/pkg/ghet/download/checksums.go new file mode 100644 index 0000000..5b877ba --- /dev/null +++ b/pkg/ghet/download/checksums.go @@ -0,0 +1,267 @@ +package download + +import ( + "bufio" + "context" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "fmt" + "hash" + "io" + "os" + "path" + "regexp" + "strconv" + "strings" + + githubapi "github.com/cardil/ghet/pkg/github/api" + "github.com/cardil/ghet/pkg/output" + "github.com/cardil/ghet/pkg/output/tui" + slog "github.com/go-eden/slf4go" + "github.com/gookit/color" + "github.com/pkg/errors" +) + +// ErrTooManyChecksums is returned when there are more than one checksums. +var ErrTooManyChecksums = errors.New("too many checksums") + +// ErrUnknownChecksumAlgorithm is returned when the checksum algorithm is unknown. +var ErrUnknownChecksumAlgorithm = errors.New("unknown checksum algorithm") + +// ErrChecksumMismatch is returned when the checksum does not match. +var ErrChecksumMismatch = errors.New("checksum mismatch") + +// ErrNotVerifiedAssets is returned when there are no verified assets. +var ErrNotVerifiedAssets = errors.New("not verified assets") + +var bsdStyleChecksums = regexp.MustCompile(`^(SHA[0-9]{1,3})\s+\([^)]+\)\s+=\s+([a-fA-F0-9]{32,128})$`) + +func verifyChecksums(ctx context.Context, assets []githubapi.Asset, args Args) error { + l := output.LoggerFrom(ctx) + widgets := tui.WidgetsFrom(ctx) + index := githubapi.CreateIndex(assets) + if len(index.Checksums) == 0 { + l.Debug("No checksums to verify") + widgets.Printf(ctx, "🕵 No checksums to verify") + return nil + } + if len(index.Checksums) > 1 { + l.Errorf("Number of checksums is %d. Expected just one.", len(index.Checksums)) + return fmt.Errorf("%w: %d", ErrTooManyChecksums, len(index.Checksums)) + } + + ca := index.Checksums[0] + l = l.WithFields(slog.Fields{"checksum": ca.Name}) + l.Debug("Verifying checksum") + + csp := checksumParser{Asset: ca, Args: args} + if cs, err := csp.parse(ctx); err != nil { + return err + } else { + err = cs.verify(ctx, append(index.Archives, index.Other...), args.Destination) + if err != nil { + return err + } + } + + widgets.Printf(ctx, "✅ All checksums match the downloaded assets") + + l.Debugf("Deleting the checksums file(s): %q", index.Checksums) + for _, c := range index.Checksums { + if err := os.Remove(path.Join(args.Destination, c.Name)); err != nil { + return unexpected(err) + } + } + + return nil +} + +type checksumParser struct { + githubapi.Asset + Args + *checksums +} + +func (p *checksumParser) parse(ctx context.Context) (*checksums, error) { + l := output.LoggerFrom(ctx) + fp := path.Join(p.Destination, p.Name) + l.Debugf("Parsing checksum: %s", fp) + if _, ferr := os.Stat(fp); ferr != nil { + return nil, unexpected(ferr) + } + file, ferr := os.Open(fp) + if ferr != nil { + return nil, unexpected(ferr) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + p.checksums = &checksums{ + entries: make([]checksumEntry, 0, 1), + } + for scanner.Scan() { + if err := p.parseLine(ctx, scanner.Text()); err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + return nil, unexpected(err) + } + + return p.checksums, nil +} + +func (p *checksumParser) parseLine(ctx context.Context, line string) error { + var entry checksumEntry + if bsdStyleChecksums.MatchString(line) { + entry = p.parseBSDStyleChecksum(ctx, line) + } else { + if e, err := p.parseRegularChecksum(line); err != nil { + return err + } else { + entry = e + } + } + p.checksums.entries = append(p.checksums.entries, entry) + return nil +} + +func (p *checksumParser) parseRegularChecksum(line string) (checksumEntry, error) { + fields := strings.Fields(line) + if len(fields) != 2 { + return checksumEntry{}, unexpected(fmt.Errorf("invalid checksum line: %s", line)) + } + + entry := checksumEntry{ + hash: fields[0], + filename: fields[1], + } + if algo, err := checksumAlgorithmForHash(entry.hash); err != nil { + return checksumEntry{}, err + } else { + entry.checksumAlgorithm = algo + } + return entry, nil +} + +func (p *checksumParser) parseBSDStyleChecksum(_ context.Context, line string) checksumEntry { + match := bsdStyleChecksums.FindStringSubmatch(line) + return checksumEntry{ + hash: match[3], + filename: match[2], + checksumAlgorithm: checksumAlgorithm(match[1]), + } +} + +type checksumAlgorithm string + +const ( + checksumAlgorithmSHA1 checksumAlgorithm = "SHA1" + checksumAlgorithmSHA224 checksumAlgorithm = "SHA224" + checksumAlgorithmSHA256 checksumAlgorithm = "SHA256" + checksumAlgorithmSHA384 checksumAlgorithm = "SHA384" + checksumAlgorithmSHA512 checksumAlgorithm = "SHA512" +) + +func (a checksumAlgorithm) bytesLen() int { + if a == checksumAlgorithmSHA1 { + return 20 + } + + i, err := strconv.Atoi(strings.TrimPrefix(string(a), "SHA")) + if err != nil { + panic(err) + } + return i / 8 +} + +func (a checksumAlgorithm) newDigest() hash.Hash { + switch a { + case checksumAlgorithmSHA1: + return sha1.New() + case checksumAlgorithmSHA224: + return sha256.New224() + case checksumAlgorithmSHA256: + return sha256.New() + case checksumAlgorithmSHA384: + return sha512.New384() + case checksumAlgorithmSHA512: + return sha512.New() + } + panic("unexpected checksum algorithm: " + a) +} + +func checksumAlgorithmForHash(hash string) (checksumAlgorithm, error) { + algs := []checksumAlgorithm{ + checksumAlgorithmSHA1, checksumAlgorithmSHA224, checksumAlgorithmSHA256, + checksumAlgorithmSHA384, checksumAlgorithmSHA512, + } + for _, alg := range algs { + if alg.bytesLen()*2 == len(hash) { + return alg, nil + } + } + return "", fmt.Errorf("%w: %s", ErrUnknownChecksumAlgorithm, hash) +} + +type checksumEntry struct { + checksumAlgorithm + hash string + filename string +} + +func (e checksumEntry) verify(_ context.Context, asset githubapi.Asset, dest string) error { + dig := e.newDigest() + fp := path.Join(dest, asset.Name) + var reader io.Reader + if f, err := os.Open(fp); err != nil { + return unexpected(err) + } else { + defer f.Close() + reader = bufio.NewReader(f) + } + if _, err := io.Copy(dig, reader); err != nil { + return unexpected(err) + } + actual := hex.EncodeToString(dig.Sum(nil)) + if actual != e.hash { + return fmt.Errorf("%w: %s, %s != %s", + ErrChecksumMismatch, asset.Name, actual, e.hash) + } + return nil +} + +type checksums struct { + entries []checksumEntry +} + +func (c checksums) verify(ctx context.Context, assets []githubapi.Asset, dest string) error { + widgets := tui.WidgetsFrom(ctx) + for _, entry := range c.entries { + for i, curr := range assets { + if entry.filename == curr.Name { + spin := widgets.NewSpinner(ctx, fmt.Sprintf("🔍 Verifying checksum for %s", + color.Cyan.Sprintf(curr.Name))) + if err := spin.With(func(_ tui.Spinner) error { + if err := entry.verify(ctx, curr, dest); err != nil { + return err + } + return nil + }); err != nil { + return err + } + assets = append(assets[:i], assets[i+1:]...) + break + } + } + } + + if len(assets) > 0 { + return errors.WithStack(fmt.Errorf("%w: %q", ErrNotVerifiedAssets, assets)) + } + + return nil +} diff --git a/pkg/ghet/download/download.go b/pkg/ghet/download/download.go index 3faee28..790199a 100644 --- a/pkg/ghet/download/download.go +++ b/pkg/ghet/download/download.go @@ -43,7 +43,7 @@ func downloadAsset(ctx context.Context, asset assetInfo, args Args) error { if fileExists(l, cachePath, asset.Size) { l.WithFields(slog.Fields{"cachePath": cachePath}). Debug("Asset already downloaded") - return copyFile(cachePath, asset.Asset, args) + return moveFile(cachePath, asset.Asset, args) } l.Debug("Downloading asset") @@ -84,10 +84,10 @@ func downloadAsset(ctx context.Context, asset assetInfo, args Args) error { }); perr != nil { return perr } - return copyFile(cachePath, asset.Asset, args) + return moveFile(cachePath, asset.Asset, args) } -func copyFile(cachePath string, asset githubapi.Asset, args Args) error { +func moveFile(cachePath string, asset githubapi.Asset, args Args) error { if err := os.MkdirAll(args.Destination, executableMode); err != nil { return errors.WithStack(err) } diff --git a/pkg/ghet/download/errors.go b/pkg/ghet/download/errors.go new file mode 100644 index 0000000..cdefdbf --- /dev/null +++ b/pkg/ghet/download/errors.go @@ -0,0 +1,14 @@ +package download + +import ( + "fmt" + + "github.com/pkg/errors" +) + +// ErrUnexpected is returned when an unexpected error occurs. +var ErrUnexpected = errors.New("unexpected error") + +func unexpected(err error) error { + return errors.WithStack(fmt.Errorf("%w: %v", ErrUnexpected, err)) +} diff --git a/pkg/ghet/download/plan.go b/pkg/ghet/download/plan.go index 1765017..0dd34cc 100644 --- a/pkg/ghet/download/plan.go +++ b/pkg/ghet/download/plan.go @@ -72,7 +72,8 @@ func CreatePlan(ctx context.Context, args Args) (*Plan, error) { } plan := &Plan{Assets: assets} log.WithFields(slog.Fields{"plan": plan}).Debug("Plan created") - widgets.Printf(ctx, "🎉 Found %s matching assets", color.Cyan.Sprint(len(assets))) + widgets.Printf(ctx, "🎉 Found %s matching assets for %s", + color.Cyan.Sprint(len(assets)), color.Cyan.Sprintf(rr.GetTagName())) return plan, nil } @@ -100,6 +101,9 @@ func (p Plan) Download(ctx context.Context, args Args) error { return err } } + if err := verifyChecksums(ctx, p.Assets, args); err != nil { + return err + } return nil }