Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions cmd/nginx-ingress/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/marketplacemetering"
"github.com/aws/aws-sdk-go-v2/service/marketplacemetering/types"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v5"
)

var (
Expand All @@ -24,6 +24,12 @@ var (
pubKeyString string
)

var (
ErrMissingProductCode = errors.New("token doesn't include the ProductCode")
ErrMissingNonce = errors.New("token doesn't include the Nonce")
ErrMissingKeyVersion = errors.New("token doesn't include the PublicKeyVersion")
)

func init() {
startupCheckFn = checkAWSEntitlement
}
Expand Down Expand Up @@ -95,21 +101,18 @@ type claims struct {
jwt.RegisteredClaims
}

func (c claims) Valid() error {
var _ jwt.ClaimsValidator = (*claims)(nil)

func (c claims) Validate() error {
if c.Nonce == "" {
return jwt.NewValidationError("token doesn't include the Nonce", jwt.ValidationErrorClaimsInvalid)
return ErrMissingNonce
}
if c.ProductCode == "" {
return jwt.NewValidationError("token doesn't include the ProductCode", jwt.ValidationErrorClaimsInvalid)
return ErrMissingProductCode
}
if c.PublicKeyVersion == 0 {
return jwt.NewValidationError("token doesn't include the PublicKeyVersion", jwt.ValidationErrorClaimsInvalid)
return ErrMissingKeyVersion
}

if err := c.RegisteredClaims.Valid(); err != nil {
return err
}

return nil
}

Expand Down
101 changes: 52 additions & 49 deletions cmd/nginx-ingress/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v5"
)

func TestValidClaims(t *testing.T) {
Expand All @@ -21,69 +21,72 @@ func TestValidClaims(t *testing.T) {
IssuedAt: &iat,
},
}
if err := c.Valid(); err != nil {
v := jwt.NewValidator(
jwt.WithIssuedAt(),
)
if err := v.Validate(c); err != nil {
t.Fatalf("Failed to verify claims, wanted: %v got %v", nil, err)
}
}

func TestInvalidClaims(t *testing.T) {
badClaims := []struct {
c claims
expectedError error
type fields struct {
leeway time.Duration
timeFunc func() time.Time
expectedAud string
expectAllAud []string
expectedIss string
expectedSub string
}
type args struct {
claims jwt.Claims
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
claims{
"",
1,
"nonce",
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)),
},
},
errors.New("token doesn't include the ProductCode"),
name: "missing ProductCode",
fields: fields{},
args: args{jwt.RegisteredClaims{}},
wantErr: ErrMissingProductCode,
},
{
claims{
"productCode",
1,
"",
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)),
},
},
errors.New("token doesn't include the Nonce"),
name: "missing Nonce",
fields: fields{},
args: args{jwt.RegisteredClaims{}},
wantErr: ErrMissingNonce,
},
{
claims{
"productCode",
0,
"nonce",
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)),
},
},
errors.New("token doesn't include the PublicKeyVersion"),
name: "missing PublicKeyVersion",
fields: fields{},
args: args{jwt.RegisteredClaims{}},
wantErr: ErrMissingKeyVersion,
},
{
claims{
"test",
1,
"nonce",
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * +2)),
},
},
errors.New("token used before issued"),
name: "iat is in the future",
fields: fields{},
args: args{jwt.RegisteredClaims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * +2))}},
wantErr: jwt.ErrTokenUsedBeforeIssued,
},
}

for _, badC := range badClaims {

err := badC.c.Valid()
if err == nil {
t.Errorf("Valid() returned no error when it should have returned error %q", badC.expectedError)
} else if err.Error() != badC.expectedError.Error() {
t.Errorf("Valid() returned error %q when it should have returned error %q", err, badC.expectedError)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := jwt.NewValidator(
jwt.WithLeeway(tt.fields.leeway),
jwt.WithTimeFunc(tt.fields.timeFunc),
jwt.WithIssuedAt(),
jwt.WithAudience(tt.fields.expectedAud),
jwt.WithAllAudiences(tt.fields.expectAllAud...),
jwt.WithIssuer(tt.fields.expectedIss),
jwt.WithSubject(tt.fields.expectedSub),
)
if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.Validate() error = %v, wantErr = %v", err, tt.wantErr)
}
})
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/cert-manager/cert-manager v1.18.2
github.com/dlclark/regexp2 v1.11.5
github.com/gkampitakis/go-snaps v0.5.15
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/go-cmp v0.7.0
github.com/gruntwork-io/terratest v0.50.0
github.com/jinzhu/copier v0.4.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/gonvenience/bunt v1.3.5 h1:wSQquifvwEWtzn27k1ngLfeLaStyt0k1b/K6TrlCNAs=
Expand Down