Skip to content

Commit

Permalink
improve builder
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Nov 5, 2019
1 parent 35633bc commit f864f62
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 49 deletions.
117 changes: 68 additions & 49 deletions build.go
@@ -1,7 +1,6 @@
package jwt

import (
"encoding"
"encoding/base64"
"encoding/json"
)
Expand All @@ -19,16 +18,26 @@ type TokenBuilder struct {
header Header
}

// BinaryMarshaler a marshaling interface for user claims.
type BinaryMarshaler interface {
MarshalBinary() (data []byte, err error)
}

// BuildBytes is used to create and encode JWT with a provided claims.
func BuildBytes(signer Signer, claims BinaryMarshaler) ([]byte, error) {
return NewTokenBuilder(signer).BuildBytes(claims)
}

// Build is used to create and encode JWT with a provided claims.
func Build(signer Signer, claims encoding.BinaryMarshaler) (*Token, error) {
func Build(signer Signer, claims BinaryMarshaler) (*Token, error) {
return NewTokenBuilder(signer).Build(claims)
}

// BuildWithHeader is used to create and encode JWT with a provided claims.
func BuildWithHeader(signer Signer, header *Header, claims encoding.BinaryMarshaler) (*Token, error) {
func BuildWithHeader(signer Signer, header Header, claims BinaryMarshaler) (*Token, error) {
b := &TokenBuilder{
signer: signer,
header: *header,
header: header,
}
return b.Build(claims)
}
Expand All @@ -46,18 +55,26 @@ func NewTokenBuilder(signer Signer) *TokenBuilder {
return b
}

// Build used to create and encode JWT with a provided claims.
func (b *TokenBuilder) Build(claims encoding.BinaryMarshaler) (*Token, error) {
encodedHeader := b.encodeHeader()
// BuildBytes used to create and encode JWT with a provided claims.
func (b *TokenBuilder) BuildBytes(claims BinaryMarshaler) ([]byte, error) {
token, err := b.Build(claims)
if err != nil {
return nil, err
}
return token.Raw(), nil
}

// Build used to create and encode JWT with a provided claims.
func (b *TokenBuilder) Build(claims BinaryMarshaler) (*Token, error) {
rawClaims, encodedClaims, err := encodeClaims(claims)
if err != nil {
return nil, err
}

encodedHeader := encodeHeader(&b.header)
payload := concatParts(encodedHeader, encodedClaims)

signed, signature, err := b.signPayload(payload)
signed, signature, err := signPayload(b.signer, payload)
if err != nil {
return nil, err
}
Expand All @@ -72,55 +89,57 @@ func (b *TokenBuilder) Build(claims encoding.BinaryMarshaler) (*Token, error) {
return token, nil
}

func (b *TokenBuilder) encodeHeader() []byte {
switch b.signer.Algorithm() {
case NoEncryption:
return []byte("eyJhbGciOiJub25lIn0")
case EdDSA:
return []byte("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9")

case HS256:
return []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9")
case HS384:
return []byte("eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9")
case HS512:
return []byte("eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9")

case RS256:
return []byte("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9")
case RS384:
return []byte("eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9")
case RS512:
return []byte("eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9")

case ES256:
return []byte("eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9")
case ES384:
return []byte("eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9")
case ES512:
return []byte("eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9")

case PS256:
return []byte("eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9")
case PS384:
return []byte("eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9")
case PS512:
return []byte("eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9")

default:
// another algorithm? encode below
func encodeHeader(header *Header) []byte {
if header.Type == "JWT" {
switch header.Algorithm {
case NoEncryption:
return []byte("eyJhbGciOiJub25lIn0")
case EdDSA:
return []byte("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9")

case HS256:
return []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9")
case HS384:
return []byte("eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9")
case HS512:
return []byte("eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9")

case RS256:
return []byte("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9")
case RS384:
return []byte("eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9")
case RS512:
return []byte("eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9")

case ES256:
return []byte("eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9")
case ES384:
return []byte("eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9")
case ES512:
return []byte("eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9")

case PS256:
return []byte("eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9")
case PS384:
return []byte("eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9")
case PS512:
return []byte("eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9")

default:
// another algorithm? encode below
}
}

// returned err is always nil, see *Header.MarshalJSON
buf, _ := json.Marshal(b.header)
buf, _ := json.Marshal(header)

encoded := make([]byte, base64EncodedLen(len(buf)))
base64Encode(encoded, buf)

return encoded
}

func encodeClaims(claims encoding.BinaryMarshaler) (raw, encoded []byte, err error) {
func encodeClaims(claims BinaryMarshaler) (raw, encoded []byte, err error) {
raw, err = claims.MarshalBinary()
if err != nil {
return nil, nil, err
Expand All @@ -132,8 +151,8 @@ func encodeClaims(claims encoding.BinaryMarshaler) (raw, encoded []byte, err err
return raw, encoded, nil
}

func (b *TokenBuilder) signPayload(payload []byte) (signed, signature []byte, err error) {
signature, err = b.signer.Sign(payload)
func signPayload(signer Signer, payload []byte) (signed, signature []byte, err error) {
signature, err = signer.Sign(payload)
if err != nil {
return nil, nil, err
}
Expand Down
68 changes: 68 additions & 0 deletions build_test.go
@@ -0,0 +1,68 @@
package jwt

import (
"encoding/base64"
"fmt"
"testing"
)

func TestBuild(t *testing.T) {
signer := NewHS256([]byte(`secret`))
builder := NewTokenBuilder(signer)

claims := &StandardClaims{
Audience: []string{"admin"},
ID: "random-unique-string",
}
token, _ := builder.Build(claims)

fmt.Printf("Algorithm %v\n", token.Header().Algorithm)
fmt.Printf("Type %v\n", token.Header().Type)
fmt.Printf("Claims %v\n", string(token.RawClaims()))
fmt.Printf("Payload %v\n", string(token.Payload()))
fmt.Printf("Token %v\n", string(token.Raw()))
}

func TestBuildWithHeader(t *testing.T) {
f := func(signer Signer, header Header, want string) {
t.Helper()

token, err := BuildWithHeader(signer, header, &StandardClaims{})
if err != nil {
t.Error(err)
}

want = toBase64(want)
raw := string(token.RawHeader())
if raw != want {
t.Errorf("want %v, got %v", want, raw)
}
}

f(
NewHS256(nil),
Header{Algorithm: HS256, Type: "JWT"},
`{"alg":"HS256","typ":"JWT"}`,
)
f(
NewHS512(nil),
Header{Algorithm: HS512, Type: "jit"},
`{"alg":"HS512","typ":"jit"}`,
)
f(
NewHS512(nil),
Header{Algorithm: Algorithm("OwO"), Type: "JWT"},
`{"alg":"OwO","typ":"JWT"}`,
)
f(
NewHS512(nil),
Header{Algorithm: Algorithm("UwU"), Type: "jit"},
`{"alg":"UwU","typ":"jit"}`,
)
}

func toBase64(s string) string {
buf := make([]byte, base64EncodedLen(len(s)))
base64.RawURLEncoding.Encode(buf, []byte(s))
return string(buf)
}

0 comments on commit f864f62

Please sign in to comment.