diff --git a/core/src/main/scala/sttp/client4/RequestOptions.scala b/core/src/main/scala/sttp/client4/RequestOptions.scala index 32c3805e6c..8ed0cb1f99 100644 --- a/core/src/main/scala/sttp/client4/RequestOptions.scala +++ b/core/src/main/scala/sttp/client4/RequestOptions.scala @@ -1,10 +1,13 @@ package sttp.client4 +import sttp.client4.internal.ContentEncoding + import scala.concurrent.duration.Duration case class RequestOptions( followRedirects: Boolean, readTimeout: Duration, // TODO: Use FiniteDuration while migrating to sttp-4 maxRedirects: Int, - redirectToGet: Boolean + redirectToGet: Boolean, + encoding: List[ContentEncoding] = List.empty ) diff --git a/core/src/main/scala/sttp/client4/SttpClientException.scala b/core/src/main/scala/sttp/client4/SttpClientException.scala index 2290950030..1410074436 100644 --- a/core/src/main/scala/sttp/client4/SttpClientException.scala +++ b/core/src/main/scala/sttp/client4/SttpClientException.scala @@ -28,6 +28,8 @@ object SttpClientException extends SttpClientExceptionExtensions { class TimeoutException(request: GenericRequest[_, _], cause: Exception) extends ReadException(request, cause) + class EncodingException(request: GenericRequest[_, _], cause: Exception) extends SttpClientException(request, cause) + def adjustExceptions[F[_], T]( monadError: MonadError[F] )(t: => F[T])(usingFn: Exception => Option[Exception]): F[T] = diff --git a/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala b/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala new file mode 100644 index 0000000000..8dcc9b64fb --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala @@ -0,0 +1,32 @@ +package sttp.client4.internal + +sealed trait ContentEncoding { + def name: String +} + +object ContentEncoding { + + val gzip = Gzip() + val deflate = Deflate() + + case class Gzip() extends ContentEncoding { + override def name: String = "gzip" + } + + case class Compress() extends ContentEncoding { + override def name: String = "compress" + } + + case class Deflate() extends ContentEncoding { + override def name: String = "deflate" + } + + case class Br() extends ContentEncoding { + override def name: String = "br" + } + + case class Zstd() extends ContentEncoding { + override def name: String = "zstd" + } + +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala new file mode 100644 index 0000000000..2e84003f26 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala @@ -0,0 +1,83 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.{Deflate, Gzip} +import sttp.client4.internal.encoders.EncoderError.UnsupportedEncoding +import sttp.client4.{BasicBodyPart, ByteArrayBody, ByteBufferBody, FileBody, InputStreamBody, StringBody} +import sttp.model.MediaType + +import scala.annotation.tailrec + + +trait ContentCodec[C <: ContentEncoding] { + + type BodyWithLength = (BasicBodyPart, Int) + + def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] + + def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] + + def encoding: C + +} + +abstract class AbstractContentCodec[C <: ContentEncoding] extends ContentCodec[C] { + + override def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = { + body match { + case StringBody(s, encoding, ct) => encode(s.getBytes(encoding), ct) + case ByteArrayBody(b, ct) => encode(b, ct) + case ByteBufferBody(b, ct) => encode(b.array(), ct) + case InputStreamBody(b, ct) => encode(b.readAllBytes(), ct) + case FileBody(f, ct) => encode(f.readAsByteArray, ct) + } + } + + private def encode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = { + encode(bytes).map(r => ByteArrayBody(r, ct) -> r.length) + } + + override def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = body match { + case StringBody(s, encoding, ct) => decode(s.getBytes(encoding), ct) + case ByteArrayBody(b, ct) => decode(b, ct) + case ByteBufferBody(b, ct) => decode(b.array(), ct) + case InputStreamBody(b, ct) => decode(b.readAllBytes(), ct) + case FileBody(f, ct) => decode(f.readAsByteArray, ct) + } + + private def decode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = { + decode(bytes).map(r => ByteArrayBody(r, ct) -> r.length) + } + + def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] + def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] +} + +object ContentCodec { + + private val gzipCodec = new GzipContentCodec + + private val deflateCodec = new DeflateContentCodec + + def encode(b: BasicBodyPart, codec: List[ContentEncoding]): Either[EncoderError, (BasicBodyPart, Int)] = { + foldLeftInEither(codec, b -> 0) { case ((l,_), r) => + r match { + case _: Gzip => gzipCodec.encode(l) + case _: Deflate => deflateCodec.encode(l) + case e => Left(UnsupportedEncoding(e)) + } + } + } + + @tailrec + private def foldLeftInEither[T, R, E](elems: List[T], zero: R)(f: (R, T) => Either[E, R]): Either[E, R] = { + elems match { + case Nil => Right[E,R](zero) + case head :: tail => f(zero, head) match { + case l :Left[E, R] => l + case Right(v) => foldLeftInEither(tail, v)(f) + } + } + } + +} \ No newline at end of file diff --git a/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala new file mode 100644 index 0000000000..9e23ee4b77 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala @@ -0,0 +1,36 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.Deflate + +import java.io.ByteArrayOutputStream +import java.util.zip.{Deflater, Inflater} +import scala.util.{Try, Using} + +class DeflateContentCodec extends AbstractContentCodec[Deflate] { + + override def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] = + Try { + val deflater: Deflater = new Deflater() + deflater.setInput(bytes) + deflater.finish() + val compressedData = new Array[Byte](bytes.length * 2) + val count: Int = deflater.deflate(compressedData) + compressedData.take(count) + }.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage)) + + override def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] = + Using(new ByteArrayOutputStream()){ bos => + val buf = new Array[Byte](1024) + val decompresser = new Inflater() + decompresser.setInput(bytes, 0, bytes.length) + while (!decompresser.finished) { + val resultLength = decompresser.inflate(buf) + bos.write(buf, 0, resultLength) + } + decompresser.end() + bos.toByteArray + }.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage)) + + override def encoding: Deflate = ContentEncoding.deflate +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala b/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala new file mode 100644 index 0000000000..36e75d1c40 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala @@ -0,0 +1,20 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding + +import scala.util.control.NoStackTrace + +sealed trait EncoderError extends Exception with NoStackTrace { + def reason: String +} + +object EncoderError { + case class UnsupportedEncoding(encoding: ContentEncoding) extends EncoderError { + override def reason: String = s"${encoding.name} is unsupported with this body" + } + + case class EncodingFailure(encoding: ContentEncoding, msg: String) extends EncoderError { + + override def reason: String = s"Can`t encode $encoding for body $msg" + } +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala new file mode 100644 index 0000000000..d6015d8986 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala @@ -0,0 +1,31 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.Gzip +import sttp.client4.internal.encoders.EncoderError.EncodingFailure + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.util.Using + +class GzipContentCodec extends AbstractContentCodec[Gzip] { + + override def encode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = { + Using(new ByteArrayOutputStream){ baos => + Using(new GZIPOutputStream(baos)){ gzos => + gzos.write(bytes) + gzos.finish() + baos.toByteArray + } + }.flatMap(identity).toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage)) + } + + override def decode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = { + Using(new GZIPInputStream(new ByteArrayInputStream(bytes))) { b => + b.readAllBytes() + }.toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage)) + } + + override def encoding: Gzip = ContentEncoding.gzip + +} diff --git a/core/src/main/scala/sttp/client4/request.scala b/core/src/main/scala/sttp/client4/request.scala index acd298695b..c95ae77a80 100644 --- a/core/src/main/scala/sttp/client4/request.scala +++ b/core/src/main/scala/sttp/client4/request.scala @@ -2,6 +2,8 @@ package sttp.client4 import sttp.model.{Header, Method, Part, RequestMetadata, Uri} import sttp.capabilities.{Effect, Streams, WebSockets} +import sttp.client4.SttpClientException.EncodingException +import sttp.client4.internal.encoders.ContentCodec import sttp.client4.internal.{ToCurlConverter, ToRfc2616Converter} import scala.collection.immutable.Seq @@ -143,7 +145,20 @@ case class Request[T]( * Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown * unchanged. */ - def send[F[_]](backend: Backend[F]): F[Response[T]] = backend.send(this) + def send[F[_]](backend: Backend[F]): F[Response[T]] = { + (this.options.encoding, this.body) match { + case (Nil, _) => backend.send(this) + case (codecs, b: BasicBodyPart) if codecs.nonEmpty => + val (newBody, newLength) = ContentCodec.encode(b, codecs) match { + case Left(err) => throw new EncodingException(this, err) + case Right(v) => v + } + val newReq = this.contentLength(newLength.toLong).copyWithBody(newBody) + backend.send(newReq) + + case _ => backend.send(this) + } + } /** Sends the request synchronously, using the given backend. * @@ -155,7 +170,19 @@ case class Request[T]( * Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown * unchanged. */ - def send(backend: SyncBackend): Response[T] = backend.send(this) + def send(backend: SyncBackend): Response[T] = { + (this.options.encoding, this.body) match { + case (codecs, b: BasicBodyPart) if codecs.nonEmpty => + val (newBody, newLength) = ContentCodec.encode(b, codecs) match { + case Left(err) => throw new EncodingException(this, err) + case Right(v) => v + } + val newReq = this.contentLength(newLength.toLong).copyWithBody(newBody) + backend.send(newReq) + + case _ => backend.send(this) + } + } } object Request { diff --git a/core/src/main/scala/sttp/client4/requestBuilder.scala b/core/src/main/scala/sttp/client4/requestBuilder.scala index f845cd0205..1418c60876 100644 --- a/core/src/main/scala/sttp/client4/requestBuilder.scala +++ b/core/src/main/scala/sttp/client4/requestBuilder.scala @@ -1,8 +1,6 @@ package sttp.client4 -import sttp.client4.internal.SttpFile -import sttp.client4.internal.Utf8 -import sttp.client4.internal.contentTypeWithCharset +import sttp.client4.internal.{ContentEncoding, SttpFile, Utf8, contentTypeWithCharset} import sttp.client4.logging.LoggingOptions import sttp.client4.wrappers.DigestAuthenticationBackend import sttp.model.HasHeaders @@ -77,6 +75,10 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R] def contentLength(l: Long): PR = header(HeaderNames.ContentLength, l.toString, replaceExisting = true) + def contentEncoding(encoding: ContentEncoding): PR = + header(HeaderNames.ContentEncoding, encoding.name, replaceExisting = false) + .withOptions(options.copy(encoding = options.encoding :+ encoding)) + /** Adds the given header to the end of the headers sequence. * @param replaceExisting * If there's already a header with the same name, should it be replaced? diff --git a/core/src/test/scala/sttp/client4/testing/HttpTest.scala b/core/src/test/scala/sttp/client4/testing/HttpTest.scala index 4e7665911a..4288d301cd 100644 --- a/core/src/test/scala/sttp/client4/testing/HttpTest.scala +++ b/core/src/test/scala/sttp/client4/testing/HttpTest.scala @@ -3,15 +3,18 @@ package sttp.client4.testing import org.scalatest._ import org.scalatest.freespec.AsyncFreeSpec import org.scalatest.matchers.should.Matchers -import sttp.client4.internal.{Iso88591, Utf8} +import sttp.client4.internal.{ContentEncoding, Iso88591, Utf8} import sttp.client4.testing.HttpTest.endpoint import sttp.client4._ +import sttp.client4.internal.encoders.EncoderError.EncodingFailure +import sttp.client4.internal.encoders.{DeflateContentCodec, GzipContentCodec} import sttp.model.StatusCode import sttp.monad.MonadError import sttp.monad.syntax._ import java.io.{ByteArrayInputStream, UnsupportedEncodingException} import java.nio.ByteBuffer +import scala.Right import scala.concurrent.Future import scala.concurrent.duration._ @@ -21,6 +24,7 @@ trait HttpTest[F[_]] with Matchers with ToFutureWrapper with OptionValues + with EitherValues with HttpTestExtensions[F] with AsyncRetries { @@ -421,6 +425,52 @@ trait HttpTest[F[_]] req.send(backend).toFuture().map(resp => resp.code shouldBe StatusCode.Ok) } + "should compress request body gzip" in { + val codec = new GzipContentCodec + val req = basicRequest.contentEncoding(ContentEncoding.gzip) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codec.decode(resp.body) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + + "should compress request body deflate" in { + val codec = new DeflateContentCodec + val req = basicRequest.contentEncoding(ContentEncoding.deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codec.decode(resp.body) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + + "should compress request body multiple codecs" in { + val codecDeflate = new DeflateContentCodec + val codecGzip = new GzipContentCodec + val req = basicRequest + .contentEncoding(ContentEncoding.gzip) + .contentEncoding(ContentEncoding.deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codecDeflate.decode(resp.body) + .flatMap(b => codecGzip.decode(b)) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + if (supportsCustomContentEncoding) { "decompress using custom content encoding" in { val req =