Skip to content
Browse files

Added a CRUDify trait for Squeryl. Fixed #967.

  • Loading branch information...
1 parent c6cafb6 commit 07d6b0b91a97f0a7ebe58137ad8983d054ceca27 @davewhittaker davewhittaker committed Apr 26, 2011
View
80 persistence/squeryl-record/src/main/scala/net/liftweb/squerylrecord/CRUDify.scala
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2006-2011 WorldWide Conferencing, LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package net.liftweb
+package squerylrecord
+
+import net.liftweb.record.{Record, MetaRecord}
+import net.liftweb.proto.Crudify
+import org.squeryl._
+import net.liftweb.squerylrecord.RecordTypeMode._
+import net.liftweb.record.Field
+import net.liftweb.common.{Box, Empty, Full}
+import scala.xml.NodeSeq
+
+trait CRUDify[K, T <: Record[T] with KeyedEntity[K]] extends Crudify {
+ self: MetaRecord[T] =>
+
+ type TheCrudType = T
+
+ type FieldPointerType = Field[_, TheCrudType]
+
+ def table: Table[TheCrudType]
+
+ def idFromString(in: String): K
+
+ override def calcPrefix = table.name :: Nil
+
+ override def fieldsForDisplay: List[FieldPointerType] = metaFields.filter(_.shouldDisplay_?)
+
+ override def computeFieldFromPointer(instance: TheCrudType, pointer: FieldPointerType): Box[FieldPointerType] = instance.fieldByName(pointer.name)
+
+ override def findForParam(in: String): Box[TheCrudType] = {
+ table.lookup(idFromString(in))
+ }
+
+ override def findForList(start: Long, count: Int) = from(table)(t => select(t)).page(start.toInt, count).toList
+
+ override def create = createRecord
+
+ override def buildBridge(in: TheCrudType) = new SquerylBridge(in)
+
+ protected class SquerylBridge(in: TheCrudType) extends CrudBridge {
+
+ def delete_! = table.delete(in.id)
+
+ def save = {
+ if (in.isPersisted) {
+ table.update(in)
+ }
+ else {
+ table.insert(in)
+ }
+ true
+ }
+
+ def validate = in.validate
+
+ def primaryKeyFieldAsString = in.id.toString
+ }
+
+ def buildFieldBridge(from: FieldPointerType): FieldPointerBridge = new SquerylFieldBridge(from)
+
+ protected class SquerylFieldBridge(in: FieldPointerType) extends FieldPointerBridge {
+ def displayHtml: NodeSeq = in.displayHtml
+ }
+
+}
View
8 persistence/squeryl-record/src/test/scala/net/liftweb/squerylrecord/Fixtures.scala
@@ -83,7 +83,13 @@ class Company private () extends Record[Company] with KeyedRecord[Long] {
lazy val employees = MySchema.companyToEmployees.left(this)
}
-object Company extends Company with MetaRecord[Company]
+object Company extends Company with MetaRecord[Company] with CRUDify[Long, Company]{
+
+ def table = MySchema.companies
+
+ def idFromString(in: String) = in.toLong
+
+}
object EmployeeRole extends Enumeration {
type EmployeeRole = Value
View
26 persistence/squeryl-record/src/test/scala/net/liftweb/squerylrecord/SquerylRecordSpec.scala
@@ -16,6 +16,7 @@ package squerylrecord
import org.specs.Specification
import record.{BaseField, Record}
+import record.field._
import RecordTypeMode._
import MySchema.{TestData => td, _}
import java.util.Calendar
@@ -283,6 +284,31 @@ object SquerylRecordSpec extends Specification("SquerylRecord Specification") {
}
}
+ forExample("support the CRUDify trait") >> {
+ transaction{
+ val company = Company.create.name("CRUDify Company").
+ created(Calendar.getInstance()).
+ country(Countries.USA).postCode("90210")
+ val bridge = Company.buildBridge(company)
+ bridge.save
+ val id = company.id
+ company.isPersisted must_== true
+ id must be_>(0l)
+ company.postCode("10001")
+ bridge.save
+ val company2 = Company.findForParam(id.toString)
+ company2.isDefined must_== true
+ company2.foreach(c2 => {
+ c2.postCode.get must_== "10001"
+ })
+ val allCompanies = Company.findForList(0, 1000)
+ allCompanies.size must be_>(0)
+ bridge.delete_!
+ val allCompanies2 = Company.findForList(0, 1000)
+ allCompanies2.size must_== (allCompanies.size - 1)
+ }
+ }
+
}
class TransactionRollbackException extends RuntimeException

0 comments on commit 07d6b0b

Please sign in to comment.
Something went wrong with that request. Please try again.