/
signature_validation.go
152 lines (127 loc) · 3.32 KB
/
signature_validation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package aws
import (
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
)
// base64Decode performs a base64 decode on the supplied string
func base64Decode(msg *SNSMessage) (b []byte, err error) {
b, err = base64.StdEncoding.DecodeString(msg.Signature)
if err != nil {
return b, err
}
return b, err
}
// getPemFile obtains a PEM file from the passed url string
func (v *Validator) getPemFile(address string) (body []byte, err error) {
req, err := http.NewRequest("GET", address, nil)
if err != nil {
return
}
resp, err := v.client.Do(req)
if err != nil {
return
}
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
return
}
// getCertificate return a x509 parsed PEM file certificate
func getCerticate(b []byte) (cert *x509.Certificate, err error) {
block, _ := pem.Decode(b)
if block == nil {
return
}
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
return
}
return
}
// formatSignature returns a string formated version of the supplied SNSMessage
//uses message values to replicate signature
// Values are delimited with newline characters
// Name/value pairs are sorted by name in byte sort order.
func formatSignature(msg *SNSMessage) (formated string, err error) {
if msg.Type == "Notification" && msg.Subject != "" {
formated = fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n",
"Message", msg.Message,
"MessageId", msg.MessageId,
"Subject", msg.Subject,
"Timestamp", msg.Timestamp,
"TopicArn", msg.TopicArn,
"Type", msg.Type,
)
} else if msg.Type == "Notification" && msg.Subject == "" {
formated = fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n",
"Message", msg.Message,
"MessageId", msg.MessageId,
"Timestamp", msg.Timestamp,
"TopicArn", msg.TopicArn,
"Type", msg.Type,
)
} else if msg.Type == "SubscriptionConfirmation" || msg.Type == "UnsubscribeConfirmation" {
formated = fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n",
"Message", msg.Message,
"MessageId", msg.MessageId,
"SubscribeURL", msg.SubscribeURL,
"Timestamp", msg.Timestamp,
"Token", msg.Token,
"TopicArn", msg.TopicArn,
"Type", msg.Type,
)
} else {
return formated, errors.New("Unable to determine SNSMessage type")
}
return
}
type Validator struct {
client *http.Client
}
type SNSValidator interface {
Validate(*SNSMessage) (bool, error)
}
func NewValidator(client *http.Client) *Validator {
if client == nil {
client = new(http.Client)
}
v := new(Validator)
v.client = client
return v
}
func NewSNSValidator() SNSValidator {
return NewValidator(nil)
}
// Validator validates an Amazon SNS message signature
func (v *Validator) Validate(msg *SNSMessage) (ok bool, err error) {
var decodedSignature []byte
if decodedSignature, err = base64Decode(msg); err != nil {
return
}
var p []byte
if p, err = v.getPemFile(msg.SigningCertURL); err != nil {
return
}
var cert *x509.Certificate
if cert, err = getCerticate(p); err != nil {
return
}
var formatedSignature string
if formatedSignature, err = formatSignature(msg); err != nil {
return
}
if err = cert.CheckSignature(x509.SHA1WithRSA, []byte(formatedSignature), decodedSignature); err != nil {
// signature verification failed
return
}
// valid signature
ok = true
return
}