/
OAuthSignatureMethod.scala
154 lines (131 loc) · 5.65 KB
/
OAuthSignatureMethod.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/*
* Copyright 2010-2011 WorldWide Conferencing, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.liftweb
package oauth
import java.net.URI
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import common._
import util.Helpers
import net.liftweb.http._
abstract class OAuthSignatureMethod(accessor: OAuthAccessor) {
def validate(message: OAuthMessage): Box[OAuthMessage] =
for {
signature <- message.getSignature
sigMethod <- message.getSignatureMethod
baseString = getBaseString(message)
bs2 <- Full(baseString).filter(bs => isValid(signature.value, bs)) ?~
OAuthUtil.Problems.SIGNATURE_INVALID._1 ~>
OAuthProblem(OAuthUtil.Problems.SIGNATURE_INVALID, ("oauth_signature", signature.value) ::
("oauth_signature_base_string", baseString) ::
("oauth_signature_method", sigMethod.value) :: Nil)
} yield message
def getBaseString(message: OAuthMessage): String =
OAuthUtil.percentEncode(message.method.method ) + "&" +
OAuthUtil.percentEncode(normalizeUrl(message.uri)) + "&" +
OAuthUtil.percentEncode(normalizeParameters(message.parameters))
private[oauth] def normalizeUrl(url: String) = {
val uri = new URI(url);
val scheme = uri.getScheme().toLowerCase()
var authority = uri.getAuthority().toLowerCase()
val dropPort = (scheme.equals("http") && uri.getPort() == 80) || (scheme.equals("https") && uri.getPort() == 443)
if (dropPort) {
// find the last : in the authority
val index = authority.lastIndexOf(":")
if (index >= 0) {
authority = authority.substring(0, index)
}
}
var path = uri.getRawPath()
if (path == null || path.length() <= 0) {
path = "/" // conforms to RFC 2616 section 3.2.2
}
// we know that there is no query and no fragment here.
scheme + "://" + authority + path
}
private def normalizeParameters(parameters: List[OAuthUtil.Parameter]) = {
val filteredParameters = parameters.filter(_.name != OAuthUtil.OAUTH_SIGNATURE)
val sortedParameters = filteredParameters.sortWith((p1, p2) => {
val k1 = OAuthUtil.percentEncode(p1.name) + ' ' + OAuthUtil.percentEncode(p1.value)
val k2 = OAuthUtil.percentEncode(p2.name) + ' ' + OAuthUtil.percentEncode(p2.value)
k1.compareTo(k2) <= 0
})
OAuthUtil.formEncode(sortedParameters)
}
def getConsumerSecret: String = accessor.consumerSecret
def getTokenSecret: Box[String] = accessor.tokenSecret
def isValid(signature: String, baseString: String): Boolean
def getSignature(baseString: String): Box[String]
}
object OAuthSignatureMethod {
val SIGNATURE_METHODS = Map("HMAC-SHA1" -> HMAC_SHA1,
"PLAINTEXT" -> PLAINTEXT)
def newSigner(message: OAuthMessage, accessor: OAuthAccessor): Box[OAuthSignatureMethod] = {
for {
sigMeth <- message.getSignatureMethod
meth <- Box(SIGNATURE_METHODS.get(sigMeth.value.toUpperCase)) ?~
OAuthUtil.Problems.SIGNATURE_METHOD_REJECTED._1 ~>
OAuthProblem(OAuthUtil.Problems.SIGNATURE_METHOD_REJECTED, (OAuthUtil.ProblemParams.OAUTH_ACCEPTABLE_SIGNATURE_METHODS,
OAuthUtil.percentEncode(SIGNATURE_METHODS.keySet.toList)))
} yield meth(accessor)
}
}
trait OAuthSignatureMethodBuilder {
def apply(accessor: OAuthAccessor): OAuthSignatureMethod
}
// HMAC_SHA1 Signature Generator
class HMAC_SHA1(accessor: OAuthAccessor) extends OAuthSignatureMethod(accessor) {
private val ENCODING = OAuthUtil.ENCODING
private val MAC_NAME = "HmacSHA1"
override def isValid(signature: String, baseString: String) = {
Thread.sleep(Helpers.randomLong(10)) // Avoid a timing attack
Helpers.secureEquals(getSignature(baseString) openOr
signature.reverse.toString, signature)
}
override def getSignature(baseString: String) = for {
cs <- computeSignature(baseString)
} yield Helpers.base64Encode(cs)
def computeSignature(baseString: String): Box[Array[Byte]] =
for {
ts <- getTokenSecret
} yield {
val keyString = OAuthUtil.percentEncode(getConsumerSecret) + '&' + OAuthUtil.percentEncode(ts)
val keyBytes = keyString.getBytes(ENCODING)
val key = new SecretKeySpec(keyBytes, MAC_NAME)
val mac = Mac.getInstance(MAC_NAME)
mac.init(key)
val text = baseString.getBytes(ENCODING)
mac.doFinal(text)
}
}
object HMAC_SHA1 extends OAuthSignatureMethodBuilder {
def apply(accessor: OAuthAccessor): OAuthSignatureMethod = new HMAC_SHA1(accessor)
}
// Plaintext Signature Generator
class PLAINTEXT(accessor: OAuthAccessor) extends OAuthSignatureMethod(accessor) {
override def isValid(signature: String, baseString: String) = {
Thread.sleep(Helpers.randomLong(10)) // Avoid a timing attack
Helpers.secureEquals(getSignature(baseString) openOr
signature.reverse.toString, signature)
}
override def getSignature(baseString: String): Box[String] =
for {
ts <- getTokenSecret
} yield OAuthUtil.percentEncode(getConsumerSecret) + '&' + OAuthUtil.percentEncode(ts)
}
object PLAINTEXT extends OAuthSignatureMethodBuilder {
def apply(accessor: OAuthAccessor): OAuthSignatureMethod = new PLAINTEXT(accessor)
}