Skip to content

Commit

Permalink
Unroll HO primitives when they would lead to non-representable let-bi…
Browse files Browse the repository at this point in the history
…ndings

Fixes #78
Fixes #25
  • Loading branch information
christiaanb committed Sep 11, 2015
1 parent 51e520a commit f6673fd
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 17 deletions.
2 changes: 1 addition & 1 deletion clash-lib/src/CLaSH/Driver/TopWrapper.hs
Expand Up @@ -177,7 +177,7 @@ mkOutput nms (i,hwty) cnt = case hwty of
netdecl = NetDecl iName hwty
assigns = zipWith
(\id_ n -> Assignment id_
(Identifier iName (Just (Indexed (hwty,1,n)))))
(Identifier iName (Just (Indexed (hwty,10,n)))))
ids
[0..]
in (nms',(ports',(netdecl:assigns ++ decls',iName)))
Expand Down
13 changes: 9 additions & 4 deletions clash-lib/src/CLaSH/Netlist.hs
Expand Up @@ -20,6 +20,7 @@ import Unbound.Generics.LocallyNameless (Embed (..), name2String,
unrebind)

import CLaSH.Core.DataCon (DataCon (..))
import CLaSH.Core.FreeVars (typeFreeVars)
import CLaSH.Core.Literal (Literal (..))
import CLaSH.Core.Pretty (showDoc)
import CLaSH.Core.Term (Pat (..), Term (..), TmName)
Expand Down Expand Up @@ -184,10 +185,14 @@ mkDeclarations bndr e@(Case scrut _ [alt]) = do
let dstId = mkBasicId . Text.pack . name2String $ varName bndr
altVarId = mkBasicId . Text.pack $ name2String varTm
modifier = case pat of
DataPat (Embed dc) ids -> let tms = case unrebind ids of
([],tms') -> tms'
_ -> error $ $(curLoc) ++ "Not in normal form: Pattern binds existential variables: " ++ showDoc e
in case elemIndex (Id varTm (Embed varTy)) tms of
DataPat (Embed dc) ids -> let (exts,tms) = unrebind ids
tmsTys = map (unembed . varType) tms
tmsFVs = concatMap (Lens.toListOf typeFreeVars) tmsTys
extNms = map varName exts
tms' = if any (`elem` tmsFVs) extNms
then error $ $(curLoc) ++ "Not in normal form: Pattern binds existential variables: " ++ showDoc e
else tms
in case elemIndex (Id varTm (Embed varTy)) tms' of
Nothing -> Nothing
Just fI
| sHwTy /= vHwTy -> Just (Indexed (sHwTy,dcTag dc - 1,fI))
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/CLaSH/Normalize/Strategy.hs
Expand Up @@ -25,7 +25,7 @@ constantPropgation :: NormRewrite
constantPropgation = propagate >-> repeatR inlineAndPropagate >-> lifting >-> spec
where
propagate = innerMost (applyMany transInner)
inlineAndPropagate = bottomupR (applyMany transBUP) !-> propagate
inlineAndPropagate = (bottomupR (applyMany transBUP) >-> bottomupR (apply "reduceNonRepPrim" reduceNonRepPrim)) !-> propagate
lifting = bottomupR (apply "liftNonRep" liftNonRep) -- See: [Note] bottom-up traversal for liftNonRep
spec = bottomupR (applyMany specRws)

Expand Down
179 changes: 172 additions & 7 deletions clash-lib/src/CLaSH/Normalize/Transformations.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}

-- | Transformations of the Normalization process
module CLaSH.Normalize.Transformations
Expand All @@ -25,6 +26,7 @@ module CLaSH.Normalize.Transformations
, inlineSmall
, simpleCSE
, reduceConst
, reduceNonRepPrim
)
where

Expand All @@ -37,19 +39,23 @@ import qualified Data.List as List
import qualified Data.Maybe as Maybe
import Unbound.Generics.LocallyNameless (Bind, Embed (..), bind, embed,
rec, unbind, unembed, unrebind,
unrec, name2String)
unrec, name2String, string2Name,
rebind)
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)

import CLaSH.Core.DataCon (DataCon, dcName, dcTag,
dcUnivTyVars)
import CLaSH.Core.DataCon (DataCon (..), dataConInstArgTys)
import CLaSH.Core.FreeVars (termFreeIds, termFreeTyVars,
typeFreeVars)
import CLaSH.Core.Pretty (showDoc)
import CLaSH.Core.Subst (substTm, substTms, substTyInTm,
substTysinTm)
import CLaSH.Core.Term (LetBinding, Pat (..), Term (..))
import CLaSH.Core.Type (TypeView (..), applyFunTy,
applyTy, splitFunTy, typeKind, tyView)
import CLaSH.Core.Type (TypeView (..), Type (..),
LitTy (..), applyFunTy,
applyTy, splitFunTy, typeKind,
tyView)
import CLaSH.Core.TyCon (tyConDataCons)
import CLaSH.Core.TysPrim (typeNatKind)
import CLaSH.Core.Util (collectArgs, idToVar, isCon,
isFun, isLet, isPolyFun, isPrim,
isVar, mkApps, mkLams, mkTmApps,
Expand Down Expand Up @@ -699,3 +705,162 @@ reduceConst _ e@(App _ _)
_ -> return e

reduceConst _ e = return e

-- | Replace primitives by their "definition" if they would lead to let-bindings
-- with a non-representable type when a function is in ANF. This happens for
-- example when CLaSH.Size.Vector.map consumes or produces a vector of
-- non-representable elements.
--
-- Basically what this transformation does is replace a primitive the completely
-- unrolled recursive definition that it represents. e.g.
--
-- > zipWith ($) (xs :: Vec 2 (Int -> Int)) (ys :: Vec 2 Int)
--
-- is replaced by:
--
-- > let (x0 :: (Int -> Int)) = case xs of (:>) _ x xr -> x
-- > (xr0 :: Vec 1 (Int -> Int)) = case xs of (:>) _ x xr -> xr
-- > (x1 :: (Int -> Int)( = case xr0 of (:>) _ x xr -> x
-- > (y0 :: Int) = case ys of (:>) _ y yr -> y
-- > (yr0 :: Vec 1 Int) = case ys of (:>) _ y yr -> xr
-- > (y1 :: Int = case yr0 of (:>) _ y yr -> y
-- > in (($) x0 y0 :> ($) x1 y1 :> Nil)
--
-- Currently, it only handles the following functions:
--
-- * CLaSH.Sized.Vector.map
-- * CLaSH.Sized.Vector.zipWith
reduceNonRepPrim :: NormRewrite
reduceNonRepPrim _ e@(App _ _)
| (Prim f _, args) <- collectArgs e
= case f of
"CLaSH.Sized.Vector.zipWith" | length args == 7 -> do
let [lhsElTy,rhsElty,resElTy,nTy] = Either.rights args
case nTy of
(LitTy (NumTy n)) -> do
untranslatableTys <- mapM isUntranslatableType [lhsElTy,rhsElty,resElTy]
if or untranslatableTys
then let [fun,lhsArg,rhsArg] = Either.lefts args
in reduceZipWith n lhsElTy rhsElty resElTy fun lhsArg rhsArg
else return e
_ -> return e
"CLaSH.Sized.Vector.map" | length args == 5 -> do
let [argElTy,resElTy,nTy] = Either.rights args
case nTy of
(LitTy (NumTy n)) -> do
untranslatableTys <- mapM isUntranslatableType [argElTy,resElTy]
if or untranslatableTys
then let [fun,arg] = Either.lefts args
in reduceMap n argElTy resElTy fun arg
else return e
_ -> return e
_ -> return e

reduceNonRepPrim _ e = return e

-- | Replace an application of @CLaSH.Sized.Vector.zipWith@ primitive on vectors
-- of a known length @n@, by the fully unrolled recursive "definition" of of
-- @CLaSH.Sized.Vector.zipWith@
reduceZipWith :: Int -- ^ Length of the vector(s)
-> Type -- ^ Type of the lhs of the function
-> Type -- ^ Type of the rhs of the function
-> Type -- ^ Type of the result of the function
-> Term -- ^ The zipWith'd functions
-> Term -- ^ The 1st vector argument
-> Term -- ^ The 2nd vector argument
-> NormalizeSession Term
reduceZipWith n lhsElTy rhsElTy resElTy fun lhsArg rhsArg = do
tcm <- Lens.view tcCache
(TyConApp vecTcNm _) <- tyView <$> termType tcm lhsArg
let (Just vecTc) = HashMap.lookup vecTcNm tcm
[nilCon,consCon] = tyConDataCons vecTc
(varsL,elemsL) = second concat . unzip $ extractElems consCon lhsElTy 'L' n lhsArg
(varsR,elemsR) = second concat . unzip $ extractElems consCon rhsElTy 'R' n rhsArg
funApps = zipWith (\l r -> mkApps fun [Left l,Left r]) varsL varsR
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (bind (rec (init elemsL ++ init elemsR)) lbody)
changed lb

-- | Replace an application of @CLaSH.Sized.Vector.map@ primitive on vectors
-- of a known length @n@, by the fully unrolled recursive "definition" of of
-- @CLaSH.Sized.Vector.map@
reduceMap :: Int -- ^ Length of the vector
-> Type -- ^ Argument type of the function
-> Type -- ^ Result type of the function
-> Term -- ^ The map'd function
-> Term -- ^ The map'd over vector
-> NormalizeSession Term
reduceMap n argElTy resElTy fun arg = do
tcm <- Lens.view tcCache
(TyConApp vecTcNm _) <- tyView <$> termType tcm arg
let (Just vecTc) = HashMap.lookup vecTcNm tcm
[nilCon,consCon] = tyConDataCons vecTc
(vars,elems) = second concat . unzip $ extractElems consCon argElTy 'A' n arg
funApps = map (fun `App`) vars
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (bind (rec (init elems)) lbody)
changed lb

-- | Create a vector of supplied elements
mkVec :: DataCon -- ^ The Nil constructor
-> DataCon -- ^ The Cons (:>) constructor
-> Type -- ^ Element type
-> Int -- ^ Length of the vector
-> [Term] -- ^ Elements to put in the vector
-> Term
mkVec nilCon consCon resTy = go
where
go _ [] = mkApps (Data nilCon) [Right (LitTy (NumTy 0))
,Right resTy
,Left (Prim "_CO_" nilCoTy)
]

go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n-1)))
,Left (Prim "_CO_" (consCoTy n))
,Left x
,Left (go (n-1) xs)]

nilCoTy = head (dataConInstArgTys nilCon [(LitTy (NumTy 0)),resTy])
consCoTy n = head (dataConInstArgTys consCon [(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n-1)))])

-- | Create let-bindings with case-statements that select elements out of a
-- vector. Returns both the variables to which element-selections are bound
-- and the let-bindings
extractElems :: DataCon -- ^ The Cons (:>) constructor
-> Type -- ^ The element type
-> Char -- ^ Char to append to the bound variable names
-> Int -- ^ Length of the vector
-> Term -- ^ The vector
-> [(Term,[LetBinding])]
extractElems consCon resTy s maxN = go maxN
where
go :: Int -> Term -> [(Term,[LetBinding])]
go 0 _ = []
go n e = (elVar
,[(Id elBNm (embed resTy) ,embed lhs)
,(Id restBNm (embed restTy),embed rhs)
]
) :
go (n-1) (Var restTy restBNm)

where
elBNm = string2Name ("el" ++ s:show (maxN-n))
restBNm = string2Name ("rest" ++ s:show (maxN-n))
elVar = Var resTy elBNm
pat = DataPat (embed consCon) (rebind [mTV] [co,el,rest])
elPatNm = string2Name "el"
restPatNm = string2Name "rest"
lhs = Case e resTy [bind pat (Var resTy elPatNm)]
rhs = Case e restTy [bind pat (Var restTy restPatNm)]

mName = string2Name "m"
mTV = TyVar mName (embed typeNatKind)
tys = [(LitTy (NumTy n)),resTy,(LitTy (NumTy (n-1)))]
idTys = dataConInstArgTys consCon tys
[co,el,rest] = zipWith Id [string2Name "_co_",elPatNm, restPatNm]
(map embed idTys)
restTy = last $ dataConInstArgTys consCon tys
14 changes: 13 additions & 1 deletion clash-lib/src/CLaSH/Rewrite/Util.hs
Expand Up @@ -405,12 +405,24 @@ isLocalVar (Var _ name)
$ Lens.use bindings
isLocalVar _ = return False

{-# INLINE isUntranslatable #-}
-- | Determine if a term cannot be represented in hardware
isUntranslatable :: Term
-> RewriteMonad extra Bool
isUntranslatable tm = do
tcm <- Lens.view tcCache
not <$> (representableType <$> Lens.view typeTranslator <*> pure tcm <*> termType tcm tm)
not <$> (representableType <$> Lens.view typeTranslator
<*> pure tcm
<*> termType tcm tm)

{-# INLINE isUntranslatableType #-}
-- | Determine if a type cannot be represented in hardware
isUntranslatableType :: Type
-> RewriteMonad extra Bool
isUntranslatableType ty =
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view tcCache
<*> pure ty)

-- | Is the Context a Lambda/Term-abstraction context?
isLambdaBodyCtx :: CoreContext
Expand Down
10 changes: 9 additions & 1 deletion clash-systemverilog/src/CLaSH/Backend/SystemVerilog.hs
Expand Up @@ -320,7 +320,15 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(SP _ args),dcI,fI)))) = fromSLV argT
end = start - argSize + 1

expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ _),_,fI)))) = text id_ <> dot <> verilogTypeMark ty <> "_sel" <> int fI
expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),_,fI)))) = text id_ <> brackets (int fI)

expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),1,1)))) = text id_ <> brackets (int 0)
expr_ _ (Identifier id_ (Just (Indexed ((Vector n _),1,2)))) = text id_ <> brackets (int 1 <> colon <> int (n-1))

-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput
-- Vector's don't have a 10'th constructor, this is just so that we can
-- recognize the particular case
expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),10,fI)))) = text id_ <> brackets (int fI)

expr_ _ (Identifier id_ (Just (DC (ty@(SP _ _),_)))) = text id_ <> brackets (int start <> colon <> int end)
where
start = typeSize ty - 1
Expand Down
18 changes: 17 additions & 1 deletion clash-verilog/src/CLaSH/Backend/Verilog.hs
Expand Up @@ -206,7 +206,23 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ argTys),_,fI)))) =
start = typeSize ty - 1 - otherSz
end = start - argSize + 1

expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),_,fI)))) =
expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),1,1)))) =
text id_ <> brackets (int start <> colon <> int end)
where
argSize = typeSize argTy
start = typeSize ty - 1
end = start - argSize + 1

expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),1,2)))) =
text id_ <> brackets (int start <> colon <> int 0)
where
argSize = typeSize argTy
start = typeSize ty - argSize - 1

-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput
-- Vector's don't have a 10'th constructor, this is just so that we can
-- recognize the particular case
expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),10,fI)))) =
text id_ <> brackets (int start <> colon <> int end)
where
argSize = typeSize argTy
Expand Down
10 changes: 9 additions & 1 deletion clash-vhdl/src/CLaSH/Backend/VHDL.hs
Expand Up @@ -485,7 +485,15 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(SP _ args),dcI,fI)))) = fromSLV argT
end = start - argSize + 1

expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ _),_,fI)))) = text id_ <> dot <> tyName ty <> "_sel" <> int fI
expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),_,fI)))) = text id_ <> parens (int fI)

expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),1,1)))) = text id_ <> parens (int 0)
expr_ _ (Identifier id_ (Just (Indexed ((Vector n _),1,2)))) = text id_ <> parens (int 1 <+> "to" <+> int (n-1))

-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput
-- Vector's don't have a 10'th constructor, this is just so that we can
-- recognize the particular case
expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),10,fI)))) = text id_ <> parens (int fI)

expr_ _ (Identifier id_ (Just (DC (ty@(SP _ _),_)))) = text id_ <> parens (int start <+> "downto" <+> int end)
where
start = typeSize ty - 1
Expand Down
4 changes: 4 additions & 0 deletions testsuite/Main.hs
Expand Up @@ -65,6 +65,10 @@ main =
, testGroup "Fixed"
[ runTest ("tests" </> "shouldwork" </> "Fixed") Both [] "SFixedTest" (Just ("testbench",True))
]
, testGroup "HOPrim"
[ runTest ("tests" </> "shouldwork" </> "HOPrim") Both [] "TestMap" (Just ("testbench",True))
, runTest ("tests" </> "shouldwork" </> "HOPrim") Both [] "VecFun" (Just ("testbench",True))
]
, testGroup "Numbers"
[ runTest ("tests" </> "shouldwork" </> "Numbers") Both [] "Resize" (Just ("testbench",True))
, runTest ("tests" </> "shouldwork" </> "Numbers") Both [] "Resize2" (Just ("testbench",True))
Expand Down

0 comments on commit f6673fd

Please sign in to comment.