Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions Sources/GRPC/ClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ public class ClientConnection {
// Configure the channel with the correct handlers and connect to our target.
let configuredChannel = ClientConnection.initializeChannel(
channel,
tls: configuration.tls,
tls: configuration.tls?.configuration,
serverHostname: configuration.tls?.hostnameOverride ?? configuration.target.host,
errorDelegate: configuration.errorDelegate
).flatMap {
channel.connect(to: socketAddress)
Expand Down Expand Up @@ -325,21 +326,34 @@ extension ClientConnection {
///
/// - Parameter configuration: The configuration to prepare the bootstrap with.
/// - Parameter group: The `EventLoopGroup` to use for the bootstrap.
/// - Parameter timeout: The connection timeout in seconds.
/// - Parameter timeout: The connection timeout in seconds.
private class func makeBootstrap(
configuration: Configuration,
group: EventLoopGroup,
timeout: TimeInterval?,
logger: Logger
) -> ClientBootstrapProtocol {
// Provide a server hostname if we're using TLS. Prefer the override.
let serverHostname: String? = configuration.tls.map {
if let hostnameOverride = $0.hostnameOverride {
logger.debug("using hostname override for TLS", metadata: ["hostname-override": "\(hostnameOverride)"])
return hostnameOverride
} else {
let host = configuration.target.host
logger.debug("using host connection target for TLS", metadata: ["hostname-override": "\(host)"])
return host
}
}

let bootstrap = PlatformSupport.makeClientBootstrap(group: group)
// Enable SO_REUSEADDR and TCP_NODELAY.
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
.channelInitializer { channel in
initializeChannel(
channel,
tls: configuration.tls,
tls: configuration.tls?.configuration,
serverHostname: serverHostname,
errorDelegate: configuration.errorDelegate
)
}
Expand All @@ -356,14 +370,16 @@ extension ClientConnection {
///
/// - Parameter channel: The channel to initialize.
/// - Parameter tls: The optional TLS configuration for the channel.
/// - Parameter serverHostname: The hostname of the server to use for TLS.
/// - Parameter errorDelegate: Optional client error delegate.
private class func initializeChannel(
_ channel: Channel,
tls: Configuration.TLS?,
tls: TLSConfiguration?,
serverHostname: String?,
errorDelegate: ClientErrorDelegate?
) -> EventLoopFuture<Void> {
let tlsConfigured = tls.map { tlsConfiguration in
channel.configureTLS(tlsConfiguration, errorDelegate: errorDelegate)
let tlsConfigured = tls.map {
channel.configureTLS($0, serverHostname: serverHostname, errorDelegate: errorDelegate)
}

return (tlsConfigured ?? channel.eventLoop.makeSucceededFuture(())).flatMap {
Expand Down Expand Up @@ -484,15 +500,18 @@ fileprivate extension Channel {
/// the `TLSVerificationHandler` which verifies that a successful handshake was completed.
///
/// - Parameter configuration: The configuration to configure the channel with.
/// - Parameter serverHostname: The server hostname to use if the hostname should be verified.
/// - Parameter errorDelegate: The error delegate to use for the TLS verification handler.
func configureTLS(
_ configuration: ClientConnection.Configuration.TLS,
_ configuration: TLSConfiguration,
serverHostname: String?,
errorDelegate: ClientErrorDelegate?
) -> EventLoopFuture<Void> {
do {
let sslClientHandler = try NIOSSLClientHandler(
context: try NIOSSLContext(configuration: configuration.configuration),
serverHostname: configuration.hostnameOverride)
context: try NIOSSLContext(configuration: configuration),
serverHostname: serverHostname
)

return self.pipeline.addHandlers(sslClientHandler, TLSVerificationHandler())
} catch {
Expand Down
254 changes: 156 additions & 98 deletions Sources/GRPCSampleData/GRPCSwiftCertificate.swift

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions Tests/GRPCTests/BasicEchoTestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,14 @@ extension TransportSecurity {
return nil

case .anonymousClient:
return .init(
trustRoots: .certificates([self.caCert]),
certificateVerification: .noHostnameVerification)
return .init(trustRoots: .certificates([self.caCert]))

case .mutualAuthentication:
return .init(
certificateChain: [.certificate(self.clientCert)],
privateKey: .privateKey(SamplePrivateKey.client),
trustRoots: .certificates([self.caCert]),
certificateVerification: .noHostnameVerification)
trustRoots: .certificates([self.caCert])
)
}
}
}
Expand Down
101 changes: 101 additions & 0 deletions Tests/GRPCTests/ClientTLSTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import Foundation
import GRPC
import GRPCSampleData
import NIO
import NIOSSL
import XCTest

class ClientTLSHostnameOverrideTests: GRPCTestCase {
var eventLoopGroup: EventLoopGroup!
var server: Server!
var connection: ClientConnection!

override func setUp() {
super.setUp()
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
}

override func tearDown() {
super.tearDown()
XCTAssertNoThrow(try self.server.close().wait())
XCTAssertNoThrow(try connection.close().wait())
XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully())
}

func makeEchoServer(tls: Server.Configuration.TLS) throws -> Server {
let configuration: Server.Configuration = .init(
target: .hostAndPort("localhost", 0),
eventLoopGroup: self.eventLoopGroup,
serviceProviders: [EchoProvider()],
tls: tls
)

return try Server.start(configuration: configuration).wait()
}

func makeConnection(port: Int, tls: ClientConnection.Configuration.TLS) -> ClientConnection {
let configuration: ClientConnection.Configuration = .init(
target: .hostAndPort("localhost", port),
eventLoopGroup: self.eventLoopGroup,
tls: tls
)

return ClientConnection(configuration: configuration)
}

func doTestUnary() throws {
let client = Echo_EchoServiceClient(connection: self.connection)
let get = client.get(.with { $0.text = "foo" })

let response = try get.response.wait()
XCTAssertEqual(response.text, "Swift echo get: foo")

let status = try get.status.wait()
XCTAssertEqual(status.code, .ok)
}

func testTLSWithHostnameOverride() throws {
// Run a server presenting a certificate for example.com on localhost.
let serverTLS: Server.Configuration.TLS = .init(
certificateChain: [.certificate(SampleCertificate.exampleServer.certificate)],
privateKey: .privateKey(SamplePrivateKey.exampleServer),
trustRoots: .certificates([SampleCertificate.ca.certificate])
)

self.server = try makeEchoServer(tls: serverTLS)
guard let port = self.server.channel.localAddress?.port else {
XCTFail("could not get server port")
return
}

let clientTLS: ClientConnection.Configuration.TLS = .init(
trustRoots: .certificates([SampleCertificate.ca.certificate]),
hostnameOverride: "example.com"
)

self.connection = self.makeConnection(port: port, tls: clientTLS)
try self.doTestUnary()
}

func testTLSWithoutHostnameOverride() throws {
// Run a server presenting a certificate for localhost on localhost.
let serverTLS: Server.Configuration.TLS = .init(
certificateChain: [.certificate(SampleCertificate.server.certificate)],
privateKey: .privateKey(SamplePrivateKey.server),
trustRoots: .certificates([SampleCertificate.ca.certificate])
)

self.server = try makeEchoServer(tls: serverTLS)
guard let port = self.server.channel.localAddress?.port else {
XCTFail("could not get server port")
return
}

let clientTLS: ClientConnection.Configuration.TLS = .init(
trustRoots: .certificates([SampleCertificate.ca.certificate])
)

self.connection = self.makeConnection(port: port, tls: clientTLS)
try self.doTestUnary()
}
}
11 changes: 11 additions & 0 deletions Tests/GRPCTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ extension ClientTLSFailureTests {
]
}

extension ClientTLSHostnameOverrideTests {
// DO NOT MODIFY: This is autogenerated, use:
// `swift test --generate-linuxmain`
// to regenerate.
static let __allTests__ClientTLSHostnameOverrideTests = [
("testTLSWithHostnameOverride", testTLSWithHostnameOverride),
("testTLSWithoutHostnameOverride", testTLSWithoutHostnameOverride),
]
}

extension ClientThrowingWhenServerReturningErrorTests {
// DO NOT MODIFY: This is autogenerated, use:
// `swift test --generate-linuxmain`
Expand Down Expand Up @@ -418,6 +428,7 @@ public func __allTests() -> [XCTestCaseEntry] {
testCase(ClientClosedChannelTests.__allTests__ClientClosedChannelTests),
testCase(ClientConnectionBackoffTests.__allTests__ClientConnectionBackoffTests),
testCase(ClientTLSFailureTests.__allTests__ClientTLSFailureTests),
testCase(ClientTLSHostnameOverrideTests.__allTests__ClientTLSHostnameOverrideTests),
testCase(ClientThrowingWhenServerReturningErrorTests.__allTests__ClientThrowingWhenServerReturningErrorTests),
testCase(ClientTimeoutTests.__allTests__ClientTimeoutTests),
testCase(ConnectionBackoffTests.__allTests__ConnectionBackoffTests),
Expand Down
16 changes: 0 additions & 16 deletions Tests/ca.crt

This file was deleted.

16 changes: 0 additions & 16 deletions Tests/client.crt

This file was deleted.

28 changes: 0 additions & 28 deletions Tests/client.pem

This file was deleted.

16 changes: 0 additions & 16 deletions Tests/server.crt

This file was deleted.

28 changes: 0 additions & 28 deletions Tests/server.pem

This file was deleted.

17 changes: 0 additions & 17 deletions Tests/ssl.crt

This file was deleted.

27 changes: 0 additions & 27 deletions Tests/ssl.key

This file was deleted.