Skip to content

Commit

Permalink
feature/softwaremill#1918 gzip and deflate content codecs
Browse files Browse the repository at this point in the history
  • Loading branch information
ifedorov committed Apr 23, 2024
1 parent 72d32b8 commit 737d90a
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 7 deletions.
5 changes: 4 additions & 1 deletion core/src/main/scala/sttp/client4/RequestOptions.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package sttp.client4

import sttp.client4.internal.ContentEncoding

import scala.concurrent.duration.Duration

case class RequestOptions(
followRedirects: Boolean,
readTimeout: Duration, // TODO: Use FiniteDuration while migrating to sttp-4
maxRedirects: Int,
redirectToGet: Boolean
redirectToGet: Boolean,
encoding: List[ContentEncoding] = List.empty
)
2 changes: 2 additions & 0 deletions core/src/main/scala/sttp/client4/SttpClientException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ object SttpClientException extends SttpClientExceptionExtensions {

class TimeoutException(request: GenericRequest[_, _], cause: Exception) extends ReadException(request, cause)

class EncodingException(request: GenericRequest[_, _], cause: Exception) extends SttpClientException(request, cause)

def adjustExceptions[F[_], T](
monadError: MonadError[F]
)(t: => F[T])(usingFn: Exception => Option[Exception]): F[T] =
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/scala/sttp/client4/internal/ContentEncoding.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package sttp.client4.internal

sealed trait ContentEncoding {
def name: String
}

object ContentEncoding {

val gzip = Gzip()
val deflate = Deflate()

case class Gzip() extends ContentEncoding {
override def name: String = "gzip"
}

case class Compress() extends ContentEncoding {
override def name: String = "compress"
}

case class Deflate() extends ContentEncoding {
override def name: String = "deflate"
}

case class Br() extends ContentEncoding {
override def name: String = "br"
}

case class Zstd() extends ContentEncoding {
override def name: String = "zstd"
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package sttp.client4.internal.encoders

import sttp.client4.internal.ContentEncoding
import sttp.client4.internal.ContentEncoding.{Deflate, Gzip}
import sttp.client4.internal.encoders.EncoderError.UnsupportedEncoding
import sttp.client4.{BasicBodyPart, ByteArrayBody, ByteBufferBody, FileBody, InputStreamBody, StringBody}
import sttp.model.MediaType

import scala.annotation.tailrec


trait ContentCodec[C <: ContentEncoding] {

type BodyWithLength = (BasicBodyPart, Int)

def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength]

def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength]

def encoding: C

}

abstract class AbstractContentCodec[C <: ContentEncoding] extends ContentCodec[C] {

override def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = {
body match {
case StringBody(s, encoding, ct) => encode(s.getBytes(encoding), ct)
case ByteArrayBody(b, ct) => encode(b, ct)
case ByteBufferBody(b, ct) => encode(b.array(), ct)
case InputStreamBody(b, ct) => encode(b.readAllBytes(), ct)
case FileBody(f, ct) => encode(f.readAsByteArray, ct)
}
}

private def encode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = {
encode(bytes).map(r => ByteArrayBody(r, ct) -> r.length)
}

override def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = body match {
case StringBody(s, encoding, ct) => decode(s.getBytes(encoding), ct)
case ByteArrayBody(b, ct) => decode(b, ct)
case ByteBufferBody(b, ct) => decode(b.array(), ct)
case InputStreamBody(b, ct) => decode(b.readAllBytes(), ct)
case FileBody(f, ct) => decode(f.readAsByteArray, ct)
}

private def decode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = {
decode(bytes).map(r => ByteArrayBody(r, ct) -> r.length)
}

def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]]
def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]]
}

object ContentCodec {

private val gzipCodec = new GzipContentCodec

private val deflateCodec = new DeflateContentCodec

def encode(b: BasicBodyPart, codec: List[ContentEncoding]): Either[EncoderError, (BasicBodyPart, Int)] = {
foldLeftInEither(codec, b -> 0) { case ((l,_), r) =>
r match {
case _: Gzip => gzipCodec.encode(l)
case _: Deflate => deflateCodec.encode(l)
case e => Left(UnsupportedEncoding(e))
}
}
}

@tailrec
private def foldLeftInEither[T, R, E](elems: List[T], zero: R)(f: (R, T) => Either[E, R]): Either[E, R] = {
elems match {
case Nil => Right[E,R](zero)
case head :: tail => f(zero, head) match {
case l :Left[E, R] => l
case Right(v) => foldLeftInEither(tail, v)(f)
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sttp.client4.internal.encoders

import sttp.client4.internal.ContentEncoding
import sttp.client4.internal.ContentEncoding.Deflate

import java.io.ByteArrayOutputStream
import java.util.zip.{Deflater, Inflater}
import scala.util.{Try, Using}

class DeflateContentCodec extends AbstractContentCodec[Deflate] {

override def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] =
Try {
val deflater: Deflater = new Deflater()
deflater.setInput(bytes)
deflater.finish()
val compressedData = new Array[Byte](bytes.length * 2)
val count: Int = deflater.deflate(compressedData)
compressedData.take(count)
}.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage))

override def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] =
Using(new ByteArrayOutputStream()){ bos =>
val buf = new Array[Byte](1024)
val decompresser = new Inflater()
decompresser.setInput(bytes, 0, bytes.length)
while (!decompresser.finished) {
val resultLength = decompresser.inflate(buf)
bos.write(buf, 0, resultLength)
}
decompresser.end()
bos.toByteArray
}.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage))

override def encoding: Deflate = ContentEncoding.deflate
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package sttp.client4.internal.encoders

import sttp.client4.internal.ContentEncoding

import scala.util.control.NoStackTrace

sealed trait EncoderError extends Exception with NoStackTrace {
def reason: String
}

object EncoderError {
case class UnsupportedEncoding(encoding: ContentEncoding) extends EncoderError {
override def reason: String = s"${encoding.name} is unsupported with this body"
}

case class EncodingFailure(encoding: ContentEncoding, msg: String) extends EncoderError {

override def reason: String = s"Can`t encode $encoding for body $msg"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package sttp.client4.internal.encoders

import sttp.client4.internal.ContentEncoding
import sttp.client4.internal.ContentEncoding.Gzip
import sttp.client4.internal.encoders.EncoderError.EncodingFailure

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.util.Using

class GzipContentCodec extends AbstractContentCodec[Gzip] {

override def encode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = {
Using(new ByteArrayOutputStream){ baos =>
Using(new GZIPOutputStream(baos)){ gzos =>
gzos.write(bytes)
gzos.finish()
baos.toByteArray
}
}.flatMap(identity).toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage))
}

override def decode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = {
Using(new GZIPInputStream(new ByteArrayInputStream(bytes))) { b =>
b.readAllBytes()
}.toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage))
}

override def encoding: Gzip = ContentEncoding.gzip

}
31 changes: 29 additions & 2 deletions core/src/main/scala/sttp/client4/request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sttp.client4

import sttp.model.{Header, Method, Part, RequestMetadata, Uri}
import sttp.capabilities.{Effect, Streams, WebSockets}
import sttp.client4.SttpClientException.EncodingException
import sttp.client4.internal.encoders.ContentCodec
import sttp.client4.internal.{ToCurlConverter, ToRfc2616Converter}

import scala.collection.immutable.Seq
Expand Down Expand Up @@ -143,7 +145,20 @@ case class Request[T](
* Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown
* unchanged.
*/
def send[F[_]](backend: Backend[F]): F[Response[T]] = backend.send(this)
def send[F[_]](backend: Backend[F]): F[Response[T]] = {
(this.options.encoding, this.body) match {
case (Nil, _) => backend.send(this)
case (codecs, b: BasicBodyPart) if codecs.nonEmpty =>
val (newBody, newLength) = ContentCodec.encode(b, codecs) match {
case Left(err) => throw new EncodingException(this, err)
case Right(v) => v
}
val newReq = this.contentLength(newLength.toLong).copyWithBody(newBody)
backend.send(newReq)

case _ => backend.send(this)
}
}

/** Sends the request synchronously, using the given backend.
*
Expand All @@ -155,7 +170,19 @@ case class Request[T](
* Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown
* unchanged.
*/
def send(backend: SyncBackend): Response[T] = backend.send(this)
def send(backend: SyncBackend): Response[T] = {
(this.options.encoding, this.body) match {
case (codecs, b: BasicBodyPart) if codecs.nonEmpty =>
val (newBody, newLength) = ContentCodec.encode(b, codecs) match {
case Left(err) => throw new EncodingException(this, err)
case Right(v) => v
}
val newReq = this.contentLength(newLength.toLong).copyWithBody(newBody)
backend.send(newReq)

case _ => backend.send(this)
}
}
}

object Request {
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/sttp/client4/requestBuilder.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package sttp.client4

import sttp.client4.internal.SttpFile
import sttp.client4.internal.Utf8
import sttp.client4.internal.contentTypeWithCharset
import sttp.client4.internal.{ContentEncoding, SttpFile, Utf8, contentTypeWithCharset}
import sttp.client4.logging.LoggingOptions
import sttp.client4.wrappers.DigestAuthenticationBackend
import sttp.model.HasHeaders
Expand Down Expand Up @@ -77,6 +75,10 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R]
def contentLength(l: Long): PR =
header(HeaderNames.ContentLength, l.toString, replaceExisting = true)

def contentEncoding(encoding: ContentEncoding): PR =
header(HeaderNames.ContentEncoding, encoding.name, replaceExisting = false)
.withOptions(options.copy(encoding = options.encoding :+ encoding))

/** Adds the given header to the end of the headers sequence.
* @param replaceExisting
* If there's already a header with the same name, should it be replaced?
Expand Down
52 changes: 51 additions & 1 deletion core/src/test/scala/sttp/client4/testing/HttpTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ package sttp.client4.testing
import org.scalatest._
import org.scalatest.freespec.AsyncFreeSpec
import org.scalatest.matchers.should.Matchers
import sttp.client4.internal.{Iso88591, Utf8}
import sttp.client4.internal.{ContentEncoding, Iso88591, Utf8}
import sttp.client4.testing.HttpTest.endpoint
import sttp.client4._
import sttp.client4.internal.encoders.EncoderError.EncodingFailure
import sttp.client4.internal.encoders.{DeflateContentCodec, GzipContentCodec}
import sttp.model.StatusCode
import sttp.monad.MonadError
import sttp.monad.syntax._

import java.io.{ByteArrayInputStream, UnsupportedEncodingException}
import java.nio.ByteBuffer
import scala.Right
import scala.concurrent.Future
import scala.concurrent.duration._

Expand All @@ -21,6 +24,7 @@ trait HttpTest[F[_]]
with Matchers
with ToFutureWrapper
with OptionValues
with EitherValues
with HttpTestExtensions[F]
with AsyncRetries {

Expand Down Expand Up @@ -421,6 +425,52 @@ trait HttpTest[F[_]]
req.send(backend).toFuture().map(resp => resp.code shouldBe StatusCode.Ok)
}

"should compress request body gzip" in {
val codec = new GzipContentCodec
val req = basicRequest.contentEncoding(ContentEncoding.gzip)
.response(asByteArrayAlways)
.post(uri"$endpoint/echo/exact")
.body("I`m not compressed")
req.send(backend).toFuture().map{resp =>
resp.code shouldBe StatusCode.Ok
val res = codec.decode(resp.body)
res.isRight shouldBe true
res.right.value shouldBe "I`m not compressed".getBytes()
}
}

"should compress request body deflate" in {
val codec = new DeflateContentCodec
val req = basicRequest.contentEncoding(ContentEncoding.deflate)
.response(asByteArrayAlways)
.post(uri"$endpoint/echo/exact")
.body("I`m not compressed")
req.send(backend).toFuture().map{resp =>
resp.code shouldBe StatusCode.Ok
val res = codec.decode(resp.body)
res.isRight shouldBe true
res.right.value shouldBe "I`m not compressed".getBytes()
}
}

"should compress request body multiple codecs" in {
val codecDeflate = new DeflateContentCodec
val codecGzip = new GzipContentCodec
val req = basicRequest
.contentEncoding(ContentEncoding.gzip)
.contentEncoding(ContentEncoding.deflate)
.response(asByteArrayAlways)
.post(uri"$endpoint/echo/exact")
.body("I`m not compressed")
req.send(backend).toFuture().map{resp =>
resp.code shouldBe StatusCode.Ok
val res = codecDeflate.decode(resp.body)
.flatMap(b => codecGzip.decode(b))
res.isRight shouldBe true
res.right.value shouldBe "I`m not compressed".getBytes()
}
}

if (supportsCustomContentEncoding) {
"decompress using custom content encoding" in {
val req =
Expand Down

0 comments on commit 737d90a

Please sign in to comment.