Skip to content

Commit

Permalink
Implement a floor function (from Prob to Nat in this case).
Browse files Browse the repository at this point in the history
  • Loading branch information
JacquesCarette committed Feb 20, 2018
1 parent 44429bd commit 62c662f
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 15 deletions.
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Disintegrate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ constrainPrimOp v0 = go
go Asinh = \(e1 :* End) -> error_TODO "Asinh"
go Acosh = \(e1 :* End) -> error_TODO "Acosh"
go Atanh = \(e1 :* End) -> error_TODO "Atanh"
go Floor = \(e1 :* End) -> error_TODO "Floor"
go RealPow = \(e1 :* e2 :* End) ->
-- TODO: There's a discrepancy between @(**)@ and @pow_@ in
-- the old code...
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Evaluation/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ evaluatePrimOp evaluate_ = go
go Asinh (e1 :* End) = neu1 P.asinh e1
go Acosh (e1 :* End) = neu1 P.acosh e1
go Atanh (e1 :* End) = neu1 P.atanh e1
go Floor (e1 :* End) = neu1 P.floor e1

-- TODO: deal with how we have better types for these three ops than Haskell does...
-- go RealPow (e1 :* e2 :* End) = rr2 (**) (P.**) e1 e2
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Parser/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ data PrimOp
| Equal | Less
| Negate | Recip
| Abs | Signum | NatRoot | Erf
| Floor
deriving (Eq, Show)

data SomeOp op where
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Parser/SymbolResolve.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ primTable =
,("asinh", primPrimOp1 U.Asinh)
,("acosh", primPrimOp1 U.Acosh)
,("atanh", primPrimOp1 U.Atanh)
,("floor", primPrimOp1 U.Floor)
-- ArrayOps
,("size", TLam $ \x -> TNeu . syn $ U.ArrayOp_ U.Size [x])
,("reduce", t3 $ \x y z -> syn $ U.ArrayOp_ U.Reduce [x, y, z])
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Pretty/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ ppPrimOp p (Signum _) (e1 :* End) = ppApply1 p "signum" e1
ppPrimOp p (Recip _) (e1 :* End) = ppRecip p e1
ppPrimOp p (NatRoot _) (e1 :* e2 :* End) = ppNatRoot p e1 e2
ppPrimOp p (Erf _) (e1 :* End) = ppApply1 p "erf" e1
ppPrimOp p Floor (e1 :* End) = ppApply1 p "floor" e1

ppNegate :: (ABT Term abt) => Int -> abt '[] a -> Doc
ppNegate p e = parensIf (p > 6) $
Expand Down
3 changes: 2 additions & 1 deletion haskell/Language/Hakaru/Pretty/Haskell.hs
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ ppPrimOp p (Recip _) = \(e1 :* End) -> ppApply1 p "recip" e1
ppPrimOp p (NatRoot _) = \(e1 :* e2 :* End) ->
-- N.B., argument order is swapped!
ppBinop "`thRootOf`" 9 LeftAssoc p e2 e1
ppPrimOp p (Erf _) = \(e1 :* End) -> ppApply1 p "erf" e1
ppPrimOp p (Erf _) = \(e1 :* End) -> ppApply1 p "erf" e1
ppPrimOp p Floor = \(e1 :* End) -> ppApply1 p "floor" e1


-- | Pretty-print a 'ArrayOp' @(:$)@ node in the AST.
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Pretty/Maple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ maplePrimOp (Negate _) (e1 :* End) = parens (app1 "-" e1)
maplePrimOp (Abs _) (e1 :* End) = app1 "abs" e1
maplePrimOp (Recip _) (e1 :* End) = app1 "1/" e1
maplePrimOp (NatRoot _) (e1 :* e2 :* End) = app2 "root" e1 e2
maplePrimOp Floor (e1 :* End) = app1 "floor" e1
maplePrimOp x _ =
error $ "TODO: maplePrimOp{" ++ show x ++ "}"

Expand Down
21 changes: 9 additions & 12 deletions haskell/Language/Hakaru/Pretty/SExpression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,17 @@ import Data.Foldable (foldMap)
import Control.Applicative ((<$>))
#endif

import System.IO (stderr)
import Data.Ratio
import Data.Text (Text)
import Data.Sequence (Seq)

import qualified Data.Text as Text
import Data.Number.Nat (fromNat)
import Data.Number.Natural (fromNatural, fromNonNegativeRational)
import Data.Ratio
import Data.Number.Natural (fromNonNegativeRational)
import qualified Data.List.NonEmpty as L
import Data.Sequence (Seq)
import Data.Text.IO as IO
import Language.Hakaru.Command (parseAndInfer)
import Language.Hakaru.Syntax.IClasses (fmap11, foldMap11, jmEq1, TypeEq(..))
import Language.Hakaru.Syntax.IClasses (jmEq1, TypeEq(..))
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.HClasses
Expand All @@ -43,10 +40,6 @@ import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.Reducer
import Language.Hakaru.Syntax.TypeCheck
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Types.Sing
import Text.PrettyPrint (Doc, (<>), (<+>))
import Text.PrettyPrint as PP

Expand Down Expand Up @@ -117,7 +110,7 @@ prettyReducer (Red_Split i red_a red_b) =
PP.parens (PP.text "r_split" <+> prettyViewABT i <+>
prettyReducer red_a <+> prettyReducer red_b)
prettyReducer (Red_Nop) = PP.text "r_nop"
prettyReducer (Red_Add s a) =
prettyReducer (Red_Add _ a) =
PP.parens (PP.text "r_add" <+> prettyViewABT a)

prettyBranch :: (ABT Term abt) => Branch a abt b -> Doc
Expand All @@ -136,7 +129,7 @@ goCode c = PP.parens $ case c of
goStruct :: PDatumStruct xs vars a -> Doc
goStruct s = PP.parens $ case s of
(PDone) -> PP.text "ps_done"
(PEt f s) -> PP.text "ps_et" <+> goFun f <+> goStruct s
(PEt f s') -> PP.text "ps_et" <+> goFun f <+> goStruct s'
goFun :: PDatumFun x vars a -> Doc
goFun f = PP.parens $ case f of
(PKonst p) -> PP.text "pf_konst" <+> prettyPattern p
Expand Down Expand Up @@ -241,6 +234,7 @@ prettyNary (Sum _) es = PP.text "+" <+> foldMap pretty es
prettyNary (Prod _) es = PP.text "*" <+> foldMap pretty es
prettyNary (Min _) es = PP.text "min" <+> foldMap pretty es
prettyNary (Max _) es = PP.text "max" <+> foldMap pretty es
prettyNary _ _ = error "Pretty.SExpression - prettyNary missing cases"

prettyType :: Sing (a :: Hakaru) -> Doc
prettyType SNat = PP.text "nat"
Expand Down Expand Up @@ -286,12 +280,15 @@ prettyPrimOp (Negate _) (e1 :* End) = PP.text "negate" <+> pretty e1
prettyPrimOp (Abs _) (e1 :* End) = PP.text "abs" <+> pretty e1
prettyPrimOp (Recip _) (e1 :* End) = PP.text "recip" <+> pretty e1
prettyPrimOp (NatRoot _) (e1 :* e2 :* End) = PP.text "root" <+> pretty e1 <+> pretty e2
prettyPrimOp Floor (e1 :* End) = PP.text "floor" <+> pretty e1
prettyPrimOp _ _ = error "prettyPrimop: a bunch of cases still need done!"

prettyArrayOp
:: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> ArrayOp typs a -> SArgs abt args -> Doc
prettyArrayOp (Index _) (e1 :* e2 :* End) = PP.text "index" <+> pretty e1 <+> pretty e2
prettyArrayOp (Size _) (e1 :* End) = PP.text "size" <+> pretty e1
prettyArrayOp (Reduce _) _ = error "prettyArrayOp doesn't know how to print Reduce"

prettyFile' :: [Char] -> [Char] -> IO ()
prettyFile' fname outFname = do
Expand All @@ -303,7 +300,7 @@ prettyFile' fname outFname = do
runPretty' :: Text -> IO String
runPretty' prog =
case parseAndInfer prog of
Left err -> return "err"
Left _ -> return "err"
Right (TypedAST _ ast) -> do
summarised <- summary . expandTransformations $ ast
return . render . pretty $ summarised
Expand Down
9 changes: 7 additions & 2 deletions haskell/Language/Hakaru/Syntax/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ import Data.Traversable
import Control.Arrow ((***))
import Data.Ratio (numerator, denominator)

import Data.Data (Data, Typeable)
import Data.Data ()

import Data.Number.Natural
import Language.Hakaru.Syntax.IClasses
Expand Down Expand Up @@ -116,7 +116,8 @@ instance Eq1 Literal where
eq1 (LInt x) (LInt y) = x == y
eq1 (LProb x) (LProb y) = x == y
eq1 (LReal x) (LReal y) = x == y
eq1 _ _ = False
-- Because of GADTs, the following is apparently redundant
-- eq1 _ _ = False

instance Eq (Literal a) where
(==) = eq1
Expand Down Expand Up @@ -408,6 +409,10 @@ data PrimOp :: [Hakaru] -> Hakaru -> * where
-- do not have all units and thus do not support signum\/normalize?


-- Coecion-like operations that are computations
-- we only implement Floor for Prob for now?
Floor :: PrimOp '[ 'HProb ] 'HNat

-- -- HFractional operators
Recip :: !(HFractional a) -> PrimOp '[ a ] a
-- generates macro: IntPow
Expand Down
1 change: 1 addition & 0 deletions haskell/Language/Hakaru/Syntax/AST/Sing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ sing_PrimOp Atanh = (sing `Cons1` Nil1, sing)
sing_PrimOp RealPow = (sing `Cons1` sing `Cons1` Nil1, sing)
sing_PrimOp Exp = (sing `Cons1` Nil1, sing)
sing_PrimOp Log = (sing `Cons1` Nil1, sing)
sing_PrimOp Floor = (sing `Cons1` Nil1, sing)
sing_PrimOp (Infinity h) = (Nil1, sing_HIntegrable h)
sing_PrimOp GammaFunc = (sing `Cons1` Nil1, sing)
sing_PrimOp BetaFunc = (sing `Cons1` sing `Cons1` Nil1, sing)
Expand Down
4 changes: 4 additions & 0 deletions haskell/Language/Hakaru/Syntax/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ module Language.Hakaru.Syntax.Prelude
, negativeInfinity
-- *** Trig
, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh
-- *** coercions-than-compute
, floor

-- * Measures
-- ** Abstract nonsense
Expand Down Expand Up @@ -717,6 +719,8 @@ asinh = primOp1_ Asinh
acosh = primOp1_ Acosh
atanh = primOp1_ Atanh

floor :: (ABT Term abt) => abt '[] 'HProb -> abt '[] 'HNat
floor = primOp1_ Floor

----------------------------------------------------------------
datum_
Expand Down
6 changes: 6 additions & 0 deletions haskell/Language/Hakaru/Syntax/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,12 @@ inferType = inferType_
syn (PrimOp_ y :$ e' :* End)
_ -> argumentNumberError

inferPrimOp U.Floor es =
case es of
[e] -> do e' <- checkType_ SProb e
return . TypedAST SNat $ syn (PrimOp_ Floor :$ e' :* End)
_ -> argumentNumberError

inferPrimOp x _ = error ("TODO: inferPrimOp: " ++ show x)


Expand Down

0 comments on commit 62c662f

Please sign in to comment.