From c9af1f4baf48a85f5708894afa80210397dcfc70 Mon Sep 17 00:00:00 2001 From: zachmann Date: Tue, 5 Aug 2025 11:24:17 +0200 Subject: [PATCH 1/4] [utils] refactor ReflectSliceCast; extract type conversion logic --- internal/utils/slices.go | 101 +++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/internal/utils/slices.go b/internal/utils/slices.go index 4a92442..6582600 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -4,59 +4,78 @@ import ( "reflect" ) -// ReflectSliceCast casts a slice to another type using reflection +// ReflectSliceCast converts a slice to another type using reflection. +// Parameters: +// - slice: source slice to convert +// - newType: target type for the slice elements +// +// Retur +// Returns: +// - converted slice or original value if input is not a slice func ReflectSliceCast(slice, newType any) any { if !IsSlice(slice) { return slice } + typeType := reflect.TypeOf(newType) sliceV := reflect.ValueOf(slice) out := reflect.MakeSlice(typeType, sliceV.Len(), sliceV.Len()) + for i := 0; i < sliceV.Len(); i++ { - vv := sliceV.Index(i) - var v reflect.Value - // This is stupid and has faults, but I did not find a better way - switch typeType.Elem().Kind() { - case reflect.Bool: - v = reflect.ValueOf(vv.Interface().(bool)) - case reflect.Int: - v = reflect.ValueOf(vv.Interface().(int)) - case reflect.Int8: - v = reflect.ValueOf(vv.Interface().(int8)) - case reflect.Int16: - v = reflect.ValueOf(vv.Interface().(int16)) - case reflect.Int32: - v = reflect.ValueOf(vv.Interface().(int32)) - case reflect.Int64: - v = reflect.ValueOf(vv.Interface().(int64)) - case reflect.Uint: - v = reflect.ValueOf(vv.Interface().(uint)) - case reflect.Uint8: - v = reflect.ValueOf(vv.Interface().(uint8)) - case reflect.Uint16: - v = reflect.ValueOf(vv.Interface().(uint16)) - case reflect.Uint32: - v = reflect.ValueOf(vv.Interface().(uint32)) - case reflect.Uint64: - v = reflect.ValueOf(vv.Interface().(uint64)) - case reflect.Uintptr: - v = reflect.ValueOf(vv.Interface().(*uint)) - case reflect.Float32: - v = reflect.ValueOf(vv.Interface().(float32)) - case reflect.Float64: - v = reflect.ValueOf(vv.Interface().(float64)) - case reflect.Interface: - v = vv - case reflect.String: - v = reflect.ValueOf(vv.Interface().(string)) - default: - v = vv.Convert(typeType.Elem()) - } - out.Index(i).Set(v) + sourceVal := sliceV.Index(i) + convertedVal := convertToTargetType(sourceVal, typeType.Elem()) + out.Index(i).Set(convertedVal) } + return out.Interface() } +// convertToTargetType converts a reflect.Value to the target type. +// It handles primitive types explicitly and falls back to generic conversion for other types. +func convertToTargetType(val reflect.Value, targetType reflect.Type) reflect.Value { + if targetType.Kind() == reflect.Interface { + return val + } + + // Get the underlying interface value + srcInterface := val.Interface() + + // Handle primitive types + switch targetType.Kind() { + case reflect.Bool: + return reflect.ValueOf(srcInterface.(bool)) + case reflect.Int: + return reflect.ValueOf(srcInterface.(int)) + case reflect.Int8: + return reflect.ValueOf(srcInterface.(int8)) + case reflect.Int16: + return reflect.ValueOf(srcInterface.(int16)) + case reflect.Int32: + return reflect.ValueOf(srcInterface.(int32)) + case reflect.Int64: + return reflect.ValueOf(srcInterface.(int64)) + case reflect.Uint: + return reflect.ValueOf(srcInterface.(uint)) + case reflect.Uint8: + return reflect.ValueOf(srcInterface.(uint8)) + case reflect.Uint16: + return reflect.ValueOf(srcInterface.(uint16)) + case reflect.Uint32: + return reflect.ValueOf(srcInterface.(uint32)) + case reflect.Uint64: + return reflect.ValueOf(srcInterface.(uint64)) + case reflect.Float32: + return reflect.ValueOf(srcInterface.(float32)) + case reflect.Float64: + return reflect.ValueOf(srcInterface.(float64)) + case reflect.String: + return reflect.ValueOf(srcInterface.(string)) + default: + // For other types, try to convert using reflection + return val.Convert(targetType) + } +} + // ReflectSliceContains checks if a slice contains a value using reflection func ReflectSliceContains(v, slice any) bool { if !IsSlice(slice) { From a2822d9cc620661f99d441709c1643ad70876500 Mon Sep 17 00:00:00 2001 From: zachmann Date: Tue, 5 Aug 2025 12:00:08 +0200 Subject: [PATCH 2/4] [trust resolver] refactor authority resolution; extract helper methods --- trustresolver.go | 149 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 43 deletions(-) diff --git a/trustresolver.go b/trustresolver.go index 4e2bb13..5b938ef 100644 --- a/trustresolver.go +++ b/trustresolver.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/pkg/errors" "github.com/scylladb/go-set/strset" "github.com/vmihailenco/msgpack/v5" "github.com/zachmann/go-utils/sliceutils" @@ -317,60 +318,122 @@ func (t *trustTree) resolve(anchors TrustAnchors) { if t.Entity == nil { return } - if t.Entity.ExpiresAt.Before(t.expiresAt.Time) { - t.expiresAt = t.Entity.ExpiresAt - } + + t.updateExpirationTime() + + // Early return if entity is issued by a trust anchor if sliceutils.SliceContains(t.Entity.Issuer, anchors.EntityIDs()) { return } + + t.resolveAuthorities(anchors) +} + +func (t *trustTree) updateExpirationTime() { + if t.Entity.ExpiresAt.Before(t.expiresAt.Time) { + t.expiresAt = t.Entity.ExpiresAt + } +} + +func (t *trustTree) resolveAuthorities(anchors TrustAnchors) { if len(t.Entity.AuthorityHints) > 0 { t.Authorities = make([]trustTree, len(t.Entity.AuthorityHints)) } - for i, aID := range t.Entity.AuthorityHints { - if t.subordinateIDs.Has(aID) { - // loop prevention - continue - } - aStmt, err := GetEntityConfiguration(aID) - if err != nil { - continue - } - if !utils.Equal(aStmt.Issuer, aStmt.Subject, aID) || !aStmt.TimeValid() { - continue - } - if aStmt.Metadata == nil || aStmt.Metadata.FederationEntity == nil || aStmt.Metadata.FederationEntity. - FederationFetchEndpoint == "" { - continue + + for i, authorityID := range t.Entity.AuthorityHints { + if t.subordinateIDs.Has(authorityID) { + continue // Loop prevention } - subordinateStmt, err := FetchEntityStatement( - aStmt.Metadata.FederationEntity.FederationFetchEndpoint, t.Entity.Issuer, aID, - ) + + authority, err := t.resolveAuthority(authorityID, anchors) if err != nil { continue } - if subordinateStmt.Issuer != aID || subordinateStmt.Subject != t.Entity.Issuer || !subordinateStmt.TimeValid() { - continue - } - if !t.checkConstraints(subordinateStmt.Constraints) { - continue - } - if subordinateStmt.ExpiresAt.Before(t.expiresAt.Time) { - t.expiresAt = subordinateStmt.ExpiresAt - } - entityTypes := t.includedEntityTypes.Copy() - entityTypes.Add(aStmt.Metadata.GuessEntityTypes()...) - subordinates := t.subordinateIDs.Copy() - subordinates.Add(aID) - tt := trustTree{ - Entity: aStmt, - Subordinate: subordinateStmt, - depth: t.depth + 1, - includedEntityTypes: entityTypes, - subordinateIDs: subordinates, - } - tt.resolve(anchors) - t.Authorities[i] = tt + + t.Authorities[i] = authority + } +} + +func (t *trustTree) resolveAuthority(authorityID string, anchors TrustAnchors) (trustTree, error) { + authorityStmt, err := GetEntityConfiguration(authorityID) + if err != nil { + return trustTree{}, err + } + + if !isValidAuthorityStatement(authorityStmt, authorityID) { + return trustTree{}, errors.New("invalid authority statement") } + + subordinateStmt, err := t.fetchAndValidateSubordinateStatement(authorityStmt, authorityID) + if err != nil { + return trustTree{}, err + } + + if !t.checkConstraints(subordinateStmt.Constraints) { + return trustTree{}, errors.New("constraints check failed") + } + + t.updateExpirationTimeFromSubordinate(subordinateStmt) + + return t.createAuthorityTrustTree(authorityStmt, subordinateStmt, authorityID, anchors), nil +} + +func isValidAuthorityStatement(stmt *EntityStatement, authorityID string) bool { + return utils.Equal(stmt.Issuer, stmt.Subject, authorityID) && + stmt.TimeValid() && + stmt.Metadata != nil && + stmt.Metadata.FederationEntity != nil && + stmt.Metadata.FederationEntity.FederationFetchEndpoint != "" +} + +func (t *trustTree) fetchAndValidateSubordinateStatement( + authorityStmt *EntityStatement, authorityID string, +) (*EntityStatement, error) { + subordinateStmt, err := FetchEntityStatement( + authorityStmt.Metadata.FederationEntity.FederationFetchEndpoint, t.Entity.Issuer, authorityID, + ) + if err != nil { + return nil, err + } + + if !isValidSubordinateStatement(subordinateStmt, authorityID, t.Entity.Issuer) { + return nil, errors.New("invalid subordinate statement") + } + + return subordinateStmt, nil +} + +func isValidSubordinateStatement(stmt *EntityStatement, authorityID, entityIssuer string) bool { + return stmt.Issuer == authorityID && + stmt.Subject == entityIssuer && + stmt.TimeValid() +} + +func (t *trustTree) updateExpirationTimeFromSubordinate(subordinateStmt *EntityStatement) { + if subordinateStmt.ExpiresAt.Before(t.expiresAt.Time) { + t.expiresAt = subordinateStmt.ExpiresAt + } +} + +func (t *trustTree) createAuthorityTrustTree( + authorityStmt, subordinateStmt *EntityStatement, authorityID string, anchors TrustAnchors, +) trustTree { + entityTypes := t.includedEntityTypes.Copy() + entityTypes.Add(authorityStmt.Metadata.GuessEntityTypes()...) + + subordinates := t.subordinateIDs.Copy() + subordinates.Add(authorityID) + + newTree := trustTree{ + Entity: authorityStmt, + Subordinate: subordinateStmt, + depth: t.depth + 1, + includedEntityTypes: entityTypes, + subordinateIDs: subordinates, + } + newTree.resolve(anchors) + + return newTree } func (t *trustTree) checkConstraints(constraints *ConstraintSpecification) bool { From 62618c3f8f5a7c2b68840eda608ce84182accf24 Mon Sep 17 00:00:00 2001 From: zachmann Date: Tue, 5 Aug 2025 12:07:36 +0200 Subject: [PATCH 3/4] [deepsource] add test_patterns configuration for Go files --- .deepsource.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.deepsource.toml b/.deepsource.toml index d6a9253..b65b585 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -1,5 +1,9 @@ version = 1 +test_patterns = [ + "*_test.go" +] + [[analyzers]] name = "go" From c244ee2ef899c17d8a398355981ad86f9dc0593a Mon Sep 17 00:00:00 2001 From: zachmann Date: Tue, 5 Aug 2025 15:24:53 +0200 Subject: [PATCH 4/4] [privateKeyStorage] refactor Load method; extract signer loading logic to reusable function --- jwx/privateKeyStorageMultiAlg.go | 91 ++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/jwx/privateKeyStorageMultiAlg.go b/jwx/privateKeyStorageMultiAlg.go index b7c22ce..57ac842 100644 --- a/jwx/privateKeyStorageMultiAlg.go +++ b/jwx/privateKeyStorageMultiAlg.go @@ -76,56 +76,25 @@ func (sks *privateKeyStorageMultiAlg) initKeyRotation(pks *pkCollection, pksOnCh // Load loads the private keys from disk and if necessary generates missing keys func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() error) error { - populatePKFromSK := false + addPublicKeysToJWKS := false if sks.signers == nil { sks.signers = make(map[jwa.SignatureAlgorithm]crypto.Signer) } if len(pks.jwks) == 0 { pks.jwks = []JWKS{NewJWKS()} - populatePKFromSK = true + addPublicKeysToJWKS = true } pksChanged := false - // load oidc keys + for _, alg := range sks.algs { - filePath := sks.keyFilePath(alg, false) - signer, err := readSignerFromFile(filePath, alg) + signer, changed, err := sks.loadOrGenerateSigner(alg, pks, addPublicKeysToJWKS) if err != nil { - // could not load key, generating a new one for this alg - sk, pk, err := generateKeyPair( - alg, sks.rsaKeyLen, keyLifetimeConf{ - NowIssued: true, - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, - ) - if err != nil { - return err - } - if err = writeSignerToFile(sk, sks.keyFilePath(alg, false)); err != nil { - return err - } - if err = pks.jwks[0].AddKey(pk); err != nil { - return errors.WithStack(err) - } - pksChanged = true - signer = sk - } else if populatePKFromSK { - pk, err := signerToPublicJWK( - signer, alg, keyLifetimeConf{ - NowIssued: false, - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, - ) - if err != nil { - return err - } - if err = pks.jwks[0].AddKey(pk); err != nil { - return errors.WithStack(err) - } + return err } + pksChanged = pksChanged || changed sks.signers[alg] = signer + // Ensure the next key file exists for rollover if !fileutils.FileExists(sks.keyFilePath(alg, true)) { _, err = generateStoreAndSetNextPrivateKey( pks, alg, sks.rsaKeyLen, keyLifetimeConf{ @@ -140,7 +109,8 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() } } } - if populatePKFromSK || pksChanged { + + if addPublicKeysToJWKS || pksChanged { if err := pksOnChange(); err != nil { return err } @@ -149,6 +119,49 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() return nil } +// loadOrGenerateSigner loads a signer from disk or generates a new one if it doesn't exist. +// If addPublicKeysToJWKS is true, it also adds the public key to the pkCollection. +func (sks *privateKeyStorageMultiAlg) loadOrGenerateSigner( + alg jwa.SignatureAlgorithm, pks *pkCollection, addPublicKeysToJWKS bool, +) (crypto.Signer, bool, error) { + filePath := sks.keyFilePath(alg, false) + signer, err := readSignerFromFile(filePath, alg) + if err != nil { + // Could not load key, generating a new one for this alg + sk, pk, err := generateKeyPair( + alg, + sks.rsaKeyLen, + keyLifetimeConf{ + NowIssued: true, + Expires: sks.rollover.Enabled, + Lifetime: sks.rollover.Interval.Duration(), + }, + ) + if err != nil { + return nil, false, err + } + if err = writeSignerToFile(sk, filePath); err != nil { + return nil, false, err + } + pks.addCurrentJWK(pk) + return sk, true, nil + } + if addPublicKeysToJWKS { + pk, err := signerToPublicJWK( + signer, alg, keyLifetimeConf{ + NowIssued: false, + Expires: sks.rollover.Enabled, + Lifetime: sks.rollover.Interval.Duration(), + }, + ) + if err != nil { + return nil, false, err + } + pks.addCurrentJWK(pk) + } + return signer, addPublicKeysToJWKS, nil +} + // GenerateNewKeys generates a new set of keys func (sks *privateKeyStorageMultiAlg) GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error { futureKeys := NewJWKS()