Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested records in deriveEsqueletoRecord #324

Merged
merged 8 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
3.5.6.1
=======
- @9999years
- [#324](https://github.com/bitemyapp/esqueleto/pull/324)
- Add ability to use nested records with `deriveEsqueletoRecord`

3.5.6.0
=======
- @9999years
Expand Down
2 changes: 1 addition & 1 deletion esqueleto.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cabal-version: 1.12

name: esqueleto

version: 3.5.6.0
version: 3.5.6.1
synopsis: Type-safe EDSL for SQL queries on persistent backends.
description: @esqueleto@ is a bare bones, type-safe EDSL for SQL queries that works with unmodified @persistent@ SQL backends. Its language closely resembles SQL, so you don't have to learn new concepts, just new syntax, and it's fairly easy to predict the generated SQL and optimize it for your backend. Most kinds of errors committed when writing SQL are caught as compile-time errors---although it is possible to write type-checked @esqueleto@ queries that fail at runtime.
.
Expand Down
175 changes: 118 additions & 57 deletions src/Database/Esqueleto/Record.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ViewPatterns #-}

module Database.Esqueleto.Record
( deriveEsqueletoRecord
Expand All @@ -16,11 +17,13 @@ import Database.Esqueleto.Experimental
(Entity, PersistValue, SqlExpr, Value(..), (:&)(..))
import Database.Esqueleto.Internal.Internal (SqlSelect(..))
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Data.Bifunctor (first)
import Data.Text (Text)
import Control.Monad (forM)
import Data.Foldable (foldl')
import GHC.Exts (IsString(fromString))
import Data.Maybe (mapMaybe, fromMaybe, listToMaybe)

-- | Takes the name of a Haskell record type and creates a variant of that
-- record prefixed with @Sql@ which can be used in esqueleto expressions. This
Expand Down Expand Up @@ -169,16 +172,19 @@ getRecordInfo name = do
RecC name' _fields -> name'
con -> error $ nonRecordConstructorMessage con
fields = getFields constructor
sqlFields = toSqlField `map` fields
sqlName = makeSqlName name

sqlFields <- mapM toSqlField fields

pure RecordInfo {..}
where
getFields :: Con -> [(Name, Type)]
getFields (RecC _name fields) = [(fieldName', fieldType') | (fieldName', _bang, fieldType') <- fields]
getFields con = error $ nonRecordConstructorMessage con

toSqlField (fieldName', ty) = (fieldName', sqlFieldType ty)
toSqlField (fieldName', ty) = do
sqlTy <- sqlFieldType ty
pure (fieldName', sqlTy)

-- | Create a new name by prefixing @Sql@ to a given name.
makeSqlName :: Name -> Name
Expand All @@ -189,17 +195,28 @@ makeSqlName name = mkName $ "Sql" ++ nameBase name
-- * @'Entity' x@ is transformed into @'SqlExpr' ('Entity' x)@.
-- * @'Maybe' ('Entity' x)@ is transformed into @'SqlExpr' ('Maybe' ('Entity' x))@.
-- * @x@ is transformed into @'SqlExpr' ('Value' x)@.
sqlFieldType :: Type -> Type
sqlFieldType fieldType =
case fieldType of
-- Entity x -> SqlExpr (Entity x)
AppT (ConT ((==) ''Entity -> True)) _innerType -> AppT (ConT ''SqlExpr) fieldType
-- Maybe (Entity x) -> SqlExpr (Maybe (Entity x))
AppT
(ConT ((==) ''Maybe -> True))
(AppT (ConT ((==) ''Entity -> True)) _innerType) -> AppT (ConT ''SqlExpr) fieldType
-- x -> SqlExpr (Value x)
_ -> AppT (ConT ''SqlExpr) (AppT (ConT ''Value) fieldType)
-- * If there exists an instance @'SqlSelect' sql x@, then @x@ is transformed into @sql@.
--
-- This function should match `sqlSelectProcessRowPat`.
sqlFieldType :: Type -> Q Type
sqlFieldType fieldType = do
maybeSqlType <- reifySqlSelectType fieldType

pure $
flip fromMaybe maybeSqlType $
case fieldType of
-- Entity x -> SqlExpr (Entity x)
AppT (ConT ((==) ''Entity -> True)) _innerType -> AppT (ConT ''SqlExpr) fieldType

-- Maybe (Entity x) -> SqlExpr (Maybe (Entity x))
(ConT ((==) ''Maybe -> True))
`AppT` ((ConT ((==) ''Entity -> True))
`AppT` _innerType) -> AppT (ConT ''SqlExpr) fieldType

-- x -> SqlExpr (Value x)
_ -> (ConT ''SqlExpr)
`AppT` ((ConT ''Value)
`AppT` fieldType)

-- | Generates the declaration for an @Sql@-prefixed record, given the original
-- record's information.
Expand All @@ -222,9 +239,9 @@ makeSqlSelectInstance info@RecordInfo {..} = do
let overlap = Nothing
instanceConstraints = []
instanceType =
AppT
(AppT (ConT ''SqlSelect) (ConT sqlName))
(ConT name)
(ConT ''SqlSelect)
`AppT` (ConT sqlName)
`AppT` (ConT name)

pure $ InstanceD overlap instanceConstraints instanceType [sqlSelectColsDec', sqlSelectColCountDec', sqlSelectProcessRowDec']

Expand Down Expand Up @@ -265,9 +282,9 @@ sqlSelectColsDec RecordInfo {..} = do
, RecP sqlName fieldPatterns
]
( NormalB $
AppE
(AppE (VarE 'sqlSelectCols) (VarE identInfo))
(ParensE joinedFields)
(VarE 'sqlSelectCols)
`AppE` (VarE identInfo)
`AppE` (ParensE joinedFields)
)
-- `where` clause.
[]
Expand Down Expand Up @@ -318,9 +335,10 @@ sqlSelectProcessRowDec RecordInfo {..} = do
(statements, fieldExps) <-
unzip <$> forM (zip fields sqlFields) (\((fieldName', fieldType), (_, sqlType')) -> do
valueName <- newName (nameBase fieldName')
pattern <- sqlSelectProcessRowPat fieldType valueName
pure
( BindS
(sqlSelectProcessRowPat fieldType valueName)
pattern
(AppTypeE (VarE 'takeColumns) sqlType')
, (mkName $ nameBase fieldName', VarE valueName)
))
Expand All @@ -334,31 +352,17 @@ sqlSelectProcessRowDec RecordInfo {..} = do
-- (evalStateT $processName $colsName)
-- where $processName = do $statements
-- pure $name {$fieldExps}
bodyExp <- [e|
first (fromString ("Failed to parse " ++ $(lift $ nameBase name) ++ ": ") <>)
(evalStateT $(varE processName) $(varE colsName))
|]

pure $
FunD
'sqlSelectProcessRow
[ Clause
[VarP colsName]
( NormalB $
AppE
( AppE
(VarE 'first)
( InfixE
(Just $ AppE
(VarE 'fromString)
(LitE $ StringL $ "Failed to parse " ++ nameBase name ++ ": "))
(VarE '(<>))
Nothing
)
)
( AppE
( AppE
(VarE 'evalStateT)
(VarE processName)
)
(VarE colsName)
)
)
(NormalB bodyExp)
-- `where` clause
[ ValD
(VarP processName)
Expand All @@ -379,22 +383,79 @@ sqlSelectProcessRowDec RecordInfo {..} = do
-- * A type of @'Entity' x@ gives a pattern of @var@.
-- * A type of @'Maybe' ('Entity' x)@ gives a pattern of @var@.
-- * A type of @x@ gives a pattern of @'Value' var@.
sqlSelectProcessRowPat :: Type -> Name -> Pat
sqlSelectProcessRowPat fieldType var =
case fieldType of
-- Entity x -> var
AppT (ConT ((==) ''Entity -> True)) _innerType -> VarP var
-- Maybe (Entity x) -> var
AppT
(ConT ((==) ''Maybe -> True))
(AppT (ConT ((==) ''Entity -> True)) _innerType) -> VarP var
-- x -> Value var
-- * If there exists an instance @'SqlSelect' sql x@, then a type of @x@ gives a pattern of @var@.
--
-- This function should match `sqlFieldType`.
sqlSelectProcessRowPat :: Type -> Name -> Q Pat
sqlSelectProcessRowPat fieldType var = do
maybeSqlType <- reifySqlSelectType fieldType

case maybeSqlType of
Just _ -> pure $ VarP var
Nothing -> case fieldType of
-- Entity x -> var
AppT (ConT ((==) ''Entity -> True)) _innerType -> pure $ VarP var
-- Maybe (Entity x) -> var
(ConT ((==) ''Maybe -> True))
`AppT` ((ConT ((==) ''Entity -> True))
`AppT` _innerType) -> pure $ VarP var
-- x -> Value var
#if MIN_VERSION_template_haskell(2,18,0)
_ -> ConP 'Value [] [VarP var]
_ -> pure $ ConP 'Value [] [VarP var]
#else
_ -> ConP 'Value [VarP var]
_ -> pure $ ConP 'Value [VarP var]
#endif

-- Given a type, find the corresponding SQL type.
--
-- If there exists an instance `SqlSelect sql ty`, then the SQL type for `ty`
-- is `sql`.
--
-- This function definitely works for records and instances generated by this
-- module, and might work for instances outside of it.
reifySqlSelectType :: Type -> Q (Maybe Type)
reifySqlSelectType originalType = do
-- Here we query the compiler for Instances of `SqlSelect a $(originalType)`;
-- the API for this is super weird, it interprets a list of types as being
-- applied as successive arguments to the typeclass name.
--
-- See: https://gitlab.haskell.org/ghc/ghc/-/issues/21825
--
-- >>> reifyInstances ''SqlSelect [VarT (mkName "a"), ConT ''MyRecord]
-- [ InstanceD Nothing
-- []
-- (AppT (AppT (ConT Database.Esqueleto.Internal.Internal.SqlSelect)
-- (ConT Ghci3.SqlMyRecord))
-- (ConT Ghci3.MyRecord))
-- []
-- ]
tyVarName <- newName "a"
instances <- reifyInstances ''SqlSelect [VarT tyVarName, originalType]

-- Given the original type (`originalType`) and an instance type for a
-- `SqlSelect` instance, get the SQL type which corresponds to the original
-- type.
let extractSqlRecord :: Type -> Type -> Maybe Type
extractSqlRecord originalTy instanceTy =
case instanceTy of
(ConT ((==) ''SqlSelect -> True))
`AppT` sqlTy
`AppT` ((==) originalTy -> True) -> Just sqlTy
_ -> Nothing

-- Filter `instances` to the instances which match `originalType`.
filteredInstances :: [Type]
filteredInstances =
flip mapMaybe instances
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, listToMaybe . mapMaybe f is also find (isJust . f) - may be easier or more concise to write like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, because find only returns the element but I also need the mapMaybe part to transform and extract the sql type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good point!

(\case InstanceD _overlap
_constraints
(extractSqlRecord originalType -> Just sqlRecord)
_decs ->
Just sqlRecord
_ -> Nothing)

pure $ listToMaybe filteredInstances

-- | Statefully parse some number of columns from a list of `PersistValue`s,
-- where the number of columns to parse is determined by `sqlSelectColCount`
-- for @a@.
Expand Down
65 changes: 63 additions & 2 deletions test/Common/Record.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
module Common.Record (testDeriveEsqueletoRecord) where

import Common.Test.Import hiding (from, on)
import Database.Esqueleto.Record (deriveEsqueletoRecord)
import Database.Esqueleto.Experimental
import Data.List (sortOn)
import Database.Esqueleto.Experimental
import Database.Esqueleto.Record (deriveEsqueletoRecord)

data MyRecord =
MyRecord
Expand Down Expand Up @@ -50,6 +50,33 @@ myRecordQuery = do
, myAddress = address
}

data MyNestedRecord = MyNestedRecord
{ myName :: Text
, myRecord :: MyRecord
}
deriving (Show, Eq)

$(deriveEsqueletoRecord ''MyNestedRecord)

myNestedRecordQuery :: SqlQuery SqlMyNestedRecord
myNestedRecordQuery = do
user :& address <-
from $
table @User
`leftJoin` table @Address
`on` (do \(user :& address) -> user ^. #address ==. address ?. #id)
pure
SqlMyNestedRecord
{ myName = castString $ user ^. #name
, myRecord =
SqlMyRecord
{ myName = castString $ user ^. #name
, myAge = val $ Just 10
, myUser = user
, myAddress = address
}
}

testDeriveEsqueletoRecord :: SpecDb
testDeriveEsqueletoRecord = describe "deriveEsqueletoRecord" $ do
let setup :: MonadIO m => SqlPersistT m ()
Expand Down Expand Up @@ -83,3 +110,37 @@ testDeriveEsqueletoRecord = describe "deriveEsqueletoRecord" $ do
, myAddress = Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"})
} -> addr1 == addr2 -- The keys should match.
_ -> False)

itDb "can select nested records" $ do
setup
records <- select myNestedRecordQuery
let sortedRecords = sortOn (\MyNestedRecord {myName} -> myName) records
liftIO $ sortedRecords !! 0
`shouldSatisfy`
(\case MyNestedRecord
{ myName = "Rebecca"
, myRecord =
MyRecord { myName = "Rebecca"
, myAge = Just 10
, myUser = Entity _ User { userAddress = Nothing
, userName = "Rebecca"
}
, myAddress = Nothing
}
} -> True
_ -> False)

liftIO $ sortedRecords !! 1
`shouldSatisfy`
(\case MyNestedRecord
{ myName = "Some Guy"
, myRecord =
MyRecord { myName = "Some Guy"
, myAge = Just 10
, myUser = Entity _ User { userAddress = Just addr1
, userName = "Some Guy"
}
, myAddress = Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"})
}
} -> addr1 == addr2 -- The keys should match.
_ -> False)