Skip to content

Commit

Permalink
Detect identical expressions in extractAggregateFields to allow `di…
Browse files Browse the repository at this point in the history
…stinctAggregator` to be used with `orderAggregate`
  • Loading branch information
shane-circuithub committed Oct 5, 2023
1 parent e429309 commit c7f0c83
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 44 deletions.
3 changes: 2 additions & 1 deletion opaleye.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ library
aeson >= 0.6 && < 2.3
, base >= 4.9 && < 4.19
, base16-bytestring >= 0.1.1.6 && < 1.1
, case-insensitive >= 1.2 && < 1.3
, bytestring >= 0.10 && < 0.12
, case-insensitive >= 1.2 && < 1.3
, containers >= 0.5 && < 0.8
, contravariant >= 1.2 && < 1.6
, postgresql-simple >= 0.6 && < 0.8
, pretty >= 1.1.1.0 && < 1.2
Expand Down
78 changes: 52 additions & 26 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
module Opaleye.Internal.Aggregate where

import Control.Applicative (Applicative, liftA2, pure, (<*>))
import Control.Arrow ((***))
import Data.Foldable (toList)
import Data.Traversable (for)

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT, gets, modify, runStateT)

import qualified Opaleye.Field as F
import qualified Opaleye.Internal.Column as C
import qualified Opaleye.Internal.Order as O
Expand Down Expand Up @@ -130,42 +137,61 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) ->
--
-- Instead of detecting when we are aggregating over a field from a
-- previous query we just create new names for all field before we
-- aggregate. On the other hand, referring to a field from a previous
-- query in an ORDER BY expression is totally fine!
-- aggregate.
--
-- Additionally, PostgreSQL imposes a limitation on aggregations using ORDER
-- BY in combination with DISTINCT - essentially the expression you pass to
-- ORDER BY must also be present in the argument list to the aggregation
-- function. This means that not only do we also have to also create new
-- names for the ORDER BY expressions (if we only rewrite the function
-- arguments then they can't match and therefore ORDER BY can never be used
-- with DISTINCT), but that these names actually have to match the names
-- created for the aggregation function arguments. To accomplish this, when
-- traversing over the aggregations, we keep track of all the expressions
-- we've encountered so far, and only create new names for new expressions,
-- reusing old names where possible.
aggregateU :: Aggregator a b
-> (a, PQ.PrimQuery, T.Tag) -> (b, PQ.PrimQuery)
aggregateU agg (c0, primQ, t0) = (c1, primQ')
where (c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
aggregateU agg (a, primQ, tag) = (b, primQ')
where
(inners, outers, b) =
runSymbols (runAggregator agg (extractAggregateFields tag) a)

primQ' = PQ.Aggregate projPEs (PQ.Rebind True inners primQ)
primQ' = PQ.Aggregate outers (PQ.Rebind True inners primQ)

extractAggregateFields
:: Traversable t
=> T.Tag
-> (t HPQ.PrimExpr)
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
-> t HPQ.PrimExpr
-> Symbols HPQ.PrimExpr (PQ.Bindings (t HPQ.Symbol)) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new

let souter = HPQ.Symbol ("result" ++ i) tag

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings
result <- mkSymbol "result" <$> lift PM.new
agg' <- traverse (symbolize (mkSymbol "inner")) agg
lift $ PM.write (result, agg')
pure $ HPQ.AttrExpr result
where
mkSymbol name i = HPQ.Symbol (name ++ i) tag

PM.write ((souter, agg'), toList bindings)
type Symbols e s =
StateT
(Map e HPQ.Symbol, PQ.Bindings e -> PQ.Bindings e)
(PM.PM s)

pure (HPQ.AttrExpr souter)
runSymbols :: Symbols e [s] a -> (PQ.Bindings e, [s], a)
runSymbols m = (dlist [], outers, a)
where
((a, (_, dlist)), outers) = PM.run $ runStateT m (Map.empty, id)

symbolize :: Ord e =>
(String -> HPQ.Symbol) -> e -> Symbols e s HPQ.Symbol
symbolize f expr = do
msymbol <- gets (Map.lookup expr . fst)
case msymbol of
Just symbol -> pure symbol
Nothing -> do
symbol <- f <$> lift PM.new
modify (Map.insert expr symbol *** (. ((symbol, expr) :)))
pure symbol

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
32 changes: 16 additions & 16 deletions src/Opaleye/Internal/HaskellDB/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Name = String
type Scheme = [Attribute]
type Assoc = [(Attribute,PrimExpr)]

data Symbol = Symbol String T.Tag deriving (Read, Show)
data Symbol = Symbol String T.Tag deriving (Eq, Ord, Read, Show)

data PrimExpr = AttrExpr Symbol
| BaseTableAttrExpr Attribute
Expand All @@ -40,7 +40,7 @@ data PrimExpr = AttrExpr Symbol
| ArrayExpr [PrimExpr] -- ^ ARRAY[..]
| RangeExpr String BoundExpr BoundExpr
| ArrayIndex PrimExpr PrimExpr
deriving (Read,Show)
deriving (Eq, Ord, Read, Show)

data Literal = NullLit
| DefaultLit -- ^ represents a default value
Expand All @@ -51,7 +51,7 @@ data Literal = NullLit
| DoubleLit Double
| NumericLit Sci.Scientific
| OtherLit String -- ^ used for hacking in custom SQL
deriving (Read,Show)
deriving (Eq, Ord, Read, Show)

data BinOp = (:==) | (:<) | (:<=) | (:>) | (:>=) | (:<>)
| OpAnd | OpOr
Expand All @@ -66,7 +66,7 @@ data BinOp = (:==) | (:<) | (:<=) | (:>) | (:>=) | (:<>)
| (:->) | (:->>) | (:#>) | (:#>>)
| (:@>) | (:<@) | (:?) | (:?|) | (:?&)
| (:&&) | (:<<) | (:>>) | (:&<) | (:&>) | (:-|-)
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data UnOp = OpNot
| OpIsNull
Expand All @@ -77,22 +77,22 @@ data UnOp = OpNot
| OpLower
| OpUpper
| UnOpOther String
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data AggrOp = AggrCount | AggrSum | AggrAvg | AggrMin | AggrMax
| AggrStdDev | AggrStdDevP | AggrVar | AggrVarP
| AggrBoolOr | AggrBoolAnd | AggrArr | JsonArr
| AggrStringAggr
| AggrOther String
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data AggrDistinct = AggrDistinct | AggrAll
deriving (Eq,Show,Read)
deriving (Eq, Ord, Read, Show)

type Aggregate = Aggregate' PrimExpr

data Aggregate' a = GroupBy a | Aggregate (Aggr' a)
deriving (Functor, Foldable, Traversable, Show, Read)
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)

type Aggr = Aggr' PrimExpr

Expand All @@ -103,25 +103,25 @@ data Aggr' a = Aggr
, aggrDistinct :: !AggrDistinct
, aggrFilter :: !(Maybe PrimExpr)
}
deriving (Functor, Foldable, Traversable, Show, Read)
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)

type OrderExpr = OrderExpr' PrimExpr

data OrderExpr' a = OrderExpr OrderOp a
deriving (Functor, Foldable, Traversable, Show, Read)
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)

data OrderNulls = NullsFirst | NullsLast
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data OrderDirection = OpAsc | OpDesc
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data OrderOp = OrderOp { orderDirection :: OrderDirection
, orderNulls :: OrderNulls }
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data BoundExpr = Inclusive PrimExpr | Exclusive PrimExpr | PosInfinity | NegInfinity
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data WndwOp
= WndwRowNumber
Expand All @@ -136,10 +136,10 @@ data WndwOp
| WndwLastValue PrimExpr
| WndwNthValue PrimExpr PrimExpr
| WndwAggregate AggrOp [PrimExpr]
deriving (Show,Read)
deriving (Eq, Ord, Read, Show)

data Partition = Partition
{ partitionBy :: [PrimExpr]
, orderBy :: [OrderExpr]
}
deriving (Read, Show)
deriving (Eq, Ord, Read, Show)
2 changes: 1 addition & 1 deletion src/Opaleye/Internal/Tag.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Opaleye.Internal.Tag where
import Control.Monad.Trans.State.Strict ( get, modify', State )

-- | Tag is for use as a source of unique IDs in QueryArr
newtype Tag = UnsafeTag Int deriving (Read, Show)
newtype Tag = UnsafeTag Int deriving (Eq, Ord, Read, Show)

start :: Tag
start = UnsafeTag 1
Expand Down

0 comments on commit c7f0c83

Please sign in to comment.