-
Notifications
You must be signed in to change notification settings - Fork 63
/
Http4sWSStage.scala
202 lines (180 loc) · 6.91 KB
/
Http4sWSStage.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
/*
* Copyright 2014 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.http4s
package blazecore
package websocket
import cats.effect._
import cats.effect.std.Dispatcher
import cats.effect.std.Semaphore
import cats.syntax.all._
import fs2._
import fs2.concurrent.SignallingRef
import org.http4s.blaze.pipeline.Command.EOF
import org.http4s.blaze.pipeline.LeafBuilder
import org.http4s.blaze.pipeline.TailStage
import org.http4s.blaze.pipeline.TrunkBuilder
import org.http4s.blaze.util.Execution.directec
import org.http4s.blaze.util.Execution.trampoline
import org.http4s.websocket.ReservedOpcodeException
import org.http4s.websocket.UnknownOpcodeException
import org.http4s.websocket.WebSocket
import org.http4s.websocket.WebSocketCombinedPipe
import org.http4s.websocket.WebSocketFrame
import org.http4s.websocket.WebSocketFrame._
import org.http4s.websocket.WebSocketSeparatePipe
import java.net.ProtocolException
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.ExecutionContext
import scala.util.Failure
import scala.util.Success
private[http4s] class Http4sWSStage[F[_]](
ws: WebSocket[F],
sentClose: AtomicBoolean,
deadSignal: SignallingRef[F, Boolean],
writeSemaphore: Semaphore[F],
dispatcher: Dispatcher[F],
)(implicit F: Async[F])
extends TailStage[WebSocketFrame] {
def name: String = "Http4s WebSocket Stage"
// ////////////////////// Source and Sink generators ////////////////////////
val isClosed: F[Boolean] = F.delay(sentClose.get())
val setClosed: F[Boolean] = F.delay(sentClose.compareAndSet(false, true))
def evalFrame(frame: WebSocketFrame): F[Unit] = frame match {
case c: Close => setClosed.ifM(writeFrame(c, directec), F.unit)
case _ => writeFrame(frame, directec)
}
def snkFun(frame: WebSocketFrame): F[Unit] = isClosed.ifM(F.unit, evalFrame(frame))
private[this] def writeFrame(frame: WebSocketFrame, ec: ExecutionContext): F[Unit] =
writeSemaphore.permit.use { _ =>
F.async_[Unit] { cb =>
channelWrite(frame).onComplete {
case Success(res) => cb(Right(res))
case Failure(t) => cb(Left(t))
}(ec)
}
}
private[this] def readFrameTrampoline: F[WebSocketFrame] =
F.async_[WebSocketFrame] { cb =>
channelRead().onComplete {
case Success(ws) => cb(Right(ws))
case Failure(exception) => cb(Left(exception))
}(trampoline)
}
/** Read from our websocket.
*
* To stay faithful to the RFC, the following must hold:
*
* - If we receive a ping frame, we MUST reply with a pong frame
* - If we receive a pong frame, we don't need to forward it.
* - If we receive a close frame, it means either one of two things:
* - We sent a close frame prior, meaning we do not need to reply with one. Just end the stream
* - We are the first to receive a close frame, so we try to atomically check a boolean flag,
* to prevent sending two close frames. Regardless, we set the signal for termination of
* the stream afterwards
*
* @return A websocket frame, or a possible IO error.
*/
private[this] def handleRead(): F[WebSocketFrame] = {
def maybeSendClose(c: Close): F[Unit] =
F.delay(sentClose.compareAndSet(false, true)).flatMap { cond =>
if (cond) writeFrame(c, trampoline)
else F.unit
} >> deadSignal.set(true)
readFrameTrampoline
.recoverWith {
case t: ReservedOpcodeException =>
F.delay(logger.error(t)("Decoded a websocket frame with a reserved opcode")) *>
F.fromEither(Close(1003))
case t: UnknownOpcodeException =>
F.delay(logger.error(t)("Decoded a websocket frame with an unknown opcode")) *>
F.fromEither(Close(1002))
case t: ProtocolException =>
F.delay(logger.error(t)("Websocket protocol violation")) *> F.fromEither(Close(1002))
}
.flatMap {
case c: Close =>
for {
s <- F.delay(sentClose.get())
// If we sent a close signal, we don't need to reply with one
_ <- if (s) deadSignal.set(true) else maybeSendClose(c)
} yield c
case p @ Ping(d) =>
// Reply to ping frame immediately
writeFrame(Pong(d), trampoline) >> F.pure(p)
case rest =>
F.pure(rest)
}
}
/** The websocket input stream
*
* Note: On receiving a close, we MUST send a close back, as stated in section
* 5.5.1 of the websocket spec: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
*
* @return
*/
def inputstream: Stream[F, WebSocketFrame] =
Stream.repeatEval(handleRead())
// ////////////////////// Startup and Shutdown ////////////////////////
override protected def stageStartup(): Unit = {
super.stageStartup()
// Effect to send a close to the other endpoint
val sendClose: F[Unit] = F.delay(closePipeline(None))
val receiveSent: Stream[F, WebSocketFrame] =
ws match {
case WebSocketSeparatePipe(send, receive, _) =>
// We don't need to terminate if the send stream terminates.
send.concurrently(receive(inputstream))
case WebSocketCombinedPipe(receiveSend, _) =>
receiveSend(inputstream)
}
val wsStream =
receiveSent
.evalMap(snkFun)
.drain
.interruptWhen(deadSignal)
.onFinalizeWeak(
ws.onClose.attempt.void
) // Doing it this way ensures `sendClose` is sent no matter what
.onFinalizeWeak(sendClose)
.compile
.drain
val result = F.handleErrorWith(wsStream) {
case EOF =>
F.delay(stageShutdown())
case t =>
F.delay(logger.error(t)("Error closing Web Socket"))
}
dispatcher.unsafeRunAndForget(result)
}
override protected def stageShutdown(): Unit = {
val fa = F.handleError(deadSignal.set(true)) { t =>
logger.error(t)("Error setting dead signal")
}
dispatcher.unsafeRunAndForget(fa)
super.stageShutdown()
}
}
object Http4sWSStage {
def bufferingSegment[F[_]](stage: Http4sWSStage[F]): LeafBuilder[WebSocketFrame] =
TrunkBuilder(new SerializingStage[WebSocketFrame]).cap(stage)
def apply[F[_]](
ws: WebSocket[F],
sentClose: AtomicBoolean,
deadSignal: SignallingRef[F, Boolean],
dispatcher: Dispatcher[F],
)(implicit F: Async[F]): F[Http4sWSStage[F]] =
Semaphore[F](1L).map(t => new Http4sWSStage(ws, sentClose, deadSignal, t, dispatcher))
}