Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[launcher] Add custom token support #367

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
168 changes: 0 additions & 168 deletions go.work.sum

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ type principalIDTokenFetcher func(audience string) ([][]byte, error)
// struct to make testing easier.
type AttestationAgent interface {
MeasureEvent(cel.Content) error
Attest(context.Context) ([]byte, error)
Attest(context.Context, AttestAgentOpts) ([]byte, error)
}

// AttestAgentOpts contains user generated options when calling the
// VerifyAttestation API
type AttestAgentOpts struct {
Aud string
Nonces []string
}

type agent struct {
Expand Down Expand Up @@ -76,7 +83,7 @@ func (a *agent) MeasureEvent(event cel.Content) error {
// Attest fetches the nonce and connection ID from the Attestation Service,
// creates an attestation message, and returns the resultant
// principalIDTokens and Metadata Server-generated ID tokens for the instance.
func (a *agent) Attest(ctx context.Context) ([]byte, error) {
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
challenge, err := a.client.CreateChallenge(ctx)
if err != nil {
return nil, err
Expand All @@ -96,6 +103,8 @@ func (a *agent) Attest(ctx context.Context) ([]byte, error) {
Challenge: challenge,
GcpCredentials: principalTokens,
Attestation: attestation,
CustomAudience: opts.Aud,
CustomNonce: opts.Nonces,
}

if a.launchSpec.Experiments.EnableSignedContainerImage {
Expand Down
2 changes: 1 addition & 1 deletion launcher/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestAttest(t *testing.T) {

agent := CreateAttestationAgent(tpm, client.AttestationKeyECC, verifierClient, tc.principalIDTokenFetcher, tc.containerSignaturesFetcher, tc.launchSpec, log.Default())

tokenBytes, err := agent.Attest(context.Background())
tokenBytes, err := agent.Attest(context.Background(), AttestAgentOpts{})
if err != nil {
t.Errorf("failed to attest to Attestation Service: %v", err)
}
Expand Down
40 changes: 17 additions & 23 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ import (
"github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
"github.com/google/go-tpm-tools/launcher/launcherfile"
"github.com/google/go-tpm-tools/launcher/spec"
"github.com/google/go-tpm-tools/launcher/teeserver"
"github.com/google/go-tpm-tools/launcher/verifier"
"github.com/google/go-tpm-tools/launcher/verifier/rest"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
)

Expand All @@ -53,6 +53,8 @@ type ContainerRunner struct {

const tokenFileTmp = ".token.tmp"

const teeServerSocket = "teeserver.sock"

// Since we only allow one container on a VM, using a deterministic id is probably fine
const (
containerID = "tee-container"
Expand All @@ -74,26 +76,6 @@ const (
defaultRefreshJitter = 0.1
)

func fetchImpersonatedToken(ctx context.Context, serviceAccount string, audience string, opts ...option.ClientOption) ([]byte, error) {
config := impersonate.IDTokenConfig{
Audience: audience,
TargetPrincipal: serviceAccount,
IncludeEmail: true,
}

tokenSource, err := impersonate.IDTokenSource(ctx, config, opts...)
if err != nil {
return nil, fmt.Errorf("error creating token source: %v", err)
}

token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("error retrieving token: %v", err)
}

return []byte(token.AccessToken), nil
}

// NewRunner returns a runner.
func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.Token, launchSpec spec.LaunchSpec, mdsClient *metadata.Client, tpm io.ReadWriteCloser, logger *log.Logger, serialConsole *os.File) (*ContainerRunner, error) {
image, err := initImage(ctx, cdClient, launchSpec, token)
Expand All @@ -103,6 +85,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To

mounts := make([]specs.Mount, 0)
mounts = appendTokenMounts(mounts)

envs, err := formatEnvVars(launchSpec.Envs)
if err != nil {
return nil, err
Expand Down Expand Up @@ -214,7 +197,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To

// Fetch impersonated ID tokens.
for _, sa := range launchSpec.ImpersonateServiceAccounts {
idToken, err := fetchImpersonatedToken(ctx, sa, audience)
idToken, err := FetchImpersonatedToken(ctx, sa, audience)
if err != nil {
return nil, fmt.Errorf("failed to get impersonated token for %v: %w", sa, err)
}
Expand Down Expand Up @@ -360,7 +343,8 @@ func (r *ContainerRunner) measureContainerClaims(ctx context.Context) error {
// The token file will be written to a tmp file and then renamed.
func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) {
r.logger.Print("refreshing attestation verifier OIDC token")
token, err := r.attestAgent.Attest(ctx)
// request a default token
token, err := r.attestAgent.Attest(ctx, agent.AttestAgentOpts{})
if err != nil {
return 0, fmt.Errorf("failed to retrieve attestation service token: %v", err)
}
Expand Down Expand Up @@ -512,6 +496,16 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
}

r.logger.Printf("EnableTestFeatureForImage is set to %v\n", r.launchSpec.Experiments.EnableTestFeatureForImage)
// create and start the TEE server behind the experiment
if r.launchSpec.Experiments.EnableOnDemandAttestation {
r.logger.Println("EnableOnDemandAttestation is enabled: initializing TEE server.")
teeServer, err := teeserver.New(path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger)
if err != nil {
return fmt.Errorf("failed to create the TEE server: %v", err)
}
go teeServer.Serve()
defer teeServer.Shutdown(ctx)
}

var streamOpt cio.Opt
switch r.launchSpec.LogRedirect {
Expand Down
82 changes: 13 additions & 69 deletions launcher/container_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path"
"strconv"
Expand All @@ -23,10 +20,10 @@ import (
"github.com/containerd/containerd/namespaces"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-tpm-tools/cel"
"github.com/google/go-tpm-tools/launcher/agent"
"github.com/google/go-tpm-tools/launcher/launcherfile"
"github.com/google/go-tpm-tools/launcher/spec"
"golang.org/x/oauth2"
"google.golang.org/api/option"
)

const (
Expand All @@ -36,7 +33,7 @@ const (
// Fake attestation agent.
type fakeAttestationAgent struct {
measureEventFunc func(cel.Content) error
attestFunc func(context.Context) ([]byte, error)
attestFunc func(context.Context, agent.AttestAgentOpts) ([]byte, error)
}

func (f *fakeAttestationAgent) MeasureEvent(event cel.Content) error {
Expand All @@ -47,9 +44,9 @@ func (f *fakeAttestationAgent) MeasureEvent(event cel.Content) error {
return fmt.Errorf("unimplemented")
}

func (f *fakeAttestationAgent) Attest(ctx context.Context) ([]byte, error) {
func (f *fakeAttestationAgent) Attest(ctx context.Context, _ agent.AttestAgentOpts) ([]byte, error) {
if f.attestFunc != nil {
return f.attestFunc(ctx)
return f.attestFunc(ctx, agent.AttestAgentOpts{})
}

return nil, fmt.Errorf("unimplemented")
Expand Down Expand Up @@ -102,7 +99,7 @@ func TestRefreshToken(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand Down Expand Up @@ -146,15 +143,15 @@ func TestRefreshTokenError(t *testing.T) {
{
name: "Attest fails",
agent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return nil, errors.New("attest error")
},
},
},
{
name: "Attest returns expired token",
agent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return createJWT(t, -5*time.Second), nil
},
},
Expand Down Expand Up @@ -184,7 +181,7 @@ func TestFetchAndWriteTokenSucceeds(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand Down Expand Up @@ -212,11 +209,11 @@ func TestTokenIsNotChangedIfRefreshFails(t *testing.T) {

expectedToken := createJWT(t, 5*time.Second)
ttl := 5 * time.Second
successfulAttestFunc := func(context.Context) ([]byte, error) {
successfulAttestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
}

errorAttestFunc := func(context.Context) ([]byte, error) {
errorAttestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return nil, errors.New("attest unsuccessful")
}

Expand Down Expand Up @@ -289,7 +286,7 @@ func testRetryPolicyWithNTries(t *testing.T, numTries int, expectRefresh bool) {
// Wait the initial token's 5s plus a second per retry (MaxInterval).
ttl := time.Duration(numTries)*time.Second + 5*time.Second
retry := -1
attestFunc := func(context.Context) ([]byte, error) {
attestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
retry++
// Success on the initial fetch (subsequent calls use refresher goroutine).
if retry == 0 {
Expand Down Expand Up @@ -350,7 +347,7 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand All @@ -374,7 +371,7 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {
// Change attest agent to return new token.
expectedRefreshedToken := createJWT(t, 10*time.Second)
runner.attestAgent = &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedRefreshedToken, nil
},
}
Expand Down Expand Up @@ -402,59 +399,6 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {
}
}

type testRoundTripper struct {
roundTripFunc func(*http.Request) *http.Response
}

func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.roundTripFunc(req), nil
}

type idTokenResp struct {
Token string `json:"token"`
}

func TestFetchImpersonatedToken(t *testing.T) {
expectedEmail := "test2@google.com"

expectedToken := []byte("test_token")

expectedURL := fmt.Sprintf(idTokenEndpoint, expectedEmail)
client := &http.Client{
Transport: &testRoundTripper{
roundTripFunc: func(req *http.Request) *http.Response {
if req.URL.String() != expectedURL {
t.Errorf("HTTP call was not made to a endpoint: got %v, want %v", req.URL.String(), expectedURL)
}

resp := idTokenResp{
Token: string(expectedToken),
}

respBody, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Unable to marshal HTTP response: %v", err)
}

return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBuffer(respBody)),
}
},
},
}

token, err := fetchImpersonatedToken(context.Background(), expectedEmail, "test_aud", option.WithHTTPClient(client))
if err != nil {
t.Fatalf("fetchImpersonatedToken returned error: %v", err)
}

if !bytes.Equal(token, expectedToken) {
t.Errorf("fetchImpersonatedToken did not return expected token: got %v, want %v", token, expectedToken)
}
}

func TestGetNextRefresh(t *testing.T) {
// 0 <= random < 1.
for _, randNum := range []float64{0, .1415926, .5, .75, .999999999} {
Expand Down
1 change: 1 addition & 0 deletions launcher/internal/experiments/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
type Experiments struct {
EnableTestFeatureForImage bool
EnableSignedContainerImage bool
EnableOnDemandAttestation bool
}

// New takes a filepath, opens the file, and calls ReadJsonInput with the contents
Expand Down