diff --git a/Package.swift b/Package.swift index 1d0257e..23d3248 100644 --- a/Package.swift +++ b/Package.swift @@ -1,24 +1,19 @@ -// swift-tools-version:4.2 +// swift-tools-version:5.1 import PackageDescription let package = Package( name: "Gatekeeper", + platforms: [ + .macOS(.v10_14) + ], products: [ - .library( - name: "Gatekeeper", - targets: ["Gatekeeper"]), + .library(name: "Gatekeeper", targets: ["Gatekeeper"]), ], dependencies: [ - .package(url: "https://github.com/vapor/vapor.git", from: "3.0.0"), + .package(url: "https://github.com/vapor/vapor.git", from: "4.0.0-beta") ], targets: [ - .target( - name: "Gatekeeper", - dependencies: [ - "Vapor" - ]), - .testTarget( - name: "GatekeeperTests", - dependencies: ["Gatekeeper"]), + .target(name: "Gatekeeper", dependencies: ["Vapor"]), + .testTarget(name: "GatekeeperTests", dependencies: ["Gatekeeper"]), ] ) diff --git a/README.md b/README.md index f7f876e..fad99f4 100644 --- a/README.md +++ b/README.md @@ -18,15 +18,15 @@ It works by adding the clients IP address to the cache and count how many reques Update your `Package.swift` dependencies: ```swift -.package(url: "https://github.com/nodes-vapor/gatekeeper.git", from: "3.0.0"), + .package(url: "https://github.com/nodes-vapor/gatekeeper.git", from: "4.0.0"), ``` as well as to your target (e.g. "App"): ```swift -targets: [ - .target(name: "App", dependencies: [..., "Gatekeeper", ...]), -// ... + targets: [ + .target(name: "App", dependencies: [..., "Gatekeeper", ...]), + // ... ] ``` @@ -34,43 +34,25 @@ targets: [ ### Configuration -in configure.swift: -```swift -import Gatekeeper - -// [...] - -// Register providers first -try services.register( - GatekeeperProvider( - config: GatekeeperConfig(maxRequests: 10, per: .second), - cacheFactory: { container -> KeyedCache in - return try container.make() - } - ) -) -``` +**Cache** -### Add to routes +You must implement the protocol GateKeeperCache and register it with the application before using GateKeeper -You can add the `GatekeeperMiddleware` to specific routes or to all. - -**Specific routes** -in routes.swift: ```swift -let protectedRoutes = router.grouped(GatekeeperMiddleware.self) -protectedRoutes.get("protected/hello") { req in - return "Protected Hello, World!" -} + app.register(GateKeeperCache.self) { (app: Application) -> GateKeeperCache in + return MyGateKeeperCache() + } ``` + **For all requests** in configure.swift: ```swift -// Register middleware -var middlewares = MiddlewareConfig() // Create _empty_ middleware config -middlewares.use(GatekeeperMiddleware.self) -services.register(middlewares) + +// Register providers first + let gateKeeperConfig = GatekeeperConfig(maxRequests: 10, per: .second) + app.provider(GatekeeperProvider(config: gateKeeperConfig)( + ``` ## Credits 🏆 diff --git a/Sources/Gatekeeper/GateKeeperCache.swift b/Sources/Gatekeeper/GateKeeperCache.swift new file mode 100644 index 0000000..163484c --- /dev/null +++ b/Sources/Gatekeeper/GateKeeperCache.swift @@ -0,0 +1,20 @@ +// +// File.swift +// +// +// Created by Tommy Hinrichsen on 02/12/2019. +// + +import Foundation +import Vapor + +public protocol GateKeeperCache { + + /// Gets key as a decodable type. + func get(_ key: String, as type: D.Type) -> EventLoopFuture where D: Decodable + + /// Sets key to an encodable item. + func set(_ key: String, to entity: E) -> EventLoopFuture where E: Encodable +} + + diff --git a/Sources/Gatekeeper/GateKeeperError.swift b/Sources/Gatekeeper/GateKeeperError.swift new file mode 100644 index 0000000..ed45f7c --- /dev/null +++ b/Sources/Gatekeeper/GateKeeperError.swift @@ -0,0 +1,13 @@ +// +// File.swift +// +// +// Created by Tommy Hinrichsen on 04/12/2019. +// + +import Foundation + +enum GateKeeperError: Swift.Error { + case forbidden + case tooManyRequests +} diff --git a/Sources/Gatekeeper/Gatekeeper.swift b/Sources/Gatekeeper/Gatekeeper.swift index 202d0e5..27548bc 100644 --- a/Sources/Gatekeeper/Gatekeeper.swift +++ b/Sources/Gatekeeper/Gatekeeper.swift @@ -1,45 +1,32 @@ import Vapor -public struct Gatekeeper: Service { +public struct Gatekeeper { internal let config: GatekeeperConfig - internal let cacheFactory: ((Container) throws -> KeyedCache) + internal let cache: GateKeeperCache - public init( - config: GatekeeperConfig, - cacheFactory: @escaping ((Container) throws -> KeyedCache) = { container in try container.make() } - ) { + public init(config: GatekeeperConfig, cache: GateKeeperCache) { self.config = config - self.cacheFactory = cacheFactory + self.cache = cache } - public func accessEndpoint( - on request: Request - ) throws -> Future { + internal func accessEndpoint(on request: Request) throws -> EventLoopFuture { - guard let peerHostName = request.http.remotePeer.hostname else { - throw Abort( - .forbidden, - reason: "Unable to verify peer" - ) + guard let ipAddress = request.remoteAddress?.ipAddress else { + return request.eventLoop.makeFailedFuture(GateKeeperError.forbidden) } - let peerCacheKey = cacheKey(for: peerHostName) - let cache = try cacheFactory(request) + let peerCacheKey = self.cacheKey(for: ipAddress) - return cache.get(peerCacheKey, as: Entry.self) - .map(to: Entry.self) { entry in + return self.cache.get(peerCacheKey, as: Entry.self) + .map({ entry -> Gatekeeper.Entry in if let entry = entry { return entry } else { - return Entry( - peerHostname: peerHostName, - createdAt: Date(), - requestsLeft: self.config.limit - ) + return Entry(ipAddress: ipAddress, createdAt: Date(), requestsLeft: self.config.limit) } - } - .map(to: Entry.self) { entry in + }) + .map({ entry -> Gatekeeper.Entry in let now = Date() var mutableEntry = entry @@ -49,28 +36,22 @@ public struct Gatekeeper: Service { } mutableEntry.requestsLeft -= 1 return mutableEntry - }.then { entry in - return cache.set(peerCacheKey, to: entry).transform(to: entry) - }.map(to: Entry.self) { entry in - - if entry.requestsLeft < 0 { - throw Abort( - .tooManyRequests, - reason: "Slow down. You sent too many requests." - ) - } + }) + .flatMap( { entry -> EventLoopFuture in + return self.cache.set(peerCacheKey, to: entry).map { entry } + }) + .flatMapThrowing({ entry in + if entry.requestsLeft < 0 { throw GateKeeperError.tooManyRequests } return entry - } + }) } - private func cacheKey(for hostname: String) -> String { - return "gatekeeper_\(hostname)" - } + private func cacheKey(for hostname: String) -> String { return "gatekeeper_\(hostname)" } } extension Gatekeeper { public struct Entry: Codable { - let peerHostname: String + let ipAddress: String var createdAt: Date var requestsLeft: Int } diff --git a/Sources/Gatekeeper/GatekeeperConfig.swift b/Sources/Gatekeeper/GatekeeperConfig.swift index 79597d6..2468b2e 100644 --- a/Sources/Gatekeeper/GatekeeperConfig.swift +++ b/Sources/Gatekeeper/GatekeeperConfig.swift @@ -1,6 +1,6 @@ import Vapor -public struct GatekeeperConfig: Service { +public struct GatekeeperConfig { public enum Interval { case second diff --git a/Sources/Gatekeeper/GatekeeperMiddleware.swift b/Sources/Gatekeeper/GatekeeperMiddleware.swift index 3f91575..bb0280a 100644 --- a/Sources/Gatekeeper/GatekeeperMiddleware.swift +++ b/Sources/Gatekeeper/GatekeeperMiddleware.swift @@ -5,19 +5,12 @@ public struct GatekeeperMiddleware { } extension GatekeeperMiddleware: Middleware { - public func respond( - to request: Request, - chainingTo next: Responder - ) throws -> Future { - - return try gatekeeper.accessEndpoint(on: request).flatMap { _ in - return try next.respond(to: request) + public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture { + do { + let response = try gatekeeper.accessEndpoint(on: request).flatMap { _ in return next.respond(to: request) } + return response + } catch { + return request.eventLoop.makeFailedFuture(error) } } } - -extension GatekeeperMiddleware: ServiceType { - public static func makeService(for container: Container) throws -> GatekeeperMiddleware { - return try .init(gatekeeper: container.make()) - } -} diff --git a/Sources/Gatekeeper/GatekeeperProvider.swift b/Sources/Gatekeeper/GatekeeperProvider.swift index 551f7d6..1bb9c2a 100644 --- a/Sources/Gatekeeper/GatekeeperProvider.swift +++ b/Sources/Gatekeeper/GatekeeperProvider.swift @@ -3,31 +3,23 @@ import Vapor public final class GatekeeperProvider { internal let config: GatekeeperConfig - internal let cacheFactory: ((Container) throws -> KeyedCache) - public init( - config: GatekeeperConfig, - cacheFactory: @escaping ((Container) throws -> KeyedCache) = { container in try container.make() } - ) { + public init(config: GatekeeperConfig = GatekeeperConfig(maxRequests: 10, per: .second)) { self.config = config - self.cacheFactory = cacheFactory } } extension GatekeeperProvider: Provider { - public func register(_ services: inout Services) throws { - services.register(config) - services.register( - Gatekeeper( - config: config, - cacheFactory: cacheFactory - ), - as: Gatekeeper.self - ) - services.register(GatekeeperMiddleware.self) - } - public func didBoot(_ container: Container) throws -> EventLoopFuture { - return .done(on: container) + public func register(_ app: Application) { + + app.register(extension: MiddlewareConfiguration.self) { (configuration: inout MiddlewareConfiguration, app: Application) in + + let cache: GateKeeperCache = app.make() + let gateKeeper = Gatekeeper(config: self.config, cache: cache) + let middleware = GatekeeperMiddleware(gatekeeper: gateKeeper) + configuration.use(middleware) + } } + } diff --git a/Tests/GatekeeperTests/GatekeeperTests.swift b/Tests/GatekeeperTests/GatekeeperTests.swift index 60b1bf3..7316352 100644 --- a/Tests/GatekeeperTests/GatekeeperTests.swift +++ b/Tests/GatekeeperTests/GatekeeperTests.swift @@ -11,7 +11,7 @@ class GatekeeperTests: XCTestCase { peerName: "::1" ) - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) + let gatekeeperMiddleware = request.make(GatekeeperMiddleware.self) for i in 1...11 { do { @@ -39,7 +39,7 @@ class GatekeeperTests: XCTestCase { peerName: nil ) - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) + let gatekeeperMiddleware = request.make(GatekeeperMiddleware.self) do { _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() @@ -64,7 +64,7 @@ class GatekeeperTests: XCTestCase { peerName: "192.168.1.2" ) - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) + let gatekeeperMiddleware = request.make(GatekeeperMiddleware.self) for _ in 0..<50 { do { @@ -75,7 +75,7 @@ class GatekeeperTests: XCTestCase { } } - let cache = try request.make(KeyedCache.self) + let cache = request.make(GateKeeperCache.self) var entry = try cache.get("gatekeeper_192.168.1.2", as: Gatekeeper.Entry.self).wait() XCTAssertEqual(entry!.requestsLeft, 50) @@ -90,34 +90,33 @@ class GatekeeperTests: XCTestCase { XCTAssertEqual(entry!.requestsLeft, 99, "Requests left should've reset") } - func testGateKeeperWithCacheFactory() throws { - - let request = try Request.test( - gatekeeperConfig: GatekeeperConfig(maxRequests: 10, per: .minute), - peerName: "::1", - cacheFactory: { try $0.make(KeyedCache.self) } - ) - - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) - - for i in 1...11 { - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - XCTAssertTrue(i <= 10, "ran \(i) times.") - } catch let error as Abort { - switch error.status { - case .tooManyRequests: - //success - XCTAssertEqual(i, 11, "Should've failed after the 11th attempt.") - break - default: - XCTFail("Expected too many request: \(error)") - } - } catch { - XCTFail("Caught wrong error: \(error)") - } - } - } +// func testGateKeeperWithCacheFactory() throws { +// +// let request = try Request.test( +// gatekeeperConfig: GatekeeperConfig(maxRequests: 10, per: .minute), +// peerName: "::1", +// ) +// +// let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) +// +// for i in 1...11 { +// do { +// _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() +// XCTAssertTrue(i <= 10, "ran \(i) times.") +// } catch let error as Abort { +// switch error.status { +// case .tooManyRequests: +// //success +// XCTAssertEqual(i, 11, "Should've failed after the 11th attempt.") +// break +// default: +// XCTFail("Expected too many request: \(error)") +// } +// } catch { +// XCTFail("Caught wrong error: \(error)") +// } +// } +// } func testRefreshIntervalValues() { let expected: [(GatekeeperConfig.Interval, Double)] = [ diff --git a/Tests/GatekeeperTests/Utilities/GateKeeperTestError.swift b/Tests/GatekeeperTests/Utilities/GateKeeperTestError.swift new file mode 100644 index 0000000..4f17a91 --- /dev/null +++ b/Tests/GatekeeperTests/Utilities/GateKeeperTestError.swift @@ -0,0 +1,12 @@ +// +// GateKeeperTestError.swift +// GatekeeperTests +// +// Created by Tommy Hinrichsen on 04/12/2019. +// + +import Foundation + +enum GateKeeperTestError: Swift.Error { + case notFound +} diff --git a/Tests/GatekeeperTests/Utilities/MemoryCache.swift b/Tests/GatekeeperTests/Utilities/MemoryCache.swift new file mode 100644 index 0000000..c92f7a0 --- /dev/null +++ b/Tests/GatekeeperTests/Utilities/MemoryCache.swift @@ -0,0 +1,53 @@ +// +// MemoryCache.swift +// GatekeeperTests +// +// Created by Tommy Hinrichsen on 04/12/2019. +// + +import Foundation +import Gatekeeper +import Vapor + +final class GateKeeperCacheMemoryCache: GateKeeperCache { + + var storage: [String: Any] + var lock: Lock + + init() { + self.storage = [:] + self.lock = .init() + + } + + public func get(_ key: String, as type: D.Type) -> EventLoopFuture where D : Decodable { + + let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() + if let value: D = self.get(key) { + let future = eventLoop.makeSucceededFuture(value) + return future.map{ $0 } + } else { + let future: EventLoopFuture = eventLoop.makeFailedFuture(GateKeeperTestError.notFound) + return future + } + } + + public func set(_ key: String, to entity: E) -> EventLoopFuture where E : Encodable { + let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() + let future = eventLoop.makeSucceededFuture(Void()) + return future + } + + func get(_ key: String) -> E? where E : Decodable { + self.lock.lock() + defer { self.lock.unlock() } + return self.storage[key] as? E + } + + func set(_ key: String, to value: E?) where E : Encodable { + self.lock.lock() + defer { self.lock.unlock() } + self.storage[key] = value + } +} + diff --git a/Tests/GatekeeperTests/Utilities/Request+test.swift b/Tests/GatekeeperTests/Utilities/Request+test.swift index 04d8088..3294170 100644 --- a/Tests/GatekeeperTests/Utilities/Request+test.swift +++ b/Tests/GatekeeperTests/Utilities/Request+test.swift @@ -1,56 +1,29 @@ import Gatekeeper -import HTTP import Vapor extension Request { static func test( - gatekeeperConfig: GatekeeperConfig, - url: URLRepresentable = "http://localhost:8080/test", - peerName: String? = "::1", - cacheFactory: ((Container) throws -> KeyedCache)? = nil + url: URI = URI(string: "http://localhost:8080/test"), + gatekeeperConfig: GatekeeperConfig = GatekeeperConfig(maxRequests: 10, per: .second), + peerName: String? = "::1" ) throws -> Request { - let config = Config() - var services = Services() - services.register(KeyedCache.self) { container in - return MemoryKeyedCache() + let app = Application(environment: .development) + app.register(GateKeeperCache.self) { (app: Application) in + return GateKeeperCacheMemoryCache() } + app.provider(GatekeeperProvider(config: gatekeeperConfig)) - if let cacheFactory = cacheFactory { - try services.register( - GatekeeperProvider( - config: gatekeeperConfig, - cacheFactory: cacheFactory - ) - ) - } else { - try services.register( - GatekeeperProvider( - config: gatekeeperConfig - ) - ) - } - - services.register(GatekeeperMiddleware.self) - - let sharedThreadPool = BlockingIOThreadPool(numberOfThreads: 2) - sharedThreadPool.start() - services.register(sharedThreadPool) - - let app = try Application(config: config, environment: .testing, services: services) + let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1) //??? let request = Request( - http: HTTPRequest( - method: .GET, - url: url - ), - using: app + application: app, + url: url, + on: eventLoop.next() ) - var http = request.http if let peerName = peerName { - http.headers.add(name: .init("X-Forwarded-For"), value: peerName) + request.headers.add(name: "X-Forwarded-For", value: peerName) } - request.http = http return request } diff --git a/Tests/GatekeeperTests/Utilities/TestResponder.swift b/Tests/GatekeeperTests/Utilities/TestResponder.swift index 1c53793..d6a3ccd 100644 --- a/Tests/GatekeeperTests/Utilities/TestResponder.swift +++ b/Tests/GatekeeperTests/Utilities/TestResponder.swift @@ -1,7 +1,9 @@ import Vapor public struct TestResponder: Responder { - public func respond(to req: Request) throws -> EventLoopFuture { - return req.future(req.response()) + + public func respond(to req: Request) -> EventLoopFuture { + let response = Response() + return req.eventLoop.future(response) } }