Skip to content

Commit

Permalink
Fix chars escaping, pass error to ConnectionProvider.release
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Hasselbach committed Jan 16, 2017
1 parent 1c857c5 commit 5a4b818
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Scala 2.12 and 2.11 is supported. Tested on PostgreSQL 9.4

## Installation
`libraryDependencies ++= "ru.arigativa" %% "akka-streams-postgresql-copy" % "0.2.1"`
`libraryDependencies ++= "ru.arigativa" %% "akka-streams-postgresql-copy" % "0.3.1"`

## Usage

Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name := "akka-streams-postgresql-copy"
organization := "ru.arigativa"


version := "0.2.1"
version := "0.3.1"

scalaVersion := "2.12.1"
crossScalaVersions := Seq("2.12.1", "2.11.8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ import scala.util.{Failure, Success, Try}
*/
trait ConnectionProvider {
def acquire(): Try[PGConnection]
def release(): Unit
def release(exOpt: Option[Throwable]): Unit
}

object ConnectionProvider {
implicit def pgConnectionGetterToCloseableProvider(getConn: () => BaseConnection): ConnectionProvider =
new ConnectionProvider {
private var conn: Try[BaseConnection] = Failure(new RuntimeException("Connection is not acquired"))
def acquire(): Try[PGConnection] = {
release()
release(None)
conn = Try(getConn())
conn
}
def release(): Unit = conn.foreach(_.close())
def release(exOpt: Option[Throwable]): Unit = conn.foreach(_.close())
}

implicit def pgConnectionToWrapperProvider(conn: PGConnection): ConnectionProvider =
new ConnectionProvider {
def acquire(): Try[PGConnection] = Success(conn)
def release(): Unit = ()
def release(exOpt: Option[Throwable]): Unit = ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[streams] class PgCopySinkStage(connectionProvider: ConnectionProvider, q
Try(copyIn.endCopy()) match {
case Success(rowsCopied) =>
completePromise.trySuccess(rowsCopied)
connectionProvider.release(None)
completeStage()
case Failure(ex) => fail(ex)
}
Expand All @@ -62,7 +63,7 @@ private[streams] class PgCopySinkStage(connectionProvider: ConnectionProvider, q
}

private def fail(ex: Throwable): Unit = {
connectionProvider.release()
connectionProvider.release(Some(ex))
completePromise.tryFailure(ex)
failStage(ex)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ import scala.concurrent.Future
*/
object PgCopyStreamConverters {

private val escapeSpecialChars: String => String =
Seq(
"\\" -> "\\\\", // escape `escape` character is first
"\b" -> "\\b", "\f" -> "\\f", "\n" -> "\\n",
"\r" -> "\\r", "\t" -> "\\t", "\u0011" -> "\\v"
).foldLeft(identity[String] _) {
case (resultFunction, (sFrom, sTo)) =>
resultFunction.andThen(_.replace(sFrom, sTo))
}

def sink(connectionProvider: ConnectionProvider, query: String, encoding: String = "UTF-8"): Sink[Product, Future[Long]] =
encodeTuples(encoding)
.toMat(bytesSink(connectionProvider, query))(Keep.right)
Expand All @@ -26,13 +36,11 @@ object PgCopyStreamConverters {
_.productIterator
.map {
case None | null => """\N"""
case Some(value) => esc(value.toString)
case value => esc(value.toString)
case Some(value) => escapeSpecialChars(value.toString)
case value => escapeSpecialChars(value.toString)
}
.mkString("", "\t", "\n")
.getBytes(encoding)
}
.map(ByteString.fromArray)

private def esc(s: String) = s.replace("""\""", """\\""")
}
22 changes: 22 additions & 0 deletions src/test/scala/it/CopySinkSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import ru.arigativa.akka.streams.ConnectionProvider._
import ru.arigativa.akka.streams.PgCopyStreamConverters
import util.PostgresFixture

import scala.util.Random

/**
* Check for integration with Postgres in Docker is working
*/
Expand Down Expand Up @@ -64,4 +66,24 @@ class CopySinkSpec extends AsyncFlatSpec with Matchers with PostgresFixture with
}
}
}

it should "encode special characters correctly" in {
val actualFirstPeople = (1L, "Alex\r\n\t Hasselbach\\", 25)
val actualSecondPeople = (2L, "Lisa", 21)
withPostgres("people_empty") { conn =>
Source.fromIterator(() => Iterator(actualFirstPeople, actualSecondPeople))
.runWith(PgCopyStreamConverters.sink(conn, "COPY people (id, name, age) FROM STDIN"))
.map { rowCount =>
rowCount shouldBe 2

val rs = conn.execSQLQuery("SELECT id, name, age FROM people")
val firstPeople = fetchPeople(rs)
val secondPeople = fetchPeople(rs)
conn.close()

firstPeople shouldBe actualFirstPeople
secondPeople shouldBe actualSecondPeople
}
}
}
}

0 comments on commit 5a4b818

Please sign in to comment.