Skip to content

Commit

Permalink
feat: add token requests client
Browse files Browse the repository at this point in the history
Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
  • Loading branch information
aramase committed Jan 24, 2022
1 parent 73bc800 commit e9116d9
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 9 deletions.
19 changes: 17 additions & 2 deletions cmd/secrets-store-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import (
"net/http"
_ "net/http/pprof" // #nosec
"os"
"strings"
"time"

secretsstorev1 "sigs.k8s.io/secrets-store-csi-driver/apis/v1"
"sigs.k8s.io/secrets-store-csi-driver/controllers"
"sigs.k8s.io/secrets-store-csi-driver/pkg/k8s"
"sigs.k8s.io/secrets-store-csi-driver/pkg/metrics"
"sigs.k8s.io/secrets-store-csi-driver/pkg/rotation"
secretsstore "sigs.k8s.io/secrets-store-csi-driver/pkg/secrets-store"
Expand Down Expand Up @@ -66,6 +68,12 @@ var (
providerHealthCheck = flag.Bool("provider-health-check", false, "Enable health check for configured providers")
providerHealthCheckInterval = flag.Duration("provider-health-check-interval", 2*time.Minute, "Provider healthcheck interval duration")

// Token request flags for the CSI driver
// Token request is beta in 1.21: https://kubernetes-csi.github.io/docs/token-requests.html. When we only support 1.21 and above, we can remove this flag
// as kubelet will send the token as part of the mount request.
audiences = flag.String("audiences", "kube-apiserver", "CSI token request audiences delimited by semi-colon. Audiences should be distinct, otherwise the validation will fail")
expirationSeconds = flag.Int64("expiration-seconds", 3600, "CSI token request expiration seconds")

scheme = runtime.NewScheme()
)

Expand Down Expand Up @@ -177,17 +185,24 @@ func main() {
reconciler.RunPatcher(ctx)
}()

// token request client
tokenClient, err := k8s.NewTokenClient(cfg, strings.Split(strings.TrimSpace(*audiences), ";"), *expirationSeconds)
if err != nil {
klog.ErrorS(err, "failed to create token client")
os.Exit(1)
}

// Secret rotation
if *enableSecretRotation {
rec, err := rotation.NewReconciler(mgr.GetCache(), scheme, *providerVolumePath, *nodeID, *rotationPollInterval, providerClients)
rec, err := rotation.NewReconciler(mgr.GetCache(), scheme, *providerVolumePath, *nodeID, *rotationPollInterval, providerClients, tokenClient)
if err != nil {
klog.ErrorS(err, "failed to initialize rotation reconciler")
os.Exit(1)
}
go rec.Run(ctx.Done())
}

driver := secretsstore.NewSecretsStoreDriver(*driverName, *nodeID, *endpoint, *providerVolumePath, providerClients, mgr.GetClient(), mgr.GetAPIReader())
driver := secretsstore.NewSecretsStoreDriver(*driverName, *nodeID, *endpoint, *providerVolumePath, providerClients, mgr.GetClient(), mgr.GetAPIReader(), tokenClient)
driver.Run(ctx)
}

Expand Down
138 changes: 138 additions & 0 deletions pkg/k8s/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package k8s

import (
"context"
"encoding/json"
"fmt"
"time"

authenticationv1 "k8s.io/api/authentication/v1"
storagev1 "k8s.io/api/storage/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)

// TokenClient is a client for Kubernetes Token API
type TokenClient struct {
client kubernetes.Interface
tokenRequests []storagev1.TokenRequest
}

// NewTokenClient creates a new TokenClient
// The client will be used to request a token for the preconfigured audiences (--audiences) and expiration time.
func NewTokenClient(config *rest.Config, audiences []string, expirationSeconds int64) (*TokenClient, error) {
kubeClient := kubernetes.NewForConfigOrDie(config)

tc := &TokenClient{
client: kubeClient,
tokenRequests: []storagev1.TokenRequest{},
}

for _, audience := range audiences {
tokenRequest := storagev1.TokenRequest{
Audience: audience,
ExpirationSeconds: &expirationSeconds,
}
tc.tokenRequests = append(tc.tokenRequests, tokenRequest)
}

errs := validateTokenRequests(tc.tokenRequests)
if len(errs) > 0 {
return nil, fmt.Errorf("failed to validate token requests: %v", errs)
}
return tc, nil
}

func (tc *TokenClient) PodServiceAccountTokenAttrs(ctx context.Context, podName, podNamespace, serviceAccountName string, podUID types.UID) (map[string]string, error) {
if len(tc.tokenRequests) == 0 {
return nil, nil
}

outputs := map[string]authenticationv1.TokenRequestStatus{}
for _, tokenRequest := range tc.tokenRequests {
audience := tokenRequest.Audience

tr, err := tc.client.CoreV1().
ServiceAccounts(podNamespace).
CreateToken(ctx, serviceAccountName,
&authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{
ExpirationSeconds: tokenRequest.ExpirationSeconds,
Audiences: []string{audience},
BoundObjectRef: &authenticationv1.BoundObjectReference{
Kind: "Pod",
APIVersion: "v1",
Name: podName,
UID: podUID,
},
},
},
metav1.CreateOptions{},
)

if err != nil {
return nil, fmt.Errorf("failed to create token request for %s: %w", audience, err)
}
outputs[audience] = tr.Status
}

tokens, err := json.Marshal(outputs)
if err != nil {
return nil, fmt.Errorf("failed to marshal token request status: %w", err)
}

return map[string]string{
"csi.storage.k8s.io/serviceAccount.tokens": string(tokens),
}, nil
}

// Vendored from kubernetes/pkg/apis/storage/validation/validation.go
// * tag: v1.23.0-alpha.1,
// * commit: dfaeacb51f9e68f7730d9e400c7f19ddb08c0087
// * link: https://github.com/kubernetes/kubernetes/blob/dfaeacb51f9e68f7730d9e400c7f19ddb08c0087/pkg/apis/storage/validation/validation.go

// validateTokenRequests tests if the Audience in each TokenRequest are different.
// Besides, at most one TokenRequest can ignore Audience.
func validateTokenRequests(tokenRequests []storagev1.TokenRequest) []error {
const min = 10 * time.Minute
var allErrs []error
audiences := make(map[string]bool)
for _, tokenRequest := range tokenRequests {
audience := tokenRequest.Audience
if _, ok := audiences[audience]; ok {
allErrs = append(allErrs, fmt.Errorf("duplicate audience %s", audience))
continue
}
audiences[audience] = true

if tokenRequest.ExpirationSeconds == nil {
continue
}
if *tokenRequest.ExpirationSeconds < int64(min.Seconds()) {
allErrs = append(allErrs, fmt.Errorf("expirationSeconds %d must be greater than %f", *tokenRequest.ExpirationSeconds, min.Seconds()))
}
if *tokenRequest.ExpirationSeconds > 1<<32 {
allErrs = append(allErrs, fmt.Errorf("expirationSeconds %d must be less than %d", *tokenRequest.ExpirationSeconds, 1<<32))
}
}

return allErrs
}
85 changes: 85 additions & 0 deletions pkg/k8s/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package k8s

import (
"testing"

storagev1 "k8s.io/api/storage/v1"
"k8s.io/utils/pointer"
)

func TestValidateTokenRequests(t *testing.T) {
test := []struct {
name string
tokenRequests []storagev1.TokenRequest
wantErrorsLen int
}{
{
name: "duplicate audience",
tokenRequests: []storagev1.TokenRequest{
{
Audience: "aud1",
ExpirationSeconds: pointer.Int64Ptr(3600),
},
{
Audience: "aud1",
ExpirationSeconds: pointer.Int64Ptr(1 << 33),
},
},
wantErrorsLen: 1,
},
{
name: "expiration seconds < 10m",
tokenRequests: []storagev1.TokenRequest{
{
Audience: "aud1",
ExpirationSeconds: pointer.Int64Ptr(599),
},
},
wantErrorsLen: 1,
},
{
name: "expiration seconds > 1<<32",
tokenRequests: []storagev1.TokenRequest{
{
Audience: "aud1",
ExpirationSeconds: pointer.Int64Ptr(1<<32 + 1),
},
},
wantErrorsLen: 1,
},
{
name: "token request has at most one token with empty string audience",
tokenRequests: []storagev1.TokenRequest{
{
Audience: "",
ExpirationSeconds: pointer.Int64Ptr(3600),
},
},
wantErrorsLen: 0,
},
{
name: "token request with different audiences",
tokenRequests: []storagev1.TokenRequest{
{
Audience: "aud1",
ExpirationSeconds: pointer.Int64Ptr(3600),
},
{
Audience: "aud2",
ExpirationSeconds: pointer.Int64Ptr(3600),
},
},
wantErrorsLen: 0,
},
}

for _, test := range test {
t.Run(test.name, func(t *testing.T) {
errs := validateTokenRequests(test.tokenRequests)
t.Log(errs)
if len(errs) != test.wantErrorsLen {
t.Errorf("validateTokenRequests() expected %v errors, got %v", test.wantErrorsLen, len(errs))
}
})
}
}
23 changes: 22 additions & 1 deletion pkg/rotation/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,20 @@ type Reconciler struct {
cache client.Reader
// secretStore stores Secret (filtered on secrets-store.csi.k8s.io/used=true)
secretStore k8s.Store
tokenClient *k8s.TokenClient
}

// +kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch
// These permissions are required for secret rotation + nodePublishSecretRef
// TODO (aramase) remove this as part of https://github.com/kubernetes-sigs/secrets-store-csi-driver/issues/585

// NewReconciler returns a new reconciler for rotation
func NewReconciler(client client.Reader, s *runtime.Scheme, providerVolumePath, nodeName string, rotationPollInterval time.Duration, providerClients *secretsstore.PluginClientBuilder) (*Reconciler, error) {
func NewReconciler(client client.Reader,
s *runtime.Scheme,
providerVolumePath, nodeName string,
rotationPollInterval time.Duration,
providerClients *secretsstore.PluginClientBuilder,
tokenClient *k8s.TokenClient) (*Reconciler, error) {
config, err := buildConfig()
if err != nil {
return nil, err
Expand Down Expand Up @@ -120,6 +126,7 @@ func NewReconciler(client client.Reader, s *runtime.Scheme, providerVolumePath,
// cache store Pod,
cache: client,
secretStore: secretStore,
tokenClient: tokenClient,
}, nil
}

Expand Down Expand Up @@ -310,6 +317,20 @@ func (r *Reconciler) reconcile(ctx context.Context, spcps *secretsstorev1.Secret
parameters[csipodnamespace] = pod.Namespace
parameters[csipoduid] = string(pod.UID)
parameters[csipodsa] = pod.Spec.ServiceAccountName
// csi.storage.k8s.io/serviceAccount.tokens is empty for Kubernetes version < 1.20.
// For 1.20+, if tokenRequests is set in the CSI driver spec, kubelet will generate
// a token for the pod and send it to the CSI driver.
// This check is done for backward compatibility to support passing token from driver
// to provider irrespective of the Kubernetes version. If the token doesn't exist in the
// volume request context, the CSI driver will generate the token for the configured audience
// and send it to the provider in the parameters.
serviceAccountTokenAttrs, err := r.tokenClient.PodServiceAccountTokenAttrs(ctx, pod.Namespace, pod.Name, pod.Spec.ServiceAccountName, pod.UID)
if err != nil {
return fmt.Errorf("failed to get service account token attrs, err: %w", err)
}
for k, v := range serviceAccountTokenAttrs {
parameters[k] = v
}

paramsJSON, err := json.Marshal(parameters)
if err != nil {
Expand Down
23 changes: 22 additions & 1 deletion pkg/secrets-store/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ import (
"runtime"

internalerrors "sigs.k8s.io/secrets-store-csi-driver/pkg/errors"
"sigs.k8s.io/secrets-store-csi-driver/pkg/k8s"

"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
mount "k8s.io/mount-utils"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -45,6 +47,7 @@ type nodeServer struct {
// This should be used sparingly and only when the client does not fit the use case.
reader client.Reader
providerClients *PluginClientBuilder
tokenClient *k8s.TokenClient
}

const (
Expand All @@ -62,7 +65,7 @@ const (
func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (npvr *csi.NodePublishVolumeResponse, err error) {
var parameters map[string]string
var providerName string
var podName, podNamespace, podUID string
var podName, podNamespace, podUID, serviceAccountName string
var targetPath string
var mounted bool
errorReason := internalerrors.FailedToMount
Expand Down Expand Up @@ -109,6 +112,7 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis
podName = attrib[csipodname]
podNamespace = attrib[csipodnamespace]
podUID = attrib[csipoduid]
serviceAccountName = attrib[csipodsa]

mounted, err = ns.ensureMountPoint(targetPath)
if err != nil {
Expand Down Expand Up @@ -170,6 +174,23 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis
parameters[csipoduid] = attrib[csipoduid]
parameters[csipodsa] = attrib[csipodsa]
parameters[csipodsatokens] = attrib[csipodsatokens] //nolint
// csi.storage.k8s.io/serviceAccount.tokens is empty for Kubernetes version < 1.20.
// For 1.20+, if tokenRequests is set in the CSI driver spec, kubelet will generate
// a token for the pod and send it to the CSI driver.
// This check is done for backward compatibility to support passing token from driver
// to provider irrespective of the Kubernetes version. If the token doesn't exist in the
// volume request context, the CSI driver will generate the token for the configured audience
// and send it to the provider in the parameters.
if parameters[csipodsatokens] == "" {
serviceAccountTokenAttrs, err := ns.tokenClient.PodServiceAccountTokenAttrs(ctx, podNamespace, podName, serviceAccountName, types.UID(podUID))
if err != nil {
klog.ErrorS(err, "failed to get service account token attrs", "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
return nil, err
}
for k, v := range serviceAccountTokenAttrs {
parameters[k] = v
}
}

// ensure it's read-only
if !req.GetReadonly() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/secrets-store/nodeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
func testNodeServer(t *testing.T, tmpDir string, mountPoints []mount.MountPoint, client client.Client, reporter StatsReporter) (*nodeServer, error) {
t.Helper()
providerClients := NewPluginClientBuilder(tmpDir)
return newNodeServer(tmpDir, "testnode", mount.NewFakeMounter(mountPoints), providerClients, client, client, reporter)
return newNodeServer(tmpDir, "testnode", mount.NewFakeMounter(mountPoints), providerClients, client, client, reporter, nil)
}

func TestNodePublishVolume(t *testing.T) {
Expand Down

0 comments on commit e9116d9

Please sign in to comment.