Skip to content

Commit

Permalink
Rewrite only references in aggregation, not all values
Browse files Browse the repository at this point in the history
To recap, PostgreSQL does not allow lateral references in aggregation functions, e.g., the following SQL is invalid:

```sql
SELECT
  result
FROM
  (
    VALUES
      (1),
      (2),
      (3)
  ) _(x),
  LATERAL (
    SELECT
      sum(x) AS result
  ) __
```

It fails with the error message `ERROR:  aggregate functions are not allowed in FROM clause of their own query level`. Opaleye works around this limitation by rewriting the above query as follows:

```sql
SELECT
  result
FROM
  (
    VALUES
      (1),
      (2),
      (3)
  ) _(x),
  LATERAL (
    SELECT
      sum(inner1) AS result
    FROM
      (
        SELECT
          x AS inner1
      ) _
  ) __
```

The current implementation of this rewriting rewrites all arguments of aggregation functions regardless of whether they contain (potentially lateral) references or not. However, the same effect could be achieved if we only rewrote references and not all expressions. That's what this PR does.

This change is beneficial when using set aggregation functions such as `percentile_cont`. With the current rewriting scheme, Opaleye will generate:

```sql
SELECT
  result
FROM
  (
    VALUES
      (1),
      (2),
      (3)
  ) _(x),
  LATERAL (
    SELECT
      percentile_cont(inner1) WITHIN GROUP (ORDER BY inner2) AS result
    FROM
      (
        SELECT
          0.5 AS inner1,
          x AS inner2
      ) _
  ) __
```

Which fails with the error message:

```
ERROR:  column "_.inner1" must appear in the GROUP BY clause or be used in an aggregate function
LINE 12:       percentile_cont(inner1) WITHIN GROUP (ORDER BY inner2)
                               ^
DETAIL:  Direct arguments of an ordered-set aggregate must use only grouped columns.
```

After the change in this PR, this instead becomes:

```sql
SELECT
  result
FROM
  (
    VALUES
      (1),
      (2),
      (3)
  ) _(x),
  LATERAL (
    SELECT
      percentile_cont(0.5) WITHIN GROUP (ORDER BY inner1) AS result
    FROM
      (
        SELECT
          x AS inner1
      ) _
  ) __
```

Which works.
  • Loading branch information
shane-circuithub committed Jan 9, 2024
1 parent a7b0270 commit efb6549
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 51 deletions.
3 changes: 2 additions & 1 deletion opaleye.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ library
aeson >= 0.6 && < 2.3
, base >= 4.9 && < 4.20
, base16-bytestring >= 0.1.1.6 && < 1.1
, case-insensitive >= 1.2 && < 1.3
, bytestring >= 0.10 && < 0.12
, case-insensitive >= 1.2 && < 1.3
, containers >= 0.5 && < 0.8
, contravariant >= 1.2 && < 1.6
, postgresql-simple >= 0.6 && < 0.8
, pretty >= 1.1.1.0 && < 1.2
Expand Down
78 changes: 53 additions & 25 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
module Opaleye.Internal.Aggregate where

import Control.Applicative (liftA2)
import Control.Arrow ((***))
import Data.Foldable (toList)
import Data.Traversable (for)

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT, gets, modify, runStateT)

import qualified Opaleye.Field as F
import qualified Opaleye.Internal.Column as C
import qualified Opaleye.Internal.Order as O
Expand Down Expand Up @@ -130,42 +137,63 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) ->
--
-- Instead of detecting when we are aggregating over a field from a
-- previous query we just create new names for all field before we
-- aggregate. On the other hand, referring to a field from a previous
-- query in an ORDER BY expression is totally fine!
-- aggregate.
--
-- Additionally, PostgreSQL imposes a limitation on aggregations using ORDER
-- BY in combination with DISTINCT - essentially the expression you pass to
-- ORDER BY must also be present in the argument list to the aggregation
-- function. This means that not only do we also have to also create new
-- names for the ORDER BY expressions (if we only rewrite the function
-- arguments then they can't match and therefore ORDER BY can never be used
-- with DISTINCT), but that these names actually have to match the names
-- created for the aggregation function arguments. To accomplish this, when
-- traversing over the aggregations, we keep track of all the expressions
-- we've encountered so far, and only create new names for new expressions,
-- reusing old names where possible.
aggregateU :: Aggregator a b
-> (a, PQ.PrimQuery, T.Tag) -> (b, PQ.PrimQuery)
aggregateU agg (c0, primQ, t0) = (c1, primQ')
where (c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)
aggregateU agg (a, primQ, tag) = (b, primQ')
where
(inners, outers, b) =
runSymbols (runAggregator agg (extractAggregateFields tag) a)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
inners' = fmap (fmap HPQ.AttrExpr) inners

primQ' = PQ.Aggregate projPEs (PQ.Rebind True inners primQ)
primQ' = PQ.Aggregate outers (PQ.Rebind True inners' primQ)

extractAggregateFields
:: Traversable t
=> T.Tag
-> (t HPQ.PrimExpr)
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
-> t HPQ.PrimExpr
-> Symbols HPQ.Symbol (PQ.Bindings (t HPQ.PrimExpr)) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new

let souter = HPQ.Symbol ("result" ++ i) tag

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings
result <- mkSymbol "result" <$> lift PM.new
agg' <- traverse (HPQ.traverseSymbols (symbolize (mkSymbol "inner"))) agg
lift $ PM.write (result, agg')
pure $ HPQ.AttrExpr result
where
mkSymbol name i = HPQ.Symbol (name ++ i) tag

PM.write ((souter, agg'), toList bindings)
type Symbols e s =
StateT
(Map e HPQ.Symbol, PQ.Bindings e -> PQ.Bindings e)
(PM.PM s)

pure (HPQ.AttrExpr souter)
runSymbols :: Symbols e [s] a -> (PQ.Bindings e, [s], a)
runSymbols m = (dlist [], outers, a)
where
((a, (_, dlist)), outers) = PM.run $ runStateT m (Map.empty, id)

symbolize :: Ord e =>
(String -> HPQ.Symbol) -> e -> Symbols e s HPQ.Symbol
symbolize f expr = do
msymbol <- gets (Map.lookup expr . fst)
case msymbol of
Just symbol -> pure symbol
Nothing -> do
symbol <- f <$> lift PM.new
modify (Map.insert expr symbol *** (. ((symbol, expr) :)))
pure symbol

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
64 changes: 47 additions & 17 deletions src/Opaleye/Internal/HaskellDB/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
-- License : BSD-style

{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE LambdaCase #-}

module Opaleye.Internal.HaskellDB.PrimQuery where

Expand All @@ -17,7 +18,7 @@ type Name = String
type Scheme = [Attribute]
type Assoc = [(Attribute,PrimExpr)]

data Symbol = Symbol String T.Tag deriving (Read, Show)
data Symbol = Symbol String T.Tag deriving (Eq, Ord, Read, Show)

data PrimExpr = AttrExpr Symbol
| BaseTableAttrExpr Attribute
Expand All @@ -42,6 +43,29 @@ data PrimExpr = AttrExpr Symbol
| ArrayIndex PrimExpr PrimExpr
deriving (Read,Show)

traverseSymbols :: Applicative f => (Symbol -> f Symbol) -> PrimExpr -> f PrimExpr
traverseSymbols f = go
where
go = \case
AttrExpr symbol -> AttrExpr <$> f symbol
BaseTableAttrExpr attribute -> pure $ BaseTableAttrExpr attribute
CompositeExpr a attribute -> CompositeExpr <$> go a <*> pure attribute
BinExpr op a b -> BinExpr op <$> go a <*> go b
UnExpr op a -> UnExpr op <$> go a
AggrExpr aggr -> AggrExpr <$> traverse go aggr
WndwExpr wndw partition -> WndwExpr <$> traverse go wndw <*> traverse go partition
ConstExpr literal -> pure $ ConstExpr literal
CaseExpr conds a -> CaseExpr <$> traverse (bitraverse go go) conds <*> go a
ListExpr as -> ListExpr <$> traverse go as
ParamExpr name a -> ParamExpr name <$> go a
FunExpr name args -> FunExpr name <$> traverse go args
CastExpr name a -> CastExpr name <$> go a
DefaultInsertExpr -> pure DefaultInsertExpr
ArrayExpr as -> ArrayExpr <$> traverse go as
RangeExpr s a b -> RangeExpr s <$> traverse go a <*> traverse go b
ArrayIndex a b -> ArrayIndex <$> go a <*> go b
bitraverse g h (a, b) = (,) <$> g a <*> h b

data Literal = NullLit
| DefaultLit -- ^ represents a default value
| BoolLit Bool
Expand Down Expand Up @@ -119,26 +143,32 @@ data OrderOp = OrderOp { orderDirection :: OrderDirection
, orderNulls :: OrderNulls }
deriving (Show,Read)

data BoundExpr = Inclusive PrimExpr | Exclusive PrimExpr | PosInfinity | NegInfinity
deriving (Show,Read)
type BoundExpr = BoundExpr' PrimExpr

data BoundExpr' a = Inclusive a | Exclusive a | PosInfinity | NegInfinity
deriving (Foldable, Functor, Traversable, Read, Show)

type WndwOp = WndwOp' PrimExpr

data WndwOp
data WndwOp' a
= WndwRowNumber
| WndwRank
| WndwDenseRank
| WndwPercentRank
| WndwCumeDist
| WndwNtile PrimExpr
| WndwLag PrimExpr PrimExpr PrimExpr
| WndwLead PrimExpr PrimExpr PrimExpr
| WndwFirstValue PrimExpr
| WndwLastValue PrimExpr
| WndwNthValue PrimExpr PrimExpr
| WndwAggregate AggrOp [PrimExpr]
deriving (Show,Read)

data Partition = Partition
{ partitionBy :: [PrimExpr]
, orderBy :: [OrderExpr]
| WndwNtile a
| WndwLag a a a
| WndwLead a a a
| WndwFirstValue a
| WndwLastValue a
| WndwNthValue a a
| WndwAggregate AggrOp [a]
deriving (Foldable, Functor, Traversable, Show, Read)

type Partition = Partition' PrimExpr

data Partition' a = Partition
{ partitionBy :: [a]
, orderBy :: [OrderExpr' a]
}
deriving (Read, Show)
deriving (Foldable, Functor, Traversable, Read, Show)
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ data PrimQuery' a = Unit
| Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr]
-- | The subqueries to take the product of and the
-- restrictions to apply
| Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol))
| Aggregate (Bindings (HPQ.Aggregate))
(PrimQuery' a)
| Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a)
-- | Represents both @DISTINCT ON@ and @ORDER BY@
Expand Down Expand Up @@ -178,7 +178,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold
, empty :: a -> p'
, baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p'
, product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p'
, aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol)
, aggregate :: Bindings HPQ.Aggregate
-> p
-> p'
, window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p'
Expand Down
7 changes: 2 additions & 5 deletions src/Opaleye/Internal/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ product ss pes = SelectFrom $
PQ.Lateral -> Lateral
PQ.NonLateral -> NonLateral

aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol)
aggregate :: PQ.Bindings HPQ.Aggregate
-> Select
-> Select
aggregate aggrs' s =
aggregate aggrs s =
SelectFrom $ newSelect { attrs = SelectAttrs (ensureColumns (map attr aggrs))
, tables = oneTable s
, groupBy = Just (groupBy' aggrs) }
Expand All @@ -190,9 +190,6 @@ aggregate aggrs' s =
handleEmpty :: [HSql.SqlExpr] -> NEL.NonEmpty HSql.SqlExpr
handleEmpty = ensureColumnsGen SP.deliteral

aggrs :: [(Symbol, HPQ.Aggregate)]
aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs'

groupBy' :: [(symbol, HPQ.Aggregate)]
-> NEL.NonEmpty HSql.SqlExpr
groupBy' aggs = handleEmpty $ do
Expand Down
2 changes: 1 addition & 1 deletion src/Opaleye/Internal/Tag.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Opaleye.Internal.Tag where
import Control.Monad.Trans.State.Strict ( get, modify', State )

-- | Tag is for use as a source of unique IDs in QueryArr
newtype Tag = UnsafeTag Int deriving (Read, Show)
newtype Tag = UnsafeTag Int deriving (Eq, Ord, Read, Show)

start :: Tag
start = UnsafeTag 1
Expand Down

0 comments on commit efb6549

Please sign in to comment.