diff --git a/cryptosigner/cryptosigner.go b/cryptosigner/cryptosigner.go index dcdbbd9..ddad5c9 100644 --- a/cryptosigner/cryptosigner.go +++ b/cryptosigner/cryptosigner.go @@ -23,6 +23,7 @@ import ( "crypto" "crypto/ecdsa" "crypto/ed25519" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/asn1" @@ -51,12 +52,20 @@ func (s *cryptoSigner) Public() *jose.JSONWebKey { } func (s *cryptoSigner) Algs() []jose.SignatureAlgorithm { - switch s.signer.Public().(type) { + switch key := s.signer.Public().(type) { case ed25519.PublicKey: return []jose.SignatureAlgorithm{jose.EdDSA} case *ecdsa.PublicKey: - // This could be more precise - return []jose.SignatureAlgorithm{jose.ES256, jose.ES384, jose.ES512} + switch key.Curve { + case elliptic.P256(): + return []jose.SignatureAlgorithm{jose.ES256} + case elliptic.P384(): + return []jose.SignatureAlgorithm{jose.ES384} + case elliptic.P521(): + return []jose.SignatureAlgorithm{jose.ES512} + default: + return nil + } case *rsa.PublicKey: return []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512} default: diff --git a/cryptosigner/cryptosigner_test.go b/cryptosigner/cryptosigner_test.go index be7c4de..28397f2 100644 --- a/cryptosigner/cryptosigner_test.go +++ b/cryptosigner/cryptosigner_test.go @@ -24,7 +24,10 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "errors" "fmt" + "io" + "reflect" "testing" "github.com/go-jose/go-jose/v3" @@ -135,3 +138,74 @@ func generateSigningTestKey(sigAlg jose.SignatureAlgorithm) (sig, ver interface{ } return } + +type fakeSigner struct{} + +func (fakeSigner) Public() crypto.PublicKey { + return []byte("fake-key") +} + +func (fakeSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return nil, errors.New("not a signer") +} + +func Test_cryptoSigner_Algs(t *testing.T) { + _, edKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + p224, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + p384, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + p521, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + type fields struct { + signer crypto.Signer + } + + tests := []struct { + name string + fields fields + want []jose.SignatureAlgorithm + }{ + {"EdDSA", fields{edKey}, []jose.SignatureAlgorithm{jose.EdDSA}}, + {"ES256", fields{p256}, []jose.SignatureAlgorithm{jose.ES256}}, + {"ES384", fields{p384}, []jose.SignatureAlgorithm{jose.ES384}}, + {"ES512", fields{p521}, []jose.SignatureAlgorithm{jose.ES512}}, + {"RSA", fields{rsaKey}, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512}}, + {"fail P-224", fields{p224}, nil}, + {"fail other", fields{fakeSigner{}}, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := &cryptoSigner{ + signer: tt.fields.signer, + } + if got := cs.Algs(); !reflect.DeepEqual(tt.want, got) { + t.Errorf("cryptoSigner.Algs() got = %v, want %v", got, tt.want) + } + }) + } +}