forked from vapor/vapor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
HTTPServerUpgradeHandler.swift
118 lines (100 loc) · 4.43 KB
/
HTTPServerUpgradeHandler.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import NIO
import NIOHTTP1
import NIOWebSocket
final class HTTPServerUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler {
typealias InboundIn = Request
typealias OutboundIn = Response
typealias OutboundOut = Response
private enum UpgradeState {
case ready
case pending(Request, UpgradeBufferHandler)
case upgraded
}
private var upgradeState: UpgradeState
let httpRequestDecoder: ByteToMessageHandler<HTTPRequestDecoder>
let httpHandlers: [RemovableChannelHandler]
init(
httpRequestDecoder: ByteToMessageHandler<HTTPRequestDecoder>,
httpHandlers: [RemovableChannelHandler]
) {
self.upgradeState = .ready
self.httpRequestDecoder = httpRequestDecoder
self.httpHandlers = httpHandlers
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let req = self.unwrapInboundIn(data)
// check if request is upgrade
let connectionHeaders = Set(req.headers[canonicalForm: "connection"].map { $0.lowercased() })
if connectionHeaders.contains("upgrade") {
let buffer = UpgradeBufferHandler()
_ = context.channel.pipeline.addHandler(buffer, position: .before(self.httpRequestDecoder))
self.upgradeState = .pending(req, buffer)
}
context.fireChannelRead(data)
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let res = self.unwrapOutboundIn(data)
// check upgrade
switch self.upgradeState {
case .pending(let req, let buffer):
self.upgradeState = .upgraded
if res.status == .switchingProtocols, let upgrader = res.upgrader {
switch upgrader {
case .webSocket(let maxFrameSize, let onUpgrade):
let maxFrameSize = maxFrameSize ?? 1 << 14
let webSocketUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: maxFrameSize, automaticErrorHandling: false, shouldUpgrade: { channel, _ in
return channel.eventLoop.makeSucceededFuture([:])
}, upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel, onUpgrade: onUpgrade)
})
var head = HTTPRequestHead(
version: req.version,
method: req.method,
uri: req.url.string
)
head.headers = req.headers
webSocketUpgrader.buildUpgradeResponse(
channel: context.channel,
upgradeRequest: head,
initialResponseHeaders: [:]
).map { headers in
res.headers = headers
context.write(self.wrapOutboundOut(res), promise: promise)
}.flatMap {
let handlers: [RemovableChannelHandler] = [self] + self.httpHandlers
return .andAllComplete(handlers.map { handler in
return context.pipeline.removeHandler(handler)
}, on: context.eventLoop)
}.flatMap {
return webSocketUpgrader.upgrade(context: context, upgradeRequest: head)
}.flatMap {
return context.pipeline.removeHandler(buffer)
}.cascadeFailure(to: promise)
}
} else {
// reset handlers
self.upgradeState = .ready
context.channel.pipeline.removeHandler(buffer, promise: nil)
context.write(self.wrapOutboundOut(res), promise: promise)
}
case .ready, .upgraded:
context.write(self.wrapOutboundOut(res), promise: promise)
}
}
}
private final class UpgradeBufferHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
var buffer: [ByteBuffer]
init() {
self.buffer = []
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let data = self.unwrapInboundIn(data)
self.buffer.append(data)
}
func handlerRemoved(context: ChannelHandlerContext) {
for data in self.buffer {
context.fireChannelRead(NIOAny(data))
}
}
}