-
Notifications
You must be signed in to change notification settings - Fork 0
/
check.go
77 lines (67 loc) · 2.28 KB
/
check.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package blocklist
import (
"github.com/redis/go-redis/v9"
"github.com/divergentcodes/jwt-block/internal/crypto"
)
// A CheckResult contains the result of checking for a token in the blocklist.
type CheckResult struct {
Message string `json:"message"` // message summarizing the result.
IsBlocked bool `json:"blocked"` // whether or not the token is blocked (present in the blocklist).
TTL int `json:"block_ttl_sec"` // remaining time-to-live of the token in the blocklist.
TTLString string `json:"block_ttl_str"` // human readable remaining time-to-live.
IsError bool `json:"error"` // whether or not the result was an error.
}
// CheckByJwt checks if a token's hash value is in the blocklist.
//
// The passed tokenString will be hashed and looked up.
func CheckByJwt(redisDB *redis.Client, tokenString string) (CheckResult, error) {
// Parse, validate, verify the JWT.
var checkResult CheckResult
_, err := crypto.RunJwtChecks(tokenString)
if err != nil {
return checkResult, err
}
key := crypto.Sha256FromString(tokenString)
return CheckBySha256(redisDB, key)
}
// CheckBySha256 checks if the hash value of a token is in the blocklist.
func CheckBySha256(redisDB *redis.Client, sha256 string) (CheckResult, error) {
// Verify the hash.
var checkResult CheckResult
err := crypto.IsValidSha256(sha256)
if err != nil {
return checkResult, err
}
// Perform lookup.
ttl, err := redisDB.TTL(redisContext, sha256).Result()
// Handle errors.
if err == redis.Nil {
return checkResult, nil
} else if err != nil {
return checkResult, err
}
// Process results.
if ttl.Nanoseconds() == -2 {
// Not found in the cache.
checkResult.Message = SuccessTokenIsAllowed
checkResult.IsBlocked = false
checkResult.TTL = -1
checkResult.TTLString = ""
checkResult.IsError = false
} else if ttl.Nanoseconds() == -1 {
// Found, but without a TTL.
checkResult.Message = SuccessTokenIsBlocked
checkResult.IsBlocked = true
checkResult.TTL = 0
checkResult.TTLString = "Inf"
checkResult.IsError = false
} else {
// Found with a defined TTL.
checkResult.Message = SuccessTokenIsBlocked
checkResult.IsBlocked = true
checkResult.TTL = int(ttl.Seconds())
checkResult.TTLString = ttl.String()
checkResult.IsError = false
}
return checkResult, nil
}