Skip to content

Commit

Permalink
Added a CRUDify trait for Squeryl. Fixed #967.
Browse files Browse the repository at this point in the history
  • Loading branch information
davewhittaker committed May 23, 2011
1 parent c6cafb6 commit 07d6b0b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 1 deletion.
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2006-2011 WorldWide Conferencing, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.liftweb
package squerylrecord

import net.liftweb.record.{Record, MetaRecord}
import net.liftweb.proto.Crudify
import org.squeryl._
import net.liftweb.squerylrecord.RecordTypeMode._
import net.liftweb.record.Field
import net.liftweb.common.{Box, Empty, Full}
import scala.xml.NodeSeq

trait CRUDify[K, T <: Record[T] with KeyedEntity[K]] extends Crudify {
self: MetaRecord[T] =>

type TheCrudType = T

type FieldPointerType = Field[_, TheCrudType]

def table: Table[TheCrudType]

def idFromString(in: String): K

override def calcPrefix = table.name :: Nil

override def fieldsForDisplay: List[FieldPointerType] = metaFields.filter(_.shouldDisplay_?)

override def computeFieldFromPointer(instance: TheCrudType, pointer: FieldPointerType): Box[FieldPointerType] = instance.fieldByName(pointer.name)

override def findForParam(in: String): Box[TheCrudType] = {
table.lookup(idFromString(in))
}

override def findForList(start: Long, count: Int) = from(table)(t => select(t)).page(start.toInt, count).toList

override def create = createRecord

override def buildBridge(in: TheCrudType) = new SquerylBridge(in)

protected class SquerylBridge(in: TheCrudType) extends CrudBridge {

def delete_! = table.delete(in.id)

def save = {
if (in.isPersisted) {
table.update(in)
}
else {
table.insert(in)
}
true
}

def validate = in.validate

def primaryKeyFieldAsString = in.id.toString
}

def buildFieldBridge(from: FieldPointerType): FieldPointerBridge = new SquerylFieldBridge(from)

protected class SquerylFieldBridge(in: FieldPointerType) extends FieldPointerBridge {
def displayHtml: NodeSeq = in.displayHtml
}

}
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ class Company private () extends Record[Company] with KeyedRecord[Long] {
lazy val employees = MySchema.companyToEmployees.left(this) lazy val employees = MySchema.companyToEmployees.left(this)


} }
object Company extends Company with MetaRecord[Company] object Company extends Company with MetaRecord[Company] with CRUDify[Long, Company]{

def table = MySchema.companies

def idFromString(in: String) = in.toLong

}


object EmployeeRole extends Enumeration { object EmployeeRole extends Enumeration {
type EmployeeRole = Value type EmployeeRole = Value
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package squerylrecord


import org.specs.Specification import org.specs.Specification
import record.{BaseField, Record} import record.{BaseField, Record}
import record.field._
import RecordTypeMode._ import RecordTypeMode._
import MySchema.{TestData => td, _} import MySchema.{TestData => td, _}
import java.util.Calendar import java.util.Calendar
Expand Down Expand Up @@ -283,6 +284,31 @@ object SquerylRecordSpec extends Specification("SquerylRecord Specification") {
} }
} }


forExample("support the CRUDify trait") >> {
transaction{
val company = Company.create.name("CRUDify Company").
created(Calendar.getInstance()).
country(Countries.USA).postCode("90210")
val bridge = Company.buildBridge(company)
bridge.save
val id = company.id
company.isPersisted must_== true
id must be_>(0l)
company.postCode("10001")
bridge.save
val company2 = Company.findForParam(id.toString)
company2.isDefined must_== true
company2.foreach(c2 => {
c2.postCode.get must_== "10001"
})
val allCompanies = Company.findForList(0, 1000)
allCompanies.size must be_>(0)
bridge.delete_!
val allCompanies2 = Company.findForList(0, 1000)
allCompanies2.size must_== (allCompanies.size - 1)
}
}

} }


class TransactionRollbackException extends RuntimeException class TransactionRollbackException extends RuntimeException
Expand Down

0 comments on commit 07d6b0b

Please sign in to comment.