Skip to content

Commit

Permalink
Fix insertion of default values since broken by #85 (#121)
Browse files Browse the repository at this point in the history
When printing the SQL, we analyse the `PrimExpr`s and `PrimQuery` to detect the case where `rows` is set to `rows = values _` and, if so, we output a bare `VALUES` statement instead of a `SELECT`, because `DEFAULT` is only syntactically valid in a bare `VALUES` clause.
  • Loading branch information
shane-circuithub committed Jul 17, 2021
1 parent 75d7b32 commit add7e77
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/Rel8/Statement/Insert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import Rel8.Schema.Name ( Name, Selects, ppColumn )
import Rel8.Schema.Table ( TableSchema(..), ppTable )
import Rel8.Statement.OnConflict ( OnConflict, ppOnConflict )
import Rel8.Statement.Returning ( Returning, decodeReturning, ppReturning )
import Rel8.Statement.Select ( ppSelect )
import Rel8.Statement.Select ( ppRows )
import Rel8.Table ( Table )
import Rel8.Table.Name ( showNames )

Expand Down Expand Up @@ -66,7 +66,7 @@ ppInsert :: Insert a -> Doc
ppInsert Insert {..} =
text "INSERT INTO" <+>
ppInto into $$
ppSelect rows $$
ppRows rows $$
ppOnConflict into onConflict $$
ppReturning into returning

Expand Down
34 changes: 31 additions & 3 deletions src/Rel8/Statement/Select.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ module Rel8.Statement.Select

, Optimized(..)
, ppPrimSelect
, ppRows
)
where

-- base
import Data.Foldable ( toList )
import Data.List.NonEmpty ( NonEmpty( (:|) ) )
import Data.Void ( Void )
import Prelude hiding ( undefined )

Expand All @@ -24,12 +27,15 @@ import qualified Hasql.Encoders as Hasql
import qualified Hasql.Statement as Hasql

-- opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye
import qualified Opaleye.Internal.PrimQuery as Opaleye
import qualified Opaleye.Internal.Print as Opaleye
import qualified Opaleye.Internal.Optimize as Opaleye
import qualified Opaleye.Internal.QueryArr as Opaleye hiding ( Select )
import qualified Opaleye.Internal.Sql as Opaleye
import qualified Opaleye.Internal.Sql as Opaleye hiding ( Values )
import qualified Opaleye.Internal.Tag as Opaleye

-- pretty
import Text.PrettyPrint ( Doc )
Expand All @@ -45,6 +51,7 @@ import Rel8.Table ( Table )
import Rel8.Table.Cols ( toCols )
import Rel8.Table.Name ( namesFromLabels )
import Rel8.Table.Opaleye ( castTable, exprsWithNames )
import qualified Rel8.Table.Opaleye as T
import Rel8.Table.Serialize ( Serializable, parse )
import Rel8.Table.Undefined ( undefined )

Expand Down Expand Up @@ -80,12 +87,33 @@ ppSelect query =
never = pure (toPrimExpr false)


ppRows :: Table Expr a => Query a -> Doc
ppRows query = case optimize primQuery of
-- Special case VALUES because we can't use DEFAULT inside a SELECT
Optimized (Opaleye.Product ((_, Opaleye.Values symbols rows) :| []) [])
| eqSymbols symbols (toList (T.exprs a)) ->
Opaleye.ppValues_ (map Opaleye.sqlExpr <$> toList rows)
_ -> ppSelect query
where
(a, primQuery, _) = Opaleye.runSimpleQueryArrStart (toOpaleye query) ()

eqSymbols (symbol : symbols) (Opaleye.AttrExpr symbol' : exprs)
| eqSymbol symbol symbol' = eqSymbols symbols exprs
| otherwise = False
eqSymbols [] [] = True
eqSymbols _ _ = False

eqSymbol
(Opaleye.Symbol name (Opaleye.UnsafeTag tag))
(Opaleye.Symbol name' (Opaleye.UnsafeTag tag'))
= name == name' && tag == tag'


ppPrimSelect :: Query a -> (Optimized Doc, a)
ppPrimSelect query =
(Opaleye.ppSql . primSelect <$> optimize primQuery, a)
where
(a, primQuery, _) =
Opaleye.runSimpleQueryArrStart (toOpaleye query) ()
(a, primQuery, _) = Opaleye.runSimpleQueryArrStart (toOpaleye query) ()


data Optimized a = Empty | Unit | Optimized a
Expand Down

0 comments on commit add7e77

Please sign in to comment.