Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ libraryDependencies ++= Seq(
### Connecting to the DB

```scala
import com.twitter.finagle.postgres.Client
val client = Client(host, username, password, database)
```

Expand Down
23 changes: 3 additions & 20 deletions src/main/scala/com/twitter/finagle/postgres/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class Client(factory: ServiceFactory[PgRequest, PgResponse], id:String) {
/*
* Run a single SELECT query and wrap the results with the provided function.
*/
def select[T](sql: String)(f: Row => T): Future[Seq[T]] = fetch(sql) map {
rs =>
extractRows(rs).map(f)
def select[T](sql: String)(f: Row => T): Future[Seq[T]] = fetch(sql) map { rs =>
rs.toRowList(customTypes).map(f)
}

/*
Expand Down Expand Up @@ -120,7 +119,7 @@ class Client(factory: ServiceFactory[PgRequest, PgResponse], id:String) {
optionalService: Option[Service[PgRequest, PgResponse]] = None
): Future[(IndexedSeq[String], IndexedSeq[ChannelBuffer => Value[Any]])] = {
send(PgRequest(Describe(portal = true, name = name), flush = true), optionalService) {
case RowDescriptions(fields) => Future.value(processFields(fields))
case RowDescriptions(fields) => Future.value(Field.processFields(fields, customTypes))
}
}

Expand Down Expand Up @@ -158,22 +157,6 @@ class Client(factory: ServiceFactory[PgRequest, PgResponse], id:String) {
})
}

private[this] def processFields(
fields: IndexedSeq[Field]): (IndexedSeq[String], IndexedSeq[ChannelBuffer => Value[Any]]) = {
val names = fields.map(f => f.name)
val parsers = fields.map(f => ValueParser.parserOf(f.format, f.dataType, customTypes))

(names, parsers)
}

private[this] def extractRows(rs: SelectResult): List[Row] = {
val (fieldNames, fieldParsers) = processFields(rs.fields)

rs.rows.map(dataRow => new Row(fieldNames, dataRow.data.zip(fieldParsers).map {
case (d, p) => if (d == null) null else p(d)
}))
}

private[this] class PreparedStatementImpl(
name: String,
service: Service[PgRequest, PgResponse]) extends PreparedStatement {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package com.twitter.finagle.postgres.messages

import com.twitter.finagle.postgres.Row
import com.twitter.finagle.postgres.values.{ValueParser, Value}

import org.jboss.netty.buffer.ChannelBuffer

/*
* Response message types.
*/
Expand Down Expand Up @@ -31,6 +36,26 @@ case class AuthenticatedResponse(params: Map[String, String], processId: Int, se

case class Rows(rows: List[DataRow], completed: Boolean) extends PgResponse

object Field {
/*
* Extract an `IndexSeq[Field]` into a tuple containing
* corresponding field-names and field-parsing functions.
*
* @param fields The `Field`s to be processed.
* @param customTypes A `Map` containing name->type pairs representing custom
* value types.
*/
private[postgres] def processFields(
fields: IndexedSeq[Field],
customTypes: Map[String, String]
): (IndexedSeq[String], IndexedSeq[ChannelBuffer => Value[Any]]) = {
val names = fields.map(f => f.name)
val parsers = fields.map(f => ValueParser.parserOf(f.format, f.dataType, customTypes))

(names, parsers)
}
}

case class Field(name: String, format: Int, dataType: Int)

case class RowDescriptions(fields: IndexedSeq[Field]) extends PgResponse
Expand All @@ -39,6 +64,20 @@ case class Descriptions(params: IndexedSeq[Int], fields: IndexedSeq[Field]) exte

case class ParamsResponse(types: IndexedSeq[Int]) extends PgResponse

case class SelectResult(fields: IndexedSeq[Field], rows: List[DataRow]) extends PgResponse
case class SelectResult(fields: IndexedSeq[Field], rows: List[DataRow]) extends PgResponse {
/*
* Returns this `SelectResult` as a list of `Row`s.
*
* @param customTypes A `Map` containing name->type pairs representing custom
* value types.
*/
def toRowList(customTypes: Map[String, String] = Map.empty): List[Row] = {
val (fieldNames, fieldParsers) = Field.processFields(fields, customTypes)

rows.map(dataRow => new Row(fieldNames, dataRow.data.zip(fieldParsers).map {
case (d, p) => if (d == null) null else p(d)
}))
}
}

case class CommandCompleteResponse(affectedRows: Int) extends PgResponse
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.twitter.finagle.postgres.messages

import com.twitter.finagle.postgres.Spec
import com.twitter.finagle.postgres.values.{Charsets, Type, Value}
import org.jboss.netty.buffer.ChannelBuffers

class PgResponsesSpec extends Spec {
"Field.processFields" should {
"Extract names from fields" in {
val fields = IndexedSeq(Field("foo", 0, Type.BOOL), Field("bar", 0, 9999))
val customTypes = Map("9999" -> "hstore")
val (fieldNames, _) = Field.processFields(fields, customTypes)

fieldNames must equal(Seq("foo", "bar"))
}
}

"SelectResult.toRowList" should {
"Return a `Row` with correct data" in {
val fields = IndexedSeq(Field("email", 0, Type.VAR_CHAR))
val row1 = DataRow(IndexedSeq(ChannelBuffers.copiedBuffer("donald@duck.com".getBytes(Charsets.Utf8))))
val row2 = DataRow(IndexedSeq(ChannelBuffers.copiedBuffer("daisy@duck.com".getBytes(Charsets.Utf8))))
val rowList = SelectResult(fields, List(row1, row2)).toRowList()

rowList.size must equal(2)
rowList(0).fields must equal(Seq("email"))
rowList(0).vals must equal(Seq((Value("donald@duck.com"))))
rowList(1).fields must equal(Seq("email"))
rowList(1).vals must equal(Seq(Value("daisy@duck.com")))
}
}
}