Skip to content

Commit

Permalink
Use less Opaleye internals in aggregation (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomjaguarpaw committed Nov 30, 2022
1 parent 42c1e1c commit 6c038ec
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 90 deletions.
2 changes: 1 addition & 1 deletion rel8.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ library
, comonad
, contravariant
, hasql ^>= 1.4.5.1 || ^>= 1.5.0.0 || ^>= 1.6.0.0
, opaleye ^>= 0.9.6.0
, opaleye ^>= 0.9.6.1
, pretty
, profunctors
, product-profunctors
Expand Down
30 changes: 11 additions & 19 deletions src/Rel8/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
{-# language FlexibleContexts #-}
{-# language FlexibleInstances #-}
{-# language MultiParamTypeClasses #-}
{-# language NamedFieldPuns #-}
{-# language RankNTypes #-}
{-# language StandaloneKindSignatures #-}
{-# language TypeFamilies #-}
{-# language UndecidableInstances #-}

module Rel8.Aggregate
( Aggregate(..), zipOutputs
, Aggregator(..), unsafeMakeAggregate
, unsafeMakeAggregate
, Aggregates
)
where
Expand All @@ -21,10 +20,13 @@ import Data.Functor.Identity ( Identity( Identity ) )
import Data.Kind ( Constraint, Type )
import Prelude

-- profunctors
import Data.Profunctor ( dimap )

-- opaleye
import qualified Opaleye.Internal.Aggregate as Opaleye
import qualified Opaleye.Aggregate as Opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.PackMap as Opaleye
import qualified Opaleye.Internal.Column as Opaleye

-- rel8
import Rel8.Expr ( Expr )
Expand Down Expand Up @@ -69,23 +71,13 @@ zipOutputs :: ()
zipOutputs f (Aggregate a) (Aggregate b) = Aggregate (liftA2 f a b)


type Aggregator :: Type
data Aggregator = Aggregator
{ operation :: Opaleye.AggrOp
, ordering :: [Opaleye.OrderExpr]
, distinction :: Opaleye.AggrDistinct
}


unsafeMakeAggregate :: forall (input :: Type) (output :: Type). ()
unsafeMakeAggregate :: forall (input :: Type) (output :: Type) n n' a a'. ()
=> (Expr input -> Opaleye.PrimExpr)
-> (Opaleye.PrimExpr -> Expr output)
-> Maybe Aggregator
-> Opaleye.Aggregator (Opaleye.Field_ n a) (Opaleye.Field_ n' a')
-> Expr input
-> Aggregate output
unsafeMakeAggregate input output aggregator expr =
Aggregate $ Opaleye.Aggregator $ Opaleye.PackMap $ \f _ ->
output <$> f (tuplize <$> aggregator, input expr)
where
tuplize Aggregator {operation, ordering, distinction} =
(operation, ordering, distinction)
Aggregate $ dimap in_ out aggregator
where out = output . Opaleye.unColumn
in_ = Opaleye.Column . input . const expr
86 changes: 16 additions & 70 deletions src/Rel8/Expr/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import Data.List.NonEmpty ( NonEmpty )
import Prelude hiding ( and, max, min, null, or, sum )

-- opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Aggregate as Opaleye
import Opaleye.Internal.Column ( Field_( Column ) )
import qualified Opaleye.Aggregate as Opaleye

-- rel8
import Rel8.Aggregate ( Aggregate, Aggregator(..), unsafeMakeAggregate )
import Rel8.Aggregate ( Aggregate, unsafeMakeAggregate )
import Rel8.Expr ( Expr )
import Rel8.Expr.Bool ( caseExpr )
import Rel8.Expr.Opaleye
Expand All @@ -51,23 +53,13 @@ import Rel8.Type.Sum ( DBSum )

-- | Count the occurances of a single column. Corresponds to @COUNT(a)@
count :: Expr a -> Aggregate Int64
count = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrCount
, ordering = []
, distinction = Opaleye.AggrAll
}

count = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.count

-- | Count the number of distinct occurances of a single column. Corresponds to
-- @COUNT(DISTINCT a)@
countDistinct :: Sql DBEq a => Expr a -> Aggregate Int64
countDistinct = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrCount
, ordering = []
, distinction = Opaleye.AggrDistinct
}
Opaleye.distinctAggregator Opaleye.count


-- | Corresponds to @COUNT(*)@.
Expand All @@ -82,55 +74,30 @@ countWhere condition = count (caseExpr [(condition, litExpr (Just True))] null)

-- | Corresponds to @bool_and@.
and :: Expr Bool -> Aggregate Bool
and = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrBoolAnd
, ordering = []
, distinction = Opaleye.AggrAll
}
and = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.boolAnd


-- | Corresponds to @bool_or@.
or :: Expr Bool -> Aggregate Bool
or = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrBoolOr
, ordering = []
, distinction = Opaleye.AggrAll
}
or = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.boolOr


-- | Produce an aggregation for @Expr a@ using the @max@ function.
max :: Sql DBMax a => Expr a -> Aggregate a
max = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrMax
, ordering = []
, distinction = Opaleye.AggrAll
}
max = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.unsafeMax


-- | Produce an aggregation for @Expr a@ using the @max@ function.
min :: Sql DBMin a => Expr a -> Aggregate a
min = unsafeMakeAggregate toPrimExpr fromPrimExpr $
Just Aggregator
{ operation = Opaleye.AggrMin
, ordering = []
, distinction = Opaleye.AggrAll
}
min = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.unsafeMin

-- | Corresponds to @sum@. Note that in SQL, @sum@ is type changing - for
-- example the @sum@ of @integer@ returns a @bigint@. Rel8 doesn't support
-- this, and will add explicit casts back to the original input type. This can
-- lead to overflows, and if you anticipate very large sums, you should upcast
-- your input.
sum :: Sql DBSum a => Expr a -> Aggregate a
sum = unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) $
Just Aggregator
{ operation = Opaleye.AggrSum
, ordering = []
, distinction = Opaleye.AggrAll
}
sum = unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) Opaleye.unsafeSum


-- | Corresponds to @avg@. Note that in SQL, @avg@ is type changing - for
Expand All @@ -139,13 +106,7 @@ sum = unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) $
-- need a fractional result on an integral column, you should cast your input
-- to 'Double' or 'Data.Scientific.Scientific' before calling 'avg'.
avg :: Sql DBSum a => Expr a -> Aggregate a
avg = unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) $
Just Aggregator
{ operation = Opaleye.AggrAvg
, ordering = []
, distinction = Opaleye.AggrAll
}

avg = unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) Opaleye.unsafeAvg

-- | Take the sum of all expressions that satisfy a predicate.
sumWhere :: (Sql DBNum a, Sql DBSum a)
Expand All @@ -157,17 +118,12 @@ sumWhere condition a = sum (caseExpr [(condition, a)] 0)
stringAgg :: Sql DBString a
=> Expr db -> Expr a -> Aggregate a
stringAgg delimiter =
unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) $
Just Aggregator
{ operation = Opaleye.AggrStringAggr (toPrimExpr delimiter)
, ordering = []
, distinction = Opaleye.AggrAll
}
unsafeMakeAggregate toPrimExpr (castExpr . fromPrimExpr) (Opaleye.stringAgg (Column (toPrimExpr delimiter)))


-- | Aggregate a value by grouping by it.
groupByExpr :: Sql DBEq a => Expr a -> Aggregate a
groupByExpr = unsafeMakeAggregate toPrimExpr fromPrimExpr Nothing
groupByExpr = unsafeMakeAggregate toPrimExpr fromPrimExpr Opaleye.groupBy


-- | Collect expressions values as a list.
Expand All @@ -182,23 +138,13 @@ nonEmptyAggExpr = snonEmptyAggExpr typeInformation

slistAggExpr :: ()
=> TypeInformation (Unnullify a) -> Expr a -> Aggregate [a]
slistAggExpr info = unsafeMakeAggregate to fromPrimExpr $ Just
Aggregator
{ operation = Opaleye.AggrArr
, ordering = []
, distinction = Opaleye.AggrAll
}
slistAggExpr info = unsafeMakeAggregate to fromPrimExpr Opaleye.arrayAgg
where
to = encodeArrayElement info . toPrimExpr


snonEmptyAggExpr :: ()
=> TypeInformation (Unnullify a) -> Expr a -> Aggregate (NonEmpty a)
snonEmptyAggExpr info = unsafeMakeAggregate to fromPrimExpr $ Just
Aggregator
{ operation = Opaleye.AggrArr
, ordering = []
, distinction = Opaleye.AggrAll
}
snonEmptyAggExpr info = unsafeMakeAggregate to fromPrimExpr Opaleye.arrayAgg
where
to = encodeArrayElement info . toPrimExpr

0 comments on commit 6c038ec

Please sign in to comment.