Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .deepsource.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
version = 1

test_patterns = [
"*_test.go"
]

[[analyzers]]
name = "go"

Expand Down
101 changes: 60 additions & 41 deletions internal/utils/slices.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,78 @@ import (
"reflect"
)

// ReflectSliceCast casts a slice to another type using reflection
// ReflectSliceCast converts a slice to another type using reflection.
// Parameters:
// - slice: source slice to convert
// - newType: target type for the slice elements
//
// Retur
// Returns:
// - converted slice or original value if input is not a slice
func ReflectSliceCast(slice, newType any) any {
if !IsSlice(slice) {
return slice
}

typeType := reflect.TypeOf(newType)
sliceV := reflect.ValueOf(slice)
out := reflect.MakeSlice(typeType, sliceV.Len(), sliceV.Len())

for i := 0; i < sliceV.Len(); i++ {
vv := sliceV.Index(i)
var v reflect.Value
// This is stupid and has faults, but I did not find a better way
switch typeType.Elem().Kind() {
case reflect.Bool:
v = reflect.ValueOf(vv.Interface().(bool))
case reflect.Int:
v = reflect.ValueOf(vv.Interface().(int))
case reflect.Int8:
v = reflect.ValueOf(vv.Interface().(int8))
case reflect.Int16:
v = reflect.ValueOf(vv.Interface().(int16))
case reflect.Int32:
v = reflect.ValueOf(vv.Interface().(int32))
case reflect.Int64:
v = reflect.ValueOf(vv.Interface().(int64))
case reflect.Uint:
v = reflect.ValueOf(vv.Interface().(uint))
case reflect.Uint8:
v = reflect.ValueOf(vv.Interface().(uint8))
case reflect.Uint16:
v = reflect.ValueOf(vv.Interface().(uint16))
case reflect.Uint32:
v = reflect.ValueOf(vv.Interface().(uint32))
case reflect.Uint64:
v = reflect.ValueOf(vv.Interface().(uint64))
case reflect.Uintptr:
v = reflect.ValueOf(vv.Interface().(*uint))
case reflect.Float32:
v = reflect.ValueOf(vv.Interface().(float32))
case reflect.Float64:
v = reflect.ValueOf(vv.Interface().(float64))
case reflect.Interface:
v = vv
case reflect.String:
v = reflect.ValueOf(vv.Interface().(string))
default:
v = vv.Convert(typeType.Elem())
}
out.Index(i).Set(v)
sourceVal := sliceV.Index(i)
convertedVal := convertToTargetType(sourceVal, typeType.Elem())
out.Index(i).Set(convertedVal)
}

return out.Interface()
}

// convertToTargetType converts a reflect.Value to the target type.
// It handles primitive types explicitly and falls back to generic conversion for other types.
func convertToTargetType(val reflect.Value, targetType reflect.Type) reflect.Value {
if targetType.Kind() == reflect.Interface {
return val
}

// Get the underlying interface value
srcInterface := val.Interface()

// Handle primitive types
switch targetType.Kind() {
case reflect.Bool:
return reflect.ValueOf(srcInterface.(bool))
case reflect.Int:
return reflect.ValueOf(srcInterface.(int))
case reflect.Int8:
return reflect.ValueOf(srcInterface.(int8))
case reflect.Int16:
return reflect.ValueOf(srcInterface.(int16))
case reflect.Int32:
return reflect.ValueOf(srcInterface.(int32))
case reflect.Int64:
return reflect.ValueOf(srcInterface.(int64))
case reflect.Uint:
return reflect.ValueOf(srcInterface.(uint))
case reflect.Uint8:
return reflect.ValueOf(srcInterface.(uint8))
case reflect.Uint16:
return reflect.ValueOf(srcInterface.(uint16))
case reflect.Uint32:
return reflect.ValueOf(srcInterface.(uint32))
case reflect.Uint64:
return reflect.ValueOf(srcInterface.(uint64))
case reflect.Float32:
return reflect.ValueOf(srcInterface.(float32))
case reflect.Float64:
return reflect.ValueOf(srcInterface.(float64))
case reflect.String:
return reflect.ValueOf(srcInterface.(string))
default:
// For other types, try to convert using reflection
return val.Convert(targetType)
}
}

// ReflectSliceContains checks if a slice contains a value using reflection
func ReflectSliceContains(v, slice any) bool {
if !IsSlice(slice) {
Expand Down
91 changes: 52 additions & 39 deletions jwx/privateKeyStorageMultiAlg.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,56 +76,25 @@ func (sks *privateKeyStorageMultiAlg) initKeyRotation(pks *pkCollection, pksOnCh

// Load loads the private keys from disk and if necessary generates missing keys
func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() error) error {
populatePKFromSK := false
addPublicKeysToJWKS := false
if sks.signers == nil {
sks.signers = make(map[jwa.SignatureAlgorithm]crypto.Signer)
}
if len(pks.jwks) == 0 {
pks.jwks = []JWKS{NewJWKS()}
populatePKFromSK = true
addPublicKeysToJWKS = true
}
pksChanged := false
// load oidc keys

for _, alg := range sks.algs {
filePath := sks.keyFilePath(alg, false)
signer, err := readSignerFromFile(filePath, alg)
signer, changed, err := sks.loadOrGenerateSigner(alg, pks, addPublicKeysToJWKS)
if err != nil {
// could not load key, generating a new one for this alg
sk, pk, err := generateKeyPair(
alg, sks.rsaKeyLen, keyLifetimeConf{
NowIssued: true,
Expires: sks.rollover.Enabled,
Lifetime: sks.rollover.Interval.Duration(),
},
)
if err != nil {
return err
}
if err = writeSignerToFile(sk, sks.keyFilePath(alg, false)); err != nil {
return err
}
if err = pks.jwks[0].AddKey(pk); err != nil {
return errors.WithStack(err)
}
pksChanged = true
signer = sk
} else if populatePKFromSK {
pk, err := signerToPublicJWK(
signer, alg, keyLifetimeConf{
NowIssued: false,
Expires: sks.rollover.Enabled,
Lifetime: sks.rollover.Interval.Duration(),
},
)
if err != nil {
return err
}
if err = pks.jwks[0].AddKey(pk); err != nil {
return errors.WithStack(err)
}
return err
}
pksChanged = pksChanged || changed
sks.signers[alg] = signer

// Ensure the next key file exists for rollover
if !fileutils.FileExists(sks.keyFilePath(alg, true)) {
_, err = generateStoreAndSetNextPrivateKey(
pks, alg, sks.rsaKeyLen, keyLifetimeConf{
Expand All @@ -140,7 +109,8 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func()
}
}
}
if populatePKFromSK || pksChanged {

if addPublicKeysToJWKS || pksChanged {
if err := pksOnChange(); err != nil {
return err
}
Expand All @@ -149,6 +119,49 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func()
return nil
}

// loadOrGenerateSigner loads a signer from disk or generates a new one if it doesn't exist.
// If addPublicKeysToJWKS is true, it also adds the public key to the pkCollection.
func (sks *privateKeyStorageMultiAlg) loadOrGenerateSigner(
alg jwa.SignatureAlgorithm, pks *pkCollection, addPublicKeysToJWKS bool,
) (crypto.Signer, bool, error) {
filePath := sks.keyFilePath(alg, false)
signer, err := readSignerFromFile(filePath, alg)
if err != nil {
// Could not load key, generating a new one for this alg
sk, pk, err := generateKeyPair(
alg,
sks.rsaKeyLen,
keyLifetimeConf{
NowIssued: true,
Expires: sks.rollover.Enabled,
Lifetime: sks.rollover.Interval.Duration(),
},
)
if err != nil {
return nil, false, err
}
if err = writeSignerToFile(sk, filePath); err != nil {
return nil, false, err
}
pks.addCurrentJWK(pk)
return sk, true, nil
}
if addPublicKeysToJWKS {
pk, err := signerToPublicJWK(
signer, alg, keyLifetimeConf{
NowIssued: false,
Expires: sks.rollover.Enabled,
Lifetime: sks.rollover.Interval.Duration(),
},
)
if err != nil {
return nil, false, err
}
pks.addCurrentJWK(pk)
}
return signer, addPublicKeysToJWKS, nil
}

// GenerateNewKeys generates a new set of keys
func (sks *privateKeyStorageMultiAlg) GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error {
futureKeys := NewJWKS()
Expand Down
Loading