-
Notifications
You must be signed in to change notification settings - Fork 137
/
statestore.go
158 lines (136 loc) · 3.39 KB
/
statestore.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package oidc
import (
"context"
"encoding/hex"
"encoding/json"
"sync"
"time"
"github.com/cozy/cozy-stack/pkg/config/config"
"github.com/cozy/cozy-stack/pkg/crypto"
"github.com/cozy/cozy-stack/pkg/logger"
"github.com/redis/go-redis/v9"
)
const (
stateTTL = 15 * time.Minute
codeTTL = 3 * time.Hour
)
type stateHolder struct {
id string
expiresAt int64
Provider ProviderOIDC
Instance string
Redirect string
Nonce string
Confirm string
}
type ProviderOIDC int
const (
GenericProvider ProviderOIDC = iota
FranceConnectProvider
)
func newStateHolder(domain, redirect, confirm string, provider ProviderOIDC) *stateHolder {
id := hex.EncodeToString(crypto.GenerateRandomBytes(24))
nonce := hex.EncodeToString(crypto.GenerateRandomBytes(24))
return &stateHolder{
id: id,
Provider: provider,
Instance: domain,
Redirect: redirect,
Confirm: confirm,
Nonce: nonce,
}
}
type stateStorage interface {
Add(*stateHolder) error
Find(id string) *stateHolder
CreateCode(sub string) string
GetSub(code string) string
}
type memStateStorage struct {
states map[string]*stateHolder
codes map[string]string // delegated code -> sub
}
func (store memStateStorage) Add(state *stateHolder) error {
state.expiresAt = time.Now().UTC().Add(stateTTL).Unix()
store.states[state.id] = state
return nil
}
func (store memStateStorage) Find(id string) *stateHolder {
state, ok := store.states[id]
if !ok {
return nil
}
if state.expiresAt < time.Now().UTC().Unix() {
delete(store.states, id)
return nil
}
return state
}
func (store memStateStorage) CreateCode(sub string) string {
code := makeCode()
store.codes[code] = sub
return code
}
func (store memStateStorage) GetSub(code string) string {
return store.codes[code]
}
type subRedisInterface interface {
Get(ctx context.Context, key string) *redis.StringCmd
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
}
type redisStateStorage struct {
cl subRedisInterface
ctx context.Context
}
func (store *redisStateStorage) Add(s *stateHolder) error {
serialized, err := json.Marshal(s)
if err != nil {
return err
}
return store.cl.Set(store.ctx, s.id, serialized, stateTTL).Err()
}
func (store *redisStateStorage) Find(id string) *stateHolder {
serialized, err := store.cl.Get(store.ctx, id).Bytes()
if err != nil {
return nil
}
var s stateHolder
err = json.Unmarshal(serialized, &s)
if err != nil {
logger.WithNamespace("redis-state").Errorf(
"Bad state in redis %s", string(serialized))
return nil
}
return &s
}
func (store *redisStateStorage) CreateCode(sub string) string {
code := makeCode()
store.cl.Set(store.ctx, code, sub, codeTTL)
return code
}
func (store *redisStateStorage) GetSub(code string) string {
return store.cl.Get(store.ctx, code).Val()
}
var globalStorage stateStorage
var globalStorageMutex sync.Mutex
func getStorage() stateStorage {
globalStorageMutex.Lock()
defer globalStorageMutex.Unlock()
if globalStorage != nil {
return globalStorage
}
cli := config.GetConfig().OauthStateStorage
if cli == nil {
globalStorage = &memStateStorage{
states: make(map[string]*stateHolder),
codes: make(map[string]string),
}
} else {
ctx := context.Background()
globalStorage = &redisStateStorage{cl: cli, ctx: ctx}
}
return globalStorage
}
func makeCode() string {
return hex.EncodeToString(crypto.GenerateRandomBytes(12))
}