Skip to content

Commit

Permalink
Add digest functions flag
Browse files Browse the repository at this point in the history
  • Loading branch information
AlessandroPatti committed May 13, 2024
1 parent 340c988 commit 9c40d67
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 26 deletions.
2 changes: 1 addition & 1 deletion cache/grpcproxy/grpcproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func newFixture(t *testing.T, proxy cache.Proxy, storageMode string) *fixture {
}
grpcServer := grpc.NewServer()
go func() {
err := server.ServeGRPC(listener, grpcServer, false, false, true, diskCache, logger, logger)
err := server.ServeGRPC(listener, grpcServer, false, false, true, diskCache, logger, logger, hashing.DigestFunctions())
if err != nil {
logger.Printf(err.Error())
}
Expand Down
2 changes: 2 additions & 0 deletions config/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ go_library(
"//cache/azblobproxy:go_default_library",
"//cache/gcsproxy:go_default_library",
"//cache/grpcproxy:go_default_library",
"//cache/hashing:go_default_library",
"//cache/httpproxy:go_default_library",
"//cache/s3proxy:go_default_library",
"//genproto/build/bazel/remote/execution/v2:go_default_library",
"@com_github_azure_azure_sdk_for_go_sdk_azcore//:go_default_library",
"@com_github_azure_azure_sdk_for_go_sdk_azidentity//:go_default_library",
"@com_github_grpc_ecosystem_go_grpc_prometheus//:go_default_library",
Expand Down
42 changes: 41 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (

"github.com/buchgr/bazel-remote/v2/cache"
"github.com/buchgr/bazel-remote/v2/cache/azblobproxy"
"github.com/buchgr/bazel-remote/v2/cache/hashing"
"github.com/buchgr/bazel-remote/v2/cache/s3proxy"
pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2"

"github.com/urfave/cli/v2"
yaml "gopkg.in/yaml.v3"
Expand Down Expand Up @@ -114,6 +116,7 @@ type Config struct {
LogTimezone string `yaml:"log_timezone"`
MaxBlobSize int64 `yaml:"max_blob_size"`
MaxProxyBlobSize int64 `yaml:"max_proxy_blob_size"`
DigestFunctions []pb.DigestFunction_Value

// Fields that are created by combinations of the flags above.
ProxyBackend cache.Proxy
Expand All @@ -125,6 +128,9 @@ type Config struct {
type YamlConfig struct {
Config `yaml:",inline"`

// Complext types that are converted later
DigestFunctionNames []string `yaml:"digest_functions"`

// Deprecated fields, retained for backwards compatibility when
// parsing config files.

Expand Down Expand Up @@ -169,7 +175,8 @@ func newFromArgs(dir string, maxSize int, storageMode string, zstdImplementation
accessLogLevel string,
logTimezone string,
maxBlobSize int64,
maxProxyBlobSize int64) (*Config, error) {
maxProxyBlobSize int64,
digestFunctions []pb.DigestFunction_Value) (*Config, error) {

c := Config{
HTTPAddress: httpAddress,
Expand Down Expand Up @@ -205,6 +212,7 @@ func newFromArgs(dir string, maxSize int, storageMode string, zstdImplementation
LogTimezone: logTimezone,
MaxBlobSize: maxBlobSize,
MaxProxyBlobSize: maxProxyBlobSize,
DigestFunctions: digestFunctions,
}

err := validateConfig(&c)
Expand Down Expand Up @@ -234,6 +242,7 @@ func newFromYamlFile(path string) (*Config, error) {

func newFromYaml(data []byte) (*Config, error) {
yc := YamlConfig{
DigestFunctionNames: []string{"sha256"},
Config: Config{
StorageMode: "zstd",
ZstdImplementation: "go",
Expand Down Expand Up @@ -270,6 +279,16 @@ func newFromYaml(data []byte) (*Config, error) {
sort.Float64s(c.MetricsDurationBuckets)
}

dfs := make([]pb.DigestFunction_Value, 0)
for _, dfn := range yc.DigestFunctionNames {
df := hashing.DigestFunction(dfn)
if df == pb.DigestFunction_UNKNOWN {
return nil, fmt.Errorf("unknown digest function %s", dfn)
}
dfs = append(dfs, hashing.DigestFunction(dfn))
}
c.DigestFunctions = dfs

err = validateConfig(&c)
if err != nil {
return nil, err
Expand Down Expand Up @@ -462,6 +481,15 @@ func validateConfig(c *Config) error {
return errors.New("'log_timezone' must be set to either \"UTC\", \"local\" or \"none\"")
}

if c.DigestFunctions == nil {
return errors.New("at least on digest function must be supported")
}
for _, df := range c.DigestFunctions {
if !hashing.Supported(df) {
return fmt.Errorf("unsupported hashing function %s", df)
}
}

return nil
}

Expand Down Expand Up @@ -590,6 +618,17 @@ func get(ctx *cli.Context) (*Config, error) {
}
}

dfs := make([]pb.DigestFunction_Value, 0)
if ctx.String("digest_functions") != "" {
for _, dfn := range strings.Split(ctx.String("digest_functions"), ",") {
df := hashing.DigestFunction(dfn)
if df == pb.DigestFunction_UNKNOWN {
return nil, fmt.Errorf("unknown digest function %s", dfn)
}
dfs = append(dfs, df)
}
}

return newFromArgs(
ctx.String("dir"),
ctx.Int("max_size"),
Expand Down Expand Up @@ -623,5 +662,6 @@ func get(ctx *cli.Context) (*Config, error) {
ctx.String("log_timezone"),
ctx.Int64("max_blob_size"),
ctx.Int64("max_proxy_blob_size"),
dfs,
)
}
81 changes: 81 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2"

"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -60,6 +62,7 @@ log_timezone: local
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "none",
LogTimezone: "local",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !reflect.DeepEqual(config, expectedConfig) {
Expand Down Expand Up @@ -103,6 +106,7 @@ gcs_proxy:
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -147,6 +151,7 @@ http_proxy:
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -224,6 +229,7 @@ s3_proxy:
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -258,6 +264,7 @@ profile_address: :7070
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -306,6 +313,7 @@ endpoint_metrics_duration_buckets: [.005, .1, 5]
MetricsDurationBuckets: []float64{0.005, 0.1, 5},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -438,6 +446,7 @@ storage_mode: zstd
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand Down Expand Up @@ -472,6 +481,7 @@ storage_mode: zstd
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
AccessLogLevel: "all",
LogTimezone: "UTC",
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
}

if !cmp.Equal(config, expectedConfig) {
Expand All @@ -495,3 +505,74 @@ func TestSocketPathMissing(t *testing.T) {
t.Fatal("Expected the error message to mention the missing 'http_address' key/flag")
}
}

func TestDigestFunctions(t *testing.T) {
t.Run("Default", func(t *testing.T) {
yaml := `dir: /opt/cache-dir
max_size: 42
`
config, err := newFromYaml([]byte(yaml))
if err != nil {
t.Fatal(err)
}
if len(config.DigestFunctions) != 1 {
t.Fatal("Expected exactly one digest function")
}
if config.DigestFunctions[0] != pb.DigestFunction_SHA256 {
t.Fatal("Expected sha256 digest function")
}
err = validateConfig(config)
if err != nil {
t.Fatal(err)
}
})

t.Run("Success", func(t *testing.T) {
yaml := `dir: /opt/cache-dir
max_size: 42
digest_functions: [sha256]
`
config, err := newFromYaml([]byte(yaml))
if err != nil {
t.Fatal(err)
}
if len(config.DigestFunctions) != 1 {
t.Fatal("Expected exactly one digest function")
}
if config.DigestFunctions[0] != pb.DigestFunction_SHA256 {
t.Fatal("Expected sha256 digest function")
}
err = validateConfig(config)
if err != nil {
t.Fatal(err)
}
})

t.Run("UnknownFunction", func(t *testing.T) {
yaml := `dir: /opt/cache-dir
max_size: 42
digest_functions: [sha256, foo]
`
_, err := newFromYaml([]byte(yaml))
if err == nil {
t.Fatal("Expected error")
}
if !strings.Contains(err.Error(), "unknown") {
t.Fatalf("Unexpected error: %s", err.Error())
}
})

t.Run("UnsupportedFunction", func(t *testing.T) {
yaml := `dir: /opt/cache-dir
max_size: 42
digest_functions: [md5]
`
_, err := newFromYaml([]byte(yaml))
if err == nil {
t.Fatal("Expected error")
}
if !strings.Contains(err.Error(), "unsupported") {
t.Fatalf("Unexpected error: %s", err.Error())
}
})
}
5 changes: 3 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func startHttpServer(c *config.Config, httpServer **http.Server,
checkClientCertForWrites := c.TLSCaFile != ""
validateAC := !c.DisableHTTPACValidation
h := server.NewHTTPCache(diskCache, c.AccessLogger, c.ErrorLogger, validateAC,
c.EnableACKeyInstanceMangling, checkClientCertForReads, checkClientCertForWrites, gitCommit)
c.EnableACKeyInstanceMangling, checkClientCertForReads, checkClientCertForWrites, gitCommit, c.DigestFunctions)

cacheHandler := h.CacheHandler
var basicAuthenticator auth.BasicAuth
Expand Down Expand Up @@ -429,7 +429,8 @@ func startGrpcServer(c *config.Config, grpcServer **grpc.Server,
validateAC,
c.EnableACKeyInstanceMangling,
enableRemoteAssetAPI,
diskCache, c.AccessLogger, c.ErrorLogger)
diskCache, c.AccessLogger, c.ErrorLogger,
c.DigestFunctions)
}

// A http.HandlerFunc wrapper which requires successful basic
Expand Down
33 changes: 23 additions & 10 deletions server/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"fmt"
"net"
"net/http"

Expand Down Expand Up @@ -30,11 +31,12 @@ import (
const grpcHealthServiceName = "/grpc.health.v1.Health/Check"

type grpcServer struct {
cache disk.Cache
accessLogger cache.Logger
errorLogger cache.Logger
depsCheck bool
mangleACKeys bool
cache disk.Cache
accessLogger cache.Logger
errorLogger cache.Logger
depsCheck bool
mangleACKeys bool
digestFunctions map[pb.DigestFunction_Value]bool
}

var readOnlyMethods = map[string]struct{}{
Expand All @@ -55,26 +57,33 @@ func ListenAndServeGRPC(
validateACDeps bool,
mangleACKeys bool,
enableRemoteAssetAPI bool,
c disk.Cache, a cache.Logger, e cache.Logger) error {
c disk.Cache, a cache.Logger, e cache.Logger,
digestFunctions []pb.DigestFunction_Value) error {

listener, err := net.Listen(network, addr)
if err != nil {
return err
}

return ServeGRPC(listener, srv, validateACDeps, mangleACKeys, enableRemoteAssetAPI, c, a, e)
return ServeGRPC(listener, srv, validateACDeps, mangleACKeys, enableRemoteAssetAPI, c, a, e, digestFunctions)
}

func ServeGRPC(l net.Listener, srv *grpc.Server,
validateACDepsCheck bool,
mangleACKeys bool,
enableRemoteAssetAPI bool,
c disk.Cache, a cache.Logger, e cache.Logger) error {
c disk.Cache, a cache.Logger, e cache.Logger,
digestFunctions []pb.DigestFunction_Value) error {

dfs := make(map[pb.DigestFunction_Value]bool)
for _, df := range digestFunctions {
dfs[df] = true
}
s := &grpcServer{
cache: c, accessLogger: a, errorLogger: e,
depsCheck: validateACDepsCheck,
mangleACKeys: mangleACKeys,
depsCheck: validateACDepsCheck,
mangleACKeys: mangleACKeys,
digestFunctions: dfs,
}
pb.RegisterActionCacheServer(srv, s)
pb.RegisterCapabilitiesServer(srv, s)
Expand Down Expand Up @@ -129,10 +138,14 @@ func (s *grpcServer) GetCapabilities(ctx context.Context,
func (s *grpcServer) getHasher(df pb.DigestFunction_Value) (hashing.Hasher, error) {
var err error
var hasher hashing.Hasher

switch df {
case pb.DigestFunction_UNKNOWN:
hasher, err = hashing.Get(hashing.LegacyFn)
default:
if _, ok := s.digestFunctions[df]; !ok {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("unsupported digest function %s", df))
}
hasher, err = hashing.Get(df)
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion server/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func grpcTestSetupInternal(t *testing.T, mangleACKeys bool) (tc grpcTestFixture)
validateAC,
mangleACKeys,
enableRemoteAssetAPI,
diskCache, accessLogger, errorLogger)
diskCache, accessLogger, errorLogger, hashing.DigestFunctions())
if err2 != nil {
fmt.Println(err2)
os.Exit(1)
Expand Down

0 comments on commit 9c40d67

Please sign in to comment.