-
Notifications
You must be signed in to change notification settings - Fork 787
/
FrameTranscoder.scala
189 lines (151 loc) · 5.38 KB
/
FrameTranscoder.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/*
* Copyright 2013 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.http4s.websocket
import scodec.bits.ByteVector
import java.nio.ByteBuffer
private[http4s] object FrameTranscoder {
final class TranscodeError(val message: String) extends Exception(message)
private def decodeBinary(in: ByteBuffer, mask: Array[Byte]) = {
val data = new Array[Byte](in.remaining)
in.get(data)
if (mask != null) // We can use the charset decode
for (i <- data.indices)
data(i) = (data(i) ^ mask(i & 0x3)).toByte // i mod 4 is the same as i & 0x3 but slower
data
}
private def lengthOffset(in: ByteBuffer) = {
val len = in.get(1) & LENGTH
if (len < 126) 2
else if (len == 126) 4
else if (len == 127) 10
else throw new FrameTranscoder.TranscodeError("Length error!")
}
private def getMask(in: ByteBuffer): Array[Byte] = {
val m = new Array[Byte](4)
in.mark
in.position(lengthOffset(in))
in.get(m)
in.reset
m
}
private def bodyLength(in: ByteBuffer) = {
val len = in.get(1) & LENGTH
if (len < 126) len
else if (len == 126) (in.get(2) << 8 & 0xff00) | (in.get(3) & 0xff)
else if (len == 127) {
val l = in.getLong(2)
if (l > Integer.MAX_VALUE) throw new FrameTranscoder.TranscodeError("Frame is too long")
else l.toInt
} else throw new FrameTranscoder.TranscodeError("Length error")
}
private def getMsgLength(in: ByteBuffer) = {
var totalLen = 2
if ((in.get(1) & MASK) != 0) totalLen += 4
val len = in.get(1) & LENGTH
if (len == 126) totalLen += 2
if (len == 127) totalLen += 8
if (in.remaining < totalLen)
-1
else {
totalLen += bodyLength(in)
if (in.remaining < totalLen) -1
else totalLen
}
}
}
class FrameTranscoder(val isClient: Boolean) {
def frameToBuffer(in: WebSocketFrame): Array[ByteBuffer] = {
var size = 2
if (isClient) size += 4 // for the mask
if (in.length < 126) {
/* NOOP */
} else if (in.length <= 0xffff) size += 2
else size += 8
val buff = ByteBuffer.allocate(if (isClient) size + in.length else size)
val opcode = in.opcode
if (in.length > 125 && (opcode == PING || opcode == PONG || opcode == CLOSE))
throw new FrameTranscoder.TranscodeError("Invalid PING frame: frame too long: " + in.length)
// First byte. Finished, reserved, and OP CODE
val b1 = if (in.last) opcode | FINISHED else opcode
buff.put(b1.byteValue)
// Second byte. Mask bit and length
var b2 = 0x0
if (isClient) b2 = MASK
if (in.length < 126) b2 |= in.length
else if (in.length <= 0xffff) b2 |= 126
else b2 |= 127
buff.put(b2.byteValue)
// Put the length if we have an extended length packet
if (in.length > 125 && in.length <= 0xffff) {
buff.put((in.length >>> 8 & 0xff).toByte).put((in.length & 0xff).toByte)
()
} else if (in.length > 0xffff) {
buff.putLong(in.length.toLong)
()
}
// If we are a client, we need to mask the data, else just wrap it in a buffer and done
if (isClient && in.length > 0) { // need to mask outgoing bytes
val mask = (Math.random * Integer.MAX_VALUE).toInt
val maskBits = Array(
((mask >>> 24) & 0xff).toByte,
((mask >>> 16) & 0xff).toByte,
((mask >>> 8) & 0xff).toByte,
((mask >>> 0) & 0xff).toByte,
)
buff.put(maskBits)
val data = in.data
for (i <- 0 until in.length.toInt)
buff.put(
(data(i.toLong) ^ maskBits(i & 0x3)).toByte
) // i & 0x3 is the same as i % 4 but faster
buff.flip
Array[ByteBuffer](buff)
} else {
buff.flip
Array[ByteBuffer](buff, in.data.toByteBuffer)
}
}
/** Method that decodes ByteBuffers to objects. None reflects not enough data to decode a message
* Any unused data in the ByteBuffer will be recycled and available for the next read
*
* @param in ByteBuffer of immediately available data
* @return optional message if enough data was available
*/
def bufferToFrame(in: ByteBuffer): WebSocketFrame =
if (in.remaining < 2 || FrameTranscoder.getMsgLength(in) < 0)
null
else {
val opcode = in.get(0) & OP_CODE
val finished = (in.get(0) & FINISHED) != 0
val masked = (in.get(1) & MASK) != 0
if (masked && isClient)
throw new FrameTranscoder.TranscodeError("Client received a masked message")
var bodyOffset = FrameTranscoder.lengthOffset(in)
val m = if (masked) {
bodyOffset += 4
FrameTranscoder.getMask(in)
} else
null
val oldLim = in.limit()
val bodylen = FrameTranscoder.bodyLength(in)
in.position(bodyOffset)
in.limit(in.position() + bodylen)
val slice = in.slice
in.position(in.limit)
in.limit(oldLim)
makeFrame(opcode, ByteVector.view(FrameTranscoder.decodeBinary(slice, m)), finished)
}
}