Skip to content
This repository has been archived by the owner on Oct 24, 2023. It is now read-only.

Commit

Permalink
Develop (#23)
Browse files Browse the repository at this point in the history
* add jwt
* add postgresql for orm
  • Loading branch information
markus621 committed Apr 22, 2023
1 parent b877cc4 commit c05c9ac
Show file tree
Hide file tree
Showing 16 changed files with 564 additions and 65 deletions.
176 changes: 176 additions & 0 deletions auth/jwt/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package jwt

//go:generate easyjson

import (
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"strings"
"time"
)

const (
AlgHS256 = "HS256"
AlgHS384 = "HS384"
AlgHS512 = "HS512"
)

type Config struct {
ID string `yaml:"id"`
Key string `yaml:"key"`
Algorithm string `yaml:"alg"`
}

//easyjson:json
type Header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"eat"`
}

type (
JWT struct {
pool map[string]*Pool
}

Pool struct {
conf Config
hash func() hash.Hash
key []byte
}
)

func New(conf []Config) (*JWT, error) {
obj := &JWT{pool: make(map[string]*Pool)}

for _, c := range conf {
var h func() hash.Hash
switch c.Algorithm {
case AlgHS256:
h = sha256.New
case AlgHS384:
h = sha512.New384
case AlgHS512:
h = sha512.New
default:
return nil, fmt.Errorf("jwt algorithm not supported")
}
obj.pool[c.ID] = &Pool{conf: c, hash: h, key: []byte(c.Key)}
}

return obj, nil
}

func (v *JWT) rndPool() (*Pool, error) {
for _, p := range v.pool {
return p, nil
}
return nil, fmt.Errorf("jwt pool is empty")
}

func (v *JWT) getPool(id string) (*Pool, error) {
p, ok := v.pool[id]
if ok {
return p, nil
}
return nil, fmt.Errorf("jwt pool not found")
}

func (v *JWT) calcHash(hash func() hash.Hash, key []byte, data []byte) ([]byte, error) {
mac := hmac.New(hash, key)
if _, err := mac.Write(data); err != nil {
return nil, err
}
result := mac.Sum(nil)
return result, nil
}

func (v *JWT) Sign(payload interface{}, ttl time.Duration) (string, error) {
pool, err := v.rndPool()
if err != nil {
return "", err
}

rh := &Header{
Kid: pool.conf.ID,
Alg: pool.conf.Algorithm,
IssuedAt: time.Now().Unix(),
ExpiresAt: time.Now().Add(ttl).Unix(),
}
h, err := json.Marshal(rh)
if err != nil {
return "", err
}
result := base64.StdEncoding.EncodeToString(h)

p, err := json.Marshal(payload)
if err != nil {
return "", err
}
result += "." + base64.StdEncoding.EncodeToString(p)

s, err := v.calcHash(pool.hash, pool.key, []byte(result))
if err != nil {
return "", err
}
result += "." + base64.StdEncoding.EncodeToString(s)

return result, nil
}

func (v *JWT) Verify(token string, payload interface{}) (*Header, error) {
data := strings.Split(token, ".")
if len(data) != 3 {
return nil, fmt.Errorf("invalid jwt format")
}

h, err := base64.StdEncoding.DecodeString(data[0])
if err != nil {
return nil, err
}
header := &Header{}
if err = json.Unmarshal(h, header); err != nil {
return nil, err
}

pool, err := v.getPool(header.Kid)
if err != nil {
return nil, err
}

if header.Alg != pool.conf.Algorithm {
return nil, fmt.Errorf("invalid jwt algorithm")
}
if header.ExpiresAt < time.Now().Unix() {
return nil, fmt.Errorf("jwt expired")
}

expected, err := base64.StdEncoding.DecodeString(data[2])
if err != nil {
return nil, err
}
actual, err := v.calcHash(pool.hash, pool.key, []byte(data[0]+"."+data[1]))
if err != nil {
return nil, err
}
if !hmac.Equal(expected, actual) {
return nil, fmt.Errorf("invalid jwt signature")
}

p, err := base64.StdEncoding.DecodeString(data[1])
if err != nil {
return nil, err
}

if err = json.Unmarshal(p, payload); err != nil {
return nil, err
}

return header, nil
}
106 changes: 106 additions & 0 deletions auth/jwt/jwt_easyjson.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions auth/jwt/jwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package jwt_test

import (
"testing"
"time"

"github.com/deweppro/go-sdk/auth/jwt"
"github.com/stretchr/testify/require"
)

type demoJwtPayload struct {
ID int `json:"id"`
}

func TestUnit_NewJWT(t *testing.T) {
conf := make([]jwt.Config, 0)
conf = append(conf, jwt.Config{ID: "789456", Key: "123456789123456789123456789", Algorithm: jwt.AlgHS256})
j, err := jwt.New(conf)
require.NoError(t, err)

payload1 := demoJwtPayload{ID: 159}
token, err := j.Sign(&payload1, time.Hour)
require.NoError(t, err)

payload2 := demoJwtPayload{}
head1, err := j.Verify(token, &payload2)
require.NoError(t, err)

require.Equal(t, payload1, payload2)

head2, err := j.Verify(token, &payload2)
require.NoError(t, err)
require.Equal(t, head1, head2)
}
14 changes: 7 additions & 7 deletions auth/isp.go → auth/oauth/isp.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package auth
package oauth

import (
"context"
Expand All @@ -15,22 +15,22 @@ var (
)

type (
UserOAuth interface {
User interface {
GetName() string
GetEmail() string
GetIcon() string
}

OAuthProvider interface {
Provider interface {
Code() string
Config(conf ConfigOAuthItem)
Config(conf ConfigItem)
AuthCodeURL() string
AuthCodeKey() string
Exchange(ctx context.Context, code string) (UserOAuth, error)
Exchange(ctx context.Context, code string) (User, error)
}
)

func (v *OAuth) AddProviders(p ...OAuthProvider) {
func (v *OAuth) AddProviders(p ...Provider) {
v.mux.Lock()
defer v.mux.Unlock()

Expand All @@ -44,7 +44,7 @@ func (v *OAuth) AddProviders(p ...OAuthProvider) {
}
}

func (v *OAuth) GetProvider(name string) (OAuthProvider, error) {
func (v *OAuth) GetProvider(name string) (Provider, error) {
v.mux.RLock()
defer v.mux.RUnlock()

Expand Down
6 changes: 3 additions & 3 deletions auth/isp_google.go → auth/oauth/isp_google.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package auth
package oauth

//go:generate easyjson

Expand Down Expand Up @@ -66,7 +66,7 @@ func (v *IspGoogle) Code() string {
return CodeGoogle
}

func (v *IspGoogle) Config(c ConfigOAuthItem) {
func (v *IspGoogle) Config(c ConfigItem) {
v.oauth = &oauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Expand All @@ -92,7 +92,7 @@ func (v *IspGoogle) AuthCodeKey() string {
return v.config.AuthCodeKey
}

func (v *IspGoogle) Exchange(ctx context.Context, code string) (UserOAuth, error) {
func (v *IspGoogle) Exchange(ctx context.Context, code string) (User, error) {
m := &UserGoogle{}
if err := oauth2ExchangeContext(ctx, code, v.config.RequestURL, v.oauth, m); err != nil {
return nil, err
Expand Down
Loading

0 comments on commit c05c9ac

Please sign in to comment.