Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 73 additions & 88 deletions Sources/GRPCHTTP2Core/Compression/Zlib.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,18 @@ enum Zlib {
extension Zlib {
/// Creates a new compressor for the given compression format.
///
/// This compressor is only suitable for compressing whole messages at a time. Callers
/// must ``initialize()`` the compressor before using it.
/// This compressor is only suitable for compressing whole messages at a time.
struct Compressor {
private var stream: z_stream
// TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.

private var stream: UnsafeMutablePointer<z_stream>
private let method: Method
private var isInitialized = false

init(method: Method) {
self.method = method
self.stream = z_stream()
}

/// Initialize the compressor.
mutating func initialize() {
precondition(!self.isInitialized)
self.stream = .allocate(capacity: 1)
self.stream.initialize(to: z_stream())
self.stream.deflateInit(windowBits: self.method.windowBits)
self.isInitialized = true
}

static func initialized(_ method: Method) -> Self {
var compressor = Compressor(method: method)
compressor.initialize()
return compressor
}

/// Compresses the data in `input` into the `output` buffer.
Expand All @@ -68,77 +57,73 @@ extension Zlib {
/// - Parameter output: The `ByteBuffer` into which the compressed message should be written.
/// - Returns: The number of bytes written into the `output` buffer.
@discardableResult
mutating func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
precondition(self.isInitialized)
func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
defer { self.reset() }
let upperBound = self.stream.deflateBound(inputBytes: input.count)
return try self.stream.deflate(input, into: &output, upperBound: upperBound)
}

/// Resets compression state.
private mutating func reset() {
private func reset() {
do {
try self.stream.deflateReset()
} catch {
self.end()
self.stream = z_stream()
self.stream.initialize(to: z_stream())
self.stream.deflateInit(windowBits: self.method.windowBits)
}
}

/// Deallocates any resources allocated by Zlib.
mutating func end() {
func end() {
self.stream.deflateEnd()
self.stream.deallocate()
}
}
}

extension Zlib {
/// Creates a new decompressor for the given compression format.
///
/// This decompressor is only suitable for compressing whole messages at a time. Callers
/// must ``initialize()`` the decompressor before using it.
/// This decompressor is only suitable for compressing whole messages at a time.
struct Decompressor {
private var stream: z_stream
// TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.

private var stream: UnsafeMutablePointer<z_stream>
private let method: Method
private var isInitialized = false

init(method: Method) {
self.method = method
self.stream = z_stream()
}

mutating func initialize() {
precondition(!self.isInitialized)
self.stream = UnsafeMutablePointer.allocate(capacity: 1)
self.stream.initialize(to: z_stream())
self.stream.inflateInit(windowBits: self.method.windowBits)
self.isInitialized = true
}

/// Returns the decompressed bytes from ``input``.
///
/// - Parameters:
/// - input: The buffer read compressed bytes from.
/// - limit: The largest size a decompressed payload may be.
mutating func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
precondition(self.isInitialized)
func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
defer { self.reset() }
return try self.stream.inflate(input: &input, limit: limit)
}

/// Resets decompression state.
private mutating func reset() {
private func reset() {
do {
try self.stream.inflateReset()
} catch {
self.end()
self.stream = z_stream()
self.stream.initialize(to: z_stream())
self.stream.inflateInit(windowBits: self.method.windowBits)
}
}

/// Deallocates any resources allocated by Zlib.
mutating func end() {
func end() {
self.stream.inflateEnd()
self.stream.deallocate()
}
}
}
Expand All @@ -155,13 +140,13 @@ struct ZlibError: Error, Hashable {
}
}

extension z_stream {
mutating func inflateInit(windowBits: Int32) {
self.zfree = nil
self.zalloc = nil
self.opaque = nil
extension UnsafeMutablePointer<z_stream> {
func inflateInit(windowBits: Int32) {
self.pointee.zfree = nil
self.pointee.zalloc = nil
self.pointee.opaque = nil

let rc = CGRPCZlib_inflateInit2(&self, windowBits)
let rc = CGRPCZlib_inflateInit2(self, windowBits)
// Possible return codes:
// - Z_OK
// - Z_MEM_ERROR: not enough memory
Expand All @@ -171,8 +156,8 @@ extension z_stream {
precondition(rc == Z_OK, "inflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
}

mutating func inflateReset() throws {
let rc = CGRPCZlib_inflateReset(&self)
func inflateReset() throws {
let rc = CGRPCZlib_inflateReset(self)

// Possible return codes:
// - Z_OK
Expand All @@ -187,17 +172,17 @@ extension z_stream {
}
}

mutating func inflateEnd() {
_ = CGRPCZlib_inflateEnd(&self)
func inflateEnd() {
_ = CGRPCZlib_inflateEnd(self)
}

mutating func deflateInit(windowBits: Int32) {
self.zfree = nil
self.zalloc = nil
self.opaque = nil
func deflateInit(windowBits: Int32) {
self.pointee.zfree = nil
self.pointee.zalloc = nil
self.pointee.opaque = nil

let rc = CGRPCZlib_deflateInit2(
&self,
self,
Z_DEFAULT_COMPRESSION, // compression level
Z_DEFLATED, // compression method (this must be Z_DEFLATED)
windowBits, // window size, i.e. deflate/gzip
Expand All @@ -215,8 +200,8 @@ extension z_stream {
precondition(rc == Z_OK, "deflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
}

mutating func deflateReset() throws {
let rc = CGRPCZlib_deflateReset(&self)
func deflateReset() throws {
let rc = CGRPCZlib_deflateReset(self)

// Possible return codes:
// - Z_OK
Expand All @@ -231,87 +216,87 @@ extension z_stream {
}
}

mutating func deflateEnd() {
_ = CGRPCZlib_deflateEnd(&self)
func deflateEnd() {
_ = CGRPCZlib_deflateEnd(self)
}

mutating func deflateBound(inputBytes: Int) -> Int {
let bound = CGRPCZlib_deflateBound(&self, UInt(inputBytes))
func deflateBound(inputBytes: Int) -> Int {
let bound = CGRPCZlib_deflateBound(self, UInt(inputBytes))
return Int(bound)
}

mutating func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
if let baseAddress = buffer.baseAddress {
self.next_in = baseAddress
self.avail_in = UInt32(buffer.count)
self.pointee.next_in = baseAddress
self.pointee.avail_in = UInt32(buffer.count)
} else {
self.next_in = nil
self.avail_in = 0
self.pointee.next_in = nil
self.pointee.avail_in = 0
}
}

mutating func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
if let buffer = buffer, let baseAddress = buffer.baseAddress {
self.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
self.avail_in = UInt32(buffer.count)
self.pointee.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
self.pointee.avail_in = UInt32(buffer.count)
} else {
self.next_in = nil
self.avail_in = 0
self.pointee.next_in = nil
self.pointee.avail_in = 0
}
}

mutating func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
if let baseAddress = buffer.baseAddress {
self.next_out = baseAddress
self.avail_out = UInt32(buffer.count)
self.pointee.next_out = baseAddress
self.pointee.avail_out = UInt32(buffer.count)
} else {
self.next_out = nil
self.avail_out = 0
self.pointee.next_out = nil
self.pointee.avail_out = 0
}
}

mutating func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
if let buffer = buffer, let baseAddress = buffer.baseAddress {
self.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
self.avail_out = UInt32(buffer.count)
self.pointee.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
self.pointee.avail_out = UInt32(buffer.count)
} else {
self.next_out = nil
self.avail_out = 0
self.pointee.next_out = nil
self.pointee.avail_out = 0
}
}

/// Number of bytes available to read `self.nextInputBuffer`. See also: `z_stream.avail_in`.
var availableInputBytes: Int {
get {
Int(self.avail_in)
Int(self.pointee.avail_in)
}
set {
self.avail_in = UInt32(newValue)
self.pointee.avail_in = UInt32(newValue)
}
}

/// The remaining writable space in `nextOutputBuffer`. See also: `z_stream.avail_out`.
var availableOutputBytes: Int {
get {
Int(self.avail_out)
Int(self.pointee.avail_out)
}
set {
self.avail_out = UInt32(newValue)
self.pointee.avail_out = UInt32(newValue)
}
}

/// The total number of bytes written to the output buffer. See also: `z_stream.total_out`.
var totalOutputBytes: Int {
Int(self.total_out)
Int(self.pointee.total_out)
}

/// The last error message that zlib wrote. No message is guaranteed on error, however, `nil` is
/// guaranteed if there is no error. See also `z_stream.msg`.
var lastError: String? {
self.msg.map { String(cString: $0) }
self.pointee.msg.map { String(cString: $0) }
}

mutating func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
return try input.readWithUnsafeMutableReadableBytes { inputPointer in
self.setNextInputBuffer(inputPointer)
defer {
Expand Down Expand Up @@ -342,7 +327,7 @@ extension z_stream {
//
// Note that Z_OK is not okay here since we always flush with Z_FINISH and therefore
// use Z_STREAM_END as our success criteria.
let rc = CGRPCZlib_inflate(&self, Z_FINISH)
let rc = CGRPCZlib_inflate(self, Z_FINISH)
switch rc {
case Z_STREAM_END:
finished = true
Expand Down Expand Up @@ -377,7 +362,7 @@ extension z_stream {
}
}

mutating func deflate(
func deflate(
_ input: [UInt8],
into output: inout ByteBuffer,
upperBound: Int
Expand All @@ -394,7 +379,7 @@ extension z_stream {
return try output.writeWithUnsafeMutableBytes(minimumWritableBytes: upperBound) { output in
self.setNextOutputBuffer(output)

let rc = CGRPCZlib_deflate(&self, Z_FINISH)
let rc = CGRPCZlib_deflate(self, Z_FINISH)

// Possible return codes:
// - Z_OK: some progress has been made
Expand Down
15 changes: 5 additions & 10 deletions Tests/GRPCHTTP2CoreTests/Server/Compression/ZlibTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ final class ZlibTests: XCTestCase {
"""

private func compress(_ input: [UInt8], method: Zlib.Method) throws -> ByteBuffer {
var compressor = Zlib.Compressor(method: method)
compressor.initialize()
let compressor = Zlib.Compressor(method: method)
defer { compressor.end() }

var buffer = ByteBuffer()
Expand All @@ -45,8 +44,7 @@ final class ZlibTests: XCTestCase {
method: Zlib.Method,
limit: Int = .max
) throws -> [UInt8] {
var decompressor = Zlib.Decompressor(method: method)
decompressor.initialize()
let decompressor = Zlib.Decompressor(method: method)
defer { decompressor.end() }

var input = input
Expand All @@ -69,8 +67,7 @@ final class ZlibTests: XCTestCase {

func testRepeatedCompresses() throws {
let original = Array(self.text.utf8)
var compressor = Zlib.Compressor(method: .deflate)
compressor.initialize()
let compressor = Zlib.Compressor(method: .deflate)
defer { compressor.end() }

var compressed = ByteBuffer()
Expand All @@ -86,8 +83,7 @@ final class ZlibTests: XCTestCase {

func testRepeatedDecompresses() throws {
let original = Array(self.text.utf8)
var decompressor = Zlib.Decompressor(method: .deflate)
decompressor.initialize()
let decompressor = Zlib.Decompressor(method: .deflate)
defer { decompressor.end() }

let compressed = try self.compress(original, method: .deflate)
Expand Down Expand Up @@ -123,8 +119,7 @@ final class ZlibTests: XCTestCase {
}

func testCompressAppendsToBuffer() throws {
var compressor = Zlib.Compressor(method: .deflate)
compressor.initialize()
let compressor = Zlib.Compressor(method: .deflate)
defer { compressor.end() }

var buffer = ByteBuffer()
Expand Down