Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions algorithms.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,10 @@ func (v nonevalidator) sign(jwt *jwt) error {
return nil
}

func (jwt *jwt) rawEncode() {
func (jwt *jwt) rawEncode() (header, payload []byte) {
headerBuf := bytes.NewBuffer(nil)
payloadBuf := bytes.NewBuffer(nil)

// TODO: Determine if errors here are possible/relevant
json.NewEncoder(headerBuf).Encode(jwt.Header)
json.NewEncoder(payloadBuf).Encode(jwt.Payload)

Expand All @@ -86,12 +85,14 @@ func (jwt *jwt) rawEncode() {
json.Compact(compactHeaderBuf, headerBuf.Bytes())
json.Compact(compactPayloadBuf, payloadBuf.Bytes())

jwt.headerRaw = make([]byte, base64.URLEncoding.EncodedLen(len(compactHeaderBuf.Bytes())))
jwt.payloadRaw = make([]byte, base64.URLEncoding.EncodedLen(len(compactPayloadBuf.Bytes())))
header = make([]byte, base64.URLEncoding.EncodedLen(len(compactHeaderBuf.Bytes())))
payload = make([]byte, base64.URLEncoding.EncodedLen(len(compactPayloadBuf.Bytes())))

base64.URLEncoding.Encode(jwt.headerRaw, compactHeaderBuf.Bytes())
base64.URLEncoding.Encode(jwt.payloadRaw, compactPayloadBuf.Bytes())
base64.URLEncoding.Encode(header, compactHeaderBuf.Bytes())
base64.URLEncoding.Encode(payload, compactPayloadBuf.Bytes())

jwt.headerRaw = []byte(strings.Trim(string(jwt.headerRaw), "="))
jwt.payloadRaw = []byte(strings.Trim(string(jwt.payloadRaw), "="))
header = []byte(strings.Trim(string(header), "="))
payload = []byte(strings.Trim(string(payload), "="))

return header, payload
}
4 changes: 2 additions & 2 deletions es.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ func (v ESValidator) sign(jwt *jwt) (err error) {
}

jwt.Header.Algorithm = v.algorithm
jwt.rawEncode()
header, payload := jwt.rawEncode()

// TODO: This block is general. Refactor it out of RS and ES validators
hsh := v.hashType.New()
hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw)))
hsh.Write([]byte(string(header) + "." + string(payload)))
hash := hsh.Sum(nil)

r, s, err := ecdsa.Sign(v.rand, v.PrivateKey, hash)
Expand Down
4 changes: 2 additions & 2 deletions hs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ func (v hsValidator) validate(jwt *jwt) (bool, error) {
func (v hsValidator) sign(jwt *jwt) error {

jwt.Header.Algorithm = v.algorithm
jwt.rawEncode()
header, payload := jwt.rawEncode()

mac := hmac.New(v.hashFunc, v.Key)
mac.Write([]byte(strings.Trim(string(jwt.headerRaw), "=") + "." + strings.Trim(string(jwt.payloadRaw), "=")))
mac.Write([]byte(string(header) + "." + string(payload)))

jwt.Signature = []byte(base64.URLEncoding.EncodeToString(mac.Sum(nil)))
return nil
Expand Down
57 changes: 19 additions & 38 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ type Payload struct {
NotBefore *time.Time `json:"nbf,omitempty"`
IssuedAt *time.Time `json:"iat,omitempty"`
JWTId string `json:"jti,omitempty"`
raw []byte
}

// A Decoder is a centeralized reader and key used to consume and verify a
Expand All @@ -66,19 +65,16 @@ type Encoder struct {
type header struct {
Algorithm Algorithm `json:"alg"`
ContentType string `json:"typ"`
raw []byte
}

// A jwt is a unified structure of the components of a jwt. This structure is
//used internally to aggregate components during encoding and decoding.
type jwt struct {
Header *header
headerRaw []byte
Payload interface{}
claimsPayload *Payload
payloadRaw []byte
registeredPayload Payload
Signature []byte
Header *header
headerRaw []byte
Payload interface{}
payloadRaw []byte
Signature []byte
}

// NewDecoder creates an underlying Decoder with a given key and input reader
Expand All @@ -91,7 +87,8 @@ func NewDecoder(r io.Reader, v Validator) *Decoder {
// of the given token is verified and will return an error if a bad signature is
// found. In addition if the jwt is using an unimplemented algorithm an error will
// be returned as well.
func (dec *Decoder) Decode(v interface{}) error {
func (dec *Decoder) Decode(v interface{}) (err error) {
var valid bool

buf := bufio.NewReader(dec.reader)
input, err := buf.ReadString(byte(' '))
Expand All @@ -102,16 +99,16 @@ func (dec *Decoder) Decode(v interface{}) error {
return err
}

if valid, err := dec.validator.validate(jwt); !valid || err != nil {
if valid, err = dec.validator.validate(jwt); !valid || err != nil {

if err != nil {
return err
}

return ErrBadSignature
err = ErrBadSignature
}

return nil
return err
}

// NewEncoder creates an underlying Encoder with a given key and output writer
Expand All @@ -135,7 +132,9 @@ func (enc *Encoder) Encode(v interface{}) error {
return err
}

fmt.Fprintf(enc.writer, "%s", jwt.token())
header, payload := jwt.rawEncode()

fmt.Fprintf(enc.writer, "%s.%s.%s", string(header), string(payload), string(jwt.Signature))

return nil
}
Expand All @@ -144,7 +143,7 @@ func (jwt *jwt) parseHeader(raw string) error {
var err error
var value []byte

if value, err = parseField(raw); err != nil {
if value, err = padB64String(raw); err != nil {
return err
}

Expand All @@ -160,8 +159,7 @@ func (jwt *jwt) parseHeader(raw string) error {
func parseJWT(input string, payload interface{}) (*jwt, error) {
var err error
jwt := &jwt{
Header: &header{},
claimsPayload: &Payload{},
Header: &header{},
}

fields := strings.Split(input, ".")
Expand All @@ -183,35 +181,18 @@ func parseJWT(input string, payload interface{}) (*jwt, error) {
return jwt, nil
}

func (jwt *jwt) token() string {
header := strings.Trim(string(jwt.headerRaw), "=")
payload := strings.Trim(string(jwt.payloadRaw), "=")
signature := strings.Trim(string(jwt.Signature), "=")

return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}

func (jwt *jwt) parsePayload(raw string, v interface{}) error {
func (jwt *jwt) parsePayload(raw string, v interface{}) (err error) {
jwt.payloadRaw = []byte(raw)
value, err := parseField(raw)
value, err := padB64String(raw)

if err != nil {
return err
}

// TODO: How to deal with json encoder errors?
err = json.NewDecoder(bytes.NewReader(value)).Decode(v)

if err != nil {
return err
}

json.NewDecoder(bytes.NewReader(value)).Decode(jwt.claimsPayload)

return nil
return json.NewDecoder(bytes.NewReader(value)).Decode(v)
}

func parseField(b64Value string) ([]byte, error) {
func padB64String(b64Value string) ([]byte, error) {
if m := len(b64Value) % 4; m != 0 {
b64Value += strings.Repeat("=", 4-m)
}
Expand Down
10 changes: 5 additions & 5 deletions jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ func TestEncodingSigning(t *testing.T) {
Payload interface{}
ValidToken string
}{
{HS256, struct{}{}, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.UGgJ_8f7TlqazSojqRAKzMJ0SUWJCJJ_9jDHe5nrhto"},
{HS256, struct{}{}, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.UGgJ_8f7TlqazSojqRAKzMJ0SUWJCJJ_9jDHe5nrhto="},
{HS384, struct{}{}, "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.e30.YGfeZ7CN9vKz4M2SINxTixlpUEDqsCZNx4LMJK62Lr_Eiptnikcf5XfgDd_7eVWe"},
{HS512, struct{}{}, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.e30.wHUM-6oRBExIgOk9MLOQ_80WqbuOmXXNuyTy4WmM_0WBM6pXld0mru8rZbc9-E314K9UhMkDNHbg2MRjIsCR3g"},
{HS512, struct{}{}, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.e30.wHUM-6oRBExIgOk9MLOQ_80WqbuOmXXNuyTy4WmM_0WBM6pXld0mru8rZbc9-E314K9UhMkDNHbg2MRjIsCR3g=="},
}

for _, c := range cases {
Expand Down Expand Up @@ -178,11 +178,11 @@ func ExampleEncoder() {
}

fmt.Println(tokenBuffer.String())
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI=
}

func ExampleDecoder() {
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI"
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI="

payload := &struct {
Payload
Expand All @@ -200,5 +200,5 @@ func ExampleDecoder() {
}

fmt.Printf("%+v\n", payload)
// Output: &{Payload:{Issuer:Ben Campbell Subject: Audience: ExpirationTime:<nil> NotBefore:<nil> IssuedAt:<nil> JWTId: raw:[]} Admin:true UserID:1234}
// Output: &{Payload:{Issuer:Ben Campbell Subject: Audience: ExpirationTime:<nil> NotBefore:<nil> IssuedAt:<nil> JWTId:} Admin:true UserID:1234}
}
10 changes: 5 additions & 5 deletions rs.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ func (v RSValidator) validate(jwt *jwt) (bool, error) {
}

jwt.Header.Algorithm = v.algorithm
jwt.rawEncode()
header, payload := jwt.rawEncode()

signature, err := base64.URLEncoding.DecodeString(string(jwt.Signature))

if err != nil {
return false, err
return false, ErrMalformedToken
}

hsh := v.hashType.New()
hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw)))
hsh.Write([]byte(string(header) + "." + string(payload)))
hash := hsh.Sum(nil)

err = rsa.VerifyPKCS1v15(v.PublicKey, v.hashType, hash, signature)
Expand All @@ -81,10 +81,10 @@ func (v RSValidator) validate(jwt *jwt) (bool, error) {

func (v RSValidator) sign(jwt *jwt) (err error) {
jwt.Header.Algorithm = v.algorithm
jwt.rawEncode()
header, payload := jwt.rawEncode()

hsh := v.hashType.New()
hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw)))
hsh.Write([]byte(string(header) + "." + string(payload)))
hash := hsh.Sum(nil)

signature, _ := rsa.SignPKCS1v15(v.randReader, v.PrivateKey, v.hashType, hash)
Expand Down