Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions pkg/config/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import (
"log"

"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)

// ProviderType storage provider type
Expand All @@ -35,7 +38,7 @@ const (
OBS ProviderType = "obs"
// BOS baidu bos
BOS ProviderType = "bos"
// GCS google gcs fixme:not tested
// GCS google gcs
GCS ProviderType = "gcs"
// KS3 kingsoft ks3
KS3 ProviderType = "ks3"
Expand Down Expand Up @@ -63,6 +66,16 @@ var endpointResolverFunc = func(urlTemplate, signingMethod string) s3.EndpointRe
}
}

var endpointResolverFuncGCS = func(urlTemplate, signingMethod string) s3.EndpointResolverFunc {
return func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
return aws.Endpoint{
URL: urlTemplate,
SigningRegion: region,
SigningMethod: signingMethod,
}, nil
}
}

// Storage storage
type Storage struct {
Type ProviderType `yaml:"type"`
Expand All @@ -77,11 +90,17 @@ type Storage struct {
// Init init
func (o *Storage) Init() {
var endpointResolver s3.EndpointResolver
if o.Type != S3 {
switch o.Type {
case GCS:
endpointResolver = s3.EndpointResolverFromURL(URLTemplate[GCS])
o.Region = "auto"
case S3:
default:
if urlTemplate, exist := URLTemplate[o.Type]; exist && urlTemplate != "" {
endpointResolver = endpointResolverFunc(urlTemplate, o.SigningMethod)
}
}

if o.Region == "" || o.AccessKeyID == "" || o.SecretAccessKey == "" {
//use default config
cfg, err := config.LoadDefaultConfig(context.TODO())
Expand All @@ -101,12 +120,69 @@ func (o *Storage) Init() {
}, nil
}),
EndpointResolver: endpointResolver,
}, func(o *s3.Options) {
o.EndpointOptions.DisableHTTPS = true
}, func(s3Options *s3.Options) {
switch o.Type {
case GCS:
s3Options.APIOptions = append(s3Options.APIOptions, func(stack *middleware.Stack) error {
if err := stack.Finalize.Insert(dropAcceptEncodingHeader, "Signing", middleware.Before); err != nil {
return err
}

if err := stack.Finalize.Insert(replaceAcceptEncodingHeader, "Signing", middleware.After); err != nil {
return err
}

return nil
})
}
})
}

// GetClient get client
func (o *Storage) GetClient() *s3.Client {
return o.client
}

const acceptEncodingHeader = "Accept-Encoding"

type acceptEncodingKey struct{}

func GetAcceptEncodingKey(ctx context.Context) (v string) {
v, _ = middleware.GetStackValue(ctx, acceptEncodingKey{}).(string)
return v
}

func SetAcceptEncodingKey(ctx context.Context, value string) context.Context {
return middleware.WithStackValue(ctx, acceptEncodingKey{}, value)
}

var dropAcceptEncodingHeader = middleware.FinalizeMiddlewareFunc("DropAcceptEncodingHeader",
func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, &v4.SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
}

ae := req.Header.Get(acceptEncodingHeader)
ctx = SetAcceptEncodingKey(ctx, ae)
req.Header.Del(acceptEncodingHeader)
in.Request = req

return next.HandleFinalize(ctx, in)
},
)

var replaceAcceptEncodingHeader = middleware.FinalizeMiddlewareFunc("ReplaceAcceptEncodingHeader",
func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, &v4.SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
}

ae := GetAcceptEncodingKey(ctx)
req.Header.Set(acceptEncodingHeader, ae)
in.Request = req

return next.HandleFinalize(ctx, in)
},
)
63 changes: 54 additions & 9 deletions pkg/config/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package config
import (
"bytes"
"context"
"fmt"
"os"
"testing"

Expand All @@ -31,14 +32,47 @@ func TestStorage_Init(t *testing.T) {
fields fields
}{
{
name: "test",
name: "test-oss",
fields: fields{
Type: ProviderType(os.Getenv("s3_provider")),
Type: OSS,
SigningMethod: "v4",
Region: os.Getenv("s3_region"),
Bucket: os.Getenv("s3_bucket"),
AccessKeyID: os.Getenv("s3_access_key_id"),
SecretAccessKey: os.Getenv("s3_secret_access_key"),
Region: os.Getenv("oss_region"),
Bucket: os.Getenv("oss_bucket"),
AccessKeyID: os.Getenv("oss_access_key_id"),
SecretAccessKey: os.Getenv("oss_secret_access_key"),
},
},
{
name: "test-bos",
fields: fields{
Type: BOS,
SigningMethod: "v4",
Region: os.Getenv("bos_region"),
Bucket: os.Getenv("bos_bucket"),
AccessKeyID: os.Getenv("bos_access_key_id"),
SecretAccessKey: os.Getenv("bos_secret_access_key"),
},
},
{
name: "test-ks3",
fields: fields{
Type: KS3,
SigningMethod: "v4",
Region: os.Getenv("ks3_region"),
Bucket: os.Getenv("ks3_bucket"),
AccessKeyID: os.Getenv("ks3_access_key_id"),
SecretAccessKey: os.Getenv("ks3_secret_access_key"),
},
},
{
name: "test-kodo",
fields: fields{
Type: KODO,
SigningMethod: "v4",
Region: os.Getenv("kodo_region"),
Bucket: os.Getenv("kodo_bucket"),
AccessKeyID: os.Getenv("kodo_access_key_id"),
SecretAccessKey: os.Getenv("kodo_secret_access_key"),
},
},
}
Expand All @@ -53,13 +87,24 @@ func TestStorage_Init(t *testing.T) {
SecretAccessKey: tt.fields.SecretAccessKey,
}
o.Init()
_, err := o.GetClient().PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String("mss-boot-io"),
res, err := o.GetClient().ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{
Bucket: aws.String(tt.fields.Bucket),
MaxKeys: 10,
})
if err != nil {
t.Fatalf("failed to list items: %v", err)
}

for _, o := range res.Contents {
fmt.Println(">>> ", *o.Key)
}
_, err = o.GetClient().PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(tt.fields.Bucket),
Key: aws.String("test.json"),
Body: bytes.NewBuffer([]byte(`{"name": "lwx"}`)),
})
if err != nil {
t.Errorf("failed to put object: %v", err)
t.Fatalf("failed to put object: %v", err)
}
})
}
Expand Down