Skip to content

Commit

Permalink
Use Shapeless 2.0 records.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joni Freeman committed Sep 11, 2013
1 parent 6571cd7 commit cdb0cf2
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 368 deletions.
19 changes: 8 additions & 11 deletions README.md
Expand Up @@ -35,28 +35,25 @@ Start console: ```sbt```, then ```project sqltyped``` and ```test:console```.
import java.sql._
import sqltyped._
Class.forName("com.mysql.jdbc.Driver")
object Columns { object name; object age; object salary; }
implicit val c = Configuration(Columns)
implicit val c = Configuration()
implicit def conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/sqltyped",
"root", "")
import Tables._
import Columns._
```

Now we are ready to query the data.

```scala
scala> val q = sql("select name, age from person")
scala> q() map (_ get age)
scala> q() map (_ get "age")
res0: List[Int] = List(36, 14)
```

Notice how the type of 'age' was infered to be Int.

```scala
scala> q() map (_ get salary)
<console>:24: error: No such key Columns.salary.type
q() map (_ get salary)
scala> q() map (_ get "salary")
<console>:24: error: No field String("salary") in record ...
q() map (_ get "salary")
```

Oops, a compilation failure. Can't access 'salary', it was not selected in the query.
Expand All @@ -79,13 +76,13 @@ Input parameters are parsed and typed.
```scala
scala> val q = sql("select name, age from person where age > ?")

scala> q("30") map (_ get name)
scala> q("30") map (_ get "name")
<console>:24: error: type mismatch;
found : String("30")
required: Int
q("30") map (_ get name)

scala> q(30) map (_ get name)
scala> q(30) map (_ get "name")
res4: List[String] = List(joe)
```

Expand All @@ -105,7 +102,7 @@ column it is typed as String => String etc.

```scala
scala> val q = sql("select max(name) as name, max(age) as age from person where age > ?")
scala> q(10).tuples.head
scala> q(10).tupled
res6: (Option[String], Option[Int]) = (Some(moe),Some(36))
```

Expand Down
9 changes: 5 additions & 4 deletions TODO
@@ -1,19 +1,20 @@
======== 0.4 ========

- Shapeless 2.0
- Bug: select from derived join

======== 0.5 ========

- Schemaprefixes
- Recursive unions
- Recursive unions (postgresql)
- Views?
- Bug: select max(id) from jackpot where `group` = ? ('group' works)

======== Backlog ========

- 2.11: Generate meta types http://stackoverflow.com/questions/12295971/will-it-be-possible-to-generate-several-top-level-classes-with-one-macro-invocat
- 2.11: Replace Query(n) with SqlF
- 2.11: Batch insert and update (after SqlF)
- 2.11: A function to merge records
- Infer function types from schema
- Handle dialect specific keywords properly (e.g. http://www.postgresql.org/docs/9.2/static/sql-keywords-appendix.html)
- 2.11: Conversions between records and normal scala types
- 2.11: Better interpolation support (table + column names etc.)
- DB REPL
4 changes: 2 additions & 2 deletions core/src/main/scala/csv.scala
@@ -1,6 +1,6 @@
package sqltyped

import shapeless._, TypeOperators.@@
import shapeless._, ops.hlist._, tag.@@

trait Show[A] {
def show(a: A): String
Expand Down Expand Up @@ -40,6 +40,6 @@ object CSV {
private def escape(s: String) = "\"" + s.replaceAll("\"","\"\"") + "\""
}

object toCSV extends Pullback1[List[String]] {
object toCSV extends Poly1 {
implicit def valueToCsv[V: Show] = at[V](v => List(implicitly[Show[V]].show(v)))
}
144 changes: 79 additions & 65 deletions core/src/main/scala/macro.scala
Expand Up @@ -5,12 +5,7 @@ import schemacrawler.schema.Schema
import NumOfResults._

trait ConfigurationName
case class Configuration[A, B](tables: A, columns: B)

object NoTables
object Configuration {
def apply[B](columns: B) = new Configuration(NoTables, columns)
}
case class Configuration()

object SqlMacro {
import shapeless._
Expand Down Expand Up @@ -43,34 +38,34 @@ object SqlMacro {
stmt.close
}

def sqlImpl[A: c.WeakTypeTag, B: c.WeakTypeTag]
def sqlImpl
(c: Context)
(s: c.Expr[String])
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] =
(config: c.Expr[Configuration]): c.Expr[Any] =
sqlImpl0(c)(s)(config)

def sqltImpl[A: c.WeakTypeTag, B: c.WeakTypeTag]
def sqltImpl
(c: Context)
(s: c.Expr[String])
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] =
(config: c.Expr[Configuration]): c.Expr[Any] =
sqlImpl0(c, useInputTags = true)(s)(config)

def sqlkImpl[A: c.WeakTypeTag, B: c.WeakTypeTag]
def sqlkImpl
(c: Context)
(s: c.Expr[String])
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] =
(config: c.Expr[Configuration]): c.Expr[Any] =
sqlImpl0(c, useInputTags = false, keys = true)(s)(config)

def sqljImpl[A: c.WeakTypeTag, B: c.WeakTypeTag]
def sqljImpl
(c: Context)
(s: c.Expr[String])
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] =
(config: c.Expr[Configuration]): c.Expr[Any] =
sqlImpl0(c, useInputTags = false, keys = false, jdbcOnly = true)(s)(config)

def sqlImpl0[A: c.WeakTypeTag, B: c.WeakTypeTag]
def sqlImpl0
(c: Context, useInputTags: Boolean = false, keys: Boolean = false, jdbcOnly: Boolean = false)
(s: c.Expr[String])
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] = {
(config: c.Expr[Configuration]): c.Expr[Any] = {

import c.universe._

Expand All @@ -83,9 +78,9 @@ object SqlMacro {
sql, (p, s) => p.parseAllWith(p.stmt, s))(config, Literal(Constant(sql)))
}

def dynsqlImpl[A: c.WeakTypeTag, B: c.WeakTypeTag]
def dynsqlImpl
(c: Context)(exprs: c.Expr[Any]*)
(config: c.Expr[Configuration[A, B]]): c.Expr[Any] = {
(config: c.Expr[Configuration]): c.Expr[Any] = {

import c.universe._

Expand All @@ -108,10 +103,10 @@ object SqlMacro {
sql, (p, s) => p.parseWith(p.selectStmt, s))(config, sqlExpr)
}

def compile[A: c.WeakTypeTag, B: c.WeakTypeTag]
def compile
(c: Context, useInputTags: Boolean, keys: Boolean, jdbcOnly: Boolean, inputsInferred: Boolean, validate: Boolean, analyze: Boolean,
sql: String, parse: (SqlParser, String) => ?[Ast.Statement[Option[String]]])
(config: c.Expr[Configuration[A, B]], sqlExpr: c.Tree): c.Expr[Any] = {
(config: c.Expr[Configuration], sqlExpr: c.Tree): c.Expr[Any] = {

import c.universe._
import scala.util.Properties
Expand Down Expand Up @@ -186,9 +181,9 @@ object SqlMacro {
)
}

def codeGen[A: c.WeakTypeTag, B: c.WeakTypeTag]
def codeGen[A: c.WeakTypeTag]
(meta: TypedStatement, sql: String, c: Context, keys: Boolean, inputsInferred: Boolean)
(config: c.Expr[Configuration[A, B]], sqlExpr: c.Tree): c.Expr[Any] = {
(config: c.Expr[Configuration], sqlExpr: c.Tree): c.Expr[Any] = {

import c.universe._

Expand All @@ -205,11 +200,11 @@ object SqlMacro {
def baseValue = Typed(Apply(Select(Ident(newTermName("rs")), newTermName(rsGetterName(x))),
List(Literal(Constant(pos)))), scalaBaseType(x))

x.tag flatMap(t => tagType(t)) map (tagged =>
x.tag map(t => tagType(t)) map (tagged =>
Apply(
Select(
TypeApply(
Select(Select(Ident(newTermName("shapeless")), newTermName("TypeOperators")), newTermName("tag")),
Select(Select(Ident(newTermName("shapeless")), newTermName("tag")), newTermName("apply")),
List(tagged)), newTermName("apply")), List(baseValue))
) getOrElse baseValue
}
Expand All @@ -229,30 +224,14 @@ object SqlMacro {
} else Ident(c.mirror.staticClass(x.tpe._1.typeSymbol.fullName))

def scalaType(x: TypedValue) = {
x.tag flatMap (t => tagType(t)) map (tagged =>
x.tag map (t => tagType(t)) map (tagged =>
AppliedTypeTree(
Select(Select(Ident(newTermName("shapeless")), newTermName("TypeOperators")), newTypeName("$at$at")),
Select(Select(Ident(newTermName("shapeless")), newTermName("tag")), newTypeName("$at$at")),
List(scalaBaseType(x), tagged))
) getOrElse scalaBaseType(x)
}

def colKey(name: String) =
if (c.typeCheck(Select(Select(config.tree, newTermName("columns")), newTermName(name)), silent = true) == EmptyTree)
Literal(Constant(name))
else Select(Select(config.tree, newTermName("columns")), newTermName(name))

def colKeyType(name: String) =
if (c.typeCheck(Select(Select(config.tree, newTermName("columns")), newTermName(name)), silent = true) == EmptyTree)
Ident(newTypeName("String"))
else SingletonTypeTree(Select(Select(config.tree, newTermName("columns")), newTermName(name)))

def tagType(tag: String) = try {
// FIXME: any better ways to check that type exists?
c.typeCheck(ValDef(Modifiers(), newTermName("xxxxx"), SelectFromTypeTree(SingletonTypeTree(Select(config.tree, newTermName("tables"))), tag), Literal(Constant(null))))
Some(SelectFromTypeTree(SingletonTypeTree(Select(config.tree, newTermName("tables"))), tag))
} catch {
case e: TypecheckException => None
}
def tagType(tag: String) = Select(Ident(newTermName(tag)), newTypeName("T"))

def stmtSetterName(x: TypedValue) = "set" + TypeMappings.setterGetterNames(x.tpe._2)
def rsGetterName(x: TypedValue) = "get" + TypeMappings.setterGetterNames(x.tpe._2)
Expand Down Expand Up @@ -306,9 +285,7 @@ object SqlMacro {
def returnTypeSigRecord = List(meta.output.foldRight(Ident(c.mirror.staticClass("shapeless.HNil")): Tree) { (x, sig) =>
AppliedTypeTree(
Ident(c.mirror.staticClass("shapeless.$colon$colon")),
List(AppliedTypeTree(
Ident(c.mirror.staticClass("scala.Tuple2")),
List(colKeyType(x.name), possiblyOptional(x, scalaType(x)))), sig)
List(AppliedTypeTree(Select(Ident(c.mirror.staticModule("shapeless.record")), newTypeName("FieldType")), List(Select(Ident(newTermName(x.name)), newTypeName("T")), possiblyOptional(x, scalaType(x)))), sig)
)
})

Expand All @@ -320,14 +297,14 @@ object SqlMacro {
newTermName("x$" + (i+1)),
TypeTree(),
Apply(Select(Apply(
Select(Ident(c.mirror.staticModule("scala.Predef")), newTermName("any2ArrowAssoc")),
List(colKey(x.name))), newTermName("$minus$greater")), List(rs(x, meta.output.length - i))))
Select(Ident(c.mirror.staticModule("shapeless.syntax.singleton")), newTermName("mkSingletonOps")),
List(Literal(Constant(x.name)))), newTermName("->>").encoded), List(rs(x, meta.output.length - i))))

val init: Tree =
Block(List(
processRow(meta.output.last, 0)),
Apply(
Select(Select(Ident(newTermName("shapeless")), newTermName("HNil")), newTermName("$colon$colon")),
Select(Select(Ident(newTermName("shapeless")), newTermName("HNil")), newTermName("::").encoded),
List(Ident(newTermName("x$1")))
))

Expand All @@ -340,7 +317,7 @@ object SqlMacro {
Select(Ident(c.mirror.staticModule("shapeless.HList")), newTermName("hlistOps")),
List(Block(ast))
),
newTermName("$colon$colon")), List(Ident(newTermName("x$" + (i+2))))))
newTermName("::").encoded), List(Ident(newTermName("x$" + (i+2))))))
})
}

Expand Down Expand Up @@ -460,24 +437,50 @@ object SqlMacro {
)
}

/* Generates following code:
new Query1[I1, (name.type, String) :: (age.type, Int) :: HNil] {
def apply(i1: I1)(implicit conn: Connection) = {
val stmt = conn.prepareStatement(sql)
stmt.setInt(1, i1)
withResultSet(stmt) { rs =>
val rows = collection.mutable.ListBuffer[(name.type, String) :: (age.type, Int) :: HNil]()
while (rs.next) {
rows.append((name -> rs.getString(1)) :: (age -> rs.getInt(2)) :: HNil)
}
rows.toList
}
}
}
/*
sql("select name, age from person where age > ?")
Generates following code:
identity {
val name = Witness("name")
val age = Witness("age")
new Query1[Int, FieldType[name.T, String] :: FieldType[age.T, Int] :: HNil] {
def apply(i1: Int)(implicit conn: Connection) = {
val stmt = conn.prepareStatement("select name, age from person where age > ?")
stmt.setInt(1, i1)
withResultSet(stmt) { rs =>
val rows = collection.mutable.ListBuffer[FieldType[name.T, String] :: FieldType[age.T, Int] :: HNil]()
while (rs.next) {
rows.append("name" ->> rs.getString(1) :: "age" ->> rs.getInt(2) :: HNil)
}
rows.toList
}
}
}
}
*/

val inputLen = if (inputsInferred) meta.input.length else 1

c.Expr {
def witnesses = (
(meta.output map (_.name)) :::
(meta.input flatMap (_.tag)) :::
(meta.output flatMap (_.tag)) :::
(if (keys) { meta.generatedKeyTypes flatMap (_.tag) } else Nil)
).distinct

def mkWitness(name: String) =
ValDef(
Modifiers(),
newTermName(name),
TypeTree(),
Apply(Select(Ident(c.mirror.staticModule("shapeless.Witness")), newTermName("apply")),
List(Literal(Constant(name)))))

def mkQuery =
Block(
List(
ClassDef(Modifiers(Flag.FINAL), newTypeName("$anon"), List(),
Expand All @@ -499,6 +502,17 @@ object SqlMacro {
)),
Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List())
)

c.Expr {
Apply(
Select(Ident(c.mirror.staticModule("scala.Predef")), newTermName("identity")),
List(
Block(
witnesses map (i => mkWitness(i)),
mkQuery
)
)
)
}
}
}
Expand Down
12 changes: 5 additions & 7 deletions core/src/main/scala/package.scala
Expand Up @@ -3,22 +3,20 @@ import shapeless._
package object sqltyped {
import language.experimental.macros

def sql[A, B](s: String)(implicit config: Configuration[A, B]) = macro SqlMacro.sqlImpl[A, B]
def sql(s: String)(implicit config: Configuration) = macro SqlMacro.sqlImpl

def sqlt[A, B](s: String)(implicit config: Configuration[A, B]) = macro SqlMacro.sqltImpl[A, B]
def sqlt(s: String)(implicit config: Configuration) = macro SqlMacro.sqltImpl

// FIXME switch to sql("select ...", keys = true) after;
// https://issues.scala-lang.org/browse/SI-5920
def sqlk[A, B](s: String)(implicit config: Configuration[A, B]) = macro SqlMacro.sqlkImpl[A, B]
def sqlk(s: String)(implicit config: Configuration) = macro SqlMacro.sqlkImpl

def sqlj[A, B](s: String)(implicit config: Configuration[A, B]) = macro SqlMacro.sqljImpl[A, B]
def sqlj(s: String)(implicit config: Configuration) = macro SqlMacro.sqljImpl

implicit class DynSQLContext(sc: StringContext) {
def sql[A, B](exprs: Any*)(implicit config: Configuration[A, B]) = macro SqlMacro.dynsqlImpl[A, B]
def sql(exprs: Any*)(implicit config: Configuration) = macro SqlMacro.dynsqlImpl
}

implicit def recordOps[L <: HList](l: L): RecordOps[L] = new RecordOps(l)

implicit def listOps[L <: HList](l: List[L]): ListOps[L] = new ListOps(l)

implicit def optionOps[L <: HList](l: Option[L]): OptionOps[L] = new OptionOps(l)
Expand Down

0 comments on commit cdb0cf2

Please sign in to comment.