diff --git a/clash-lib/src/CLaSH/Driver/TopWrapper.hs b/clash-lib/src/CLaSH/Driver/TopWrapper.hs index efd7e0141a..3a69054829 100644 --- a/clash-lib/src/CLaSH/Driver/TopWrapper.hs +++ b/clash-lib/src/CLaSH/Driver/TopWrapper.hs @@ -177,7 +177,7 @@ mkOutput nms (i,hwty) cnt = case hwty of netdecl = NetDecl iName hwty assigns = zipWith (\id_ n -> Assignment id_ - (Identifier iName (Just (Indexed (hwty,1,n))))) + (Identifier iName (Just (Indexed (hwty,10,n))))) ids [0..] in (nms',(ports',(netdecl:assigns ++ decls',iName))) diff --git a/clash-lib/src/CLaSH/Netlist.hs b/clash-lib/src/CLaSH/Netlist.hs index fa1c95f158..922239142e 100644 --- a/clash-lib/src/CLaSH/Netlist.hs +++ b/clash-lib/src/CLaSH/Netlist.hs @@ -20,6 +20,7 @@ import Unbound.Generics.LocallyNameless (Embed (..), name2String, unrebind) import CLaSH.Core.DataCon (DataCon (..)) +import CLaSH.Core.FreeVars (typeFreeVars) import CLaSH.Core.Literal (Literal (..)) import CLaSH.Core.Pretty (showDoc) import CLaSH.Core.Term (Pat (..), Term (..), TmName) @@ -184,10 +185,14 @@ mkDeclarations bndr e@(Case scrut _ [alt]) = do let dstId = mkBasicId . Text.pack . name2String $ varName bndr altVarId = mkBasicId . Text.pack $ name2String varTm modifier = case pat of - DataPat (Embed dc) ids -> let tms = case unrebind ids of - ([],tms') -> tms' - _ -> error $ $(curLoc) ++ "Not in normal form: Pattern binds existential variables: " ++ showDoc e - in case elemIndex (Id varTm (Embed varTy)) tms of + DataPat (Embed dc) ids -> let (exts,tms) = unrebind ids + tmsTys = map (unembed . varType) tms + tmsFVs = concatMap (Lens.toListOf typeFreeVars) tmsTys + extNms = map varName exts + tms' = if any (`elem` tmsFVs) extNms + then error $ $(curLoc) ++ "Not in normal form: Pattern binds existential variables: " ++ showDoc e + else tms + in case elemIndex (Id varTm (Embed varTy)) tms' of Nothing -> Nothing Just fI | sHwTy /= vHwTy -> Just (Indexed (sHwTy,dcTag dc - 1,fI)) diff --git a/clash-lib/src/CLaSH/Normalize/Strategy.hs b/clash-lib/src/CLaSH/Normalize/Strategy.hs index 7fb894496e..f1c209e438 100644 --- a/clash-lib/src/CLaSH/Normalize/Strategy.hs +++ b/clash-lib/src/CLaSH/Normalize/Strategy.hs @@ -25,7 +25,7 @@ constantPropgation :: NormRewrite constantPropgation = propagate >-> repeatR inlineAndPropagate >-> lifting >-> spec where propagate = innerMost (applyMany transInner) - inlineAndPropagate = bottomupR (applyMany transBUP) !-> propagate + inlineAndPropagate = (bottomupR (applyMany transBUP) >-> bottomupR (apply "reduceNonRepPrim" reduceNonRepPrim)) !-> propagate lifting = bottomupR (apply "liftNonRep" liftNonRep) -- See: [Note] bottom-up traversal for liftNonRep spec = bottomupR (applyMany specRws) diff --git a/clash-lib/src/CLaSH/Normalize/Transformations.hs b/clash-lib/src/CLaSH/Normalize/Transformations.hs index c4fb9c9f0d..94904968e0 100644 --- a/clash-lib/src/CLaSH/Normalize/Transformations.hs +++ b/clash-lib/src/CLaSH/Normalize/Transformations.hs @@ -1,5 +1,6 @@ -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} -- | Transformations of the Normalization process module CLaSH.Normalize.Transformations @@ -25,6 +26,7 @@ module CLaSH.Normalize.Transformations , inlineSmall , simpleCSE , reduceConst + , reduceNonRepPrim ) where @@ -37,19 +39,23 @@ import qualified Data.List as List import qualified Data.Maybe as Maybe import Unbound.Generics.LocallyNameless (Bind, Embed (..), bind, embed, rec, unbind, unembed, unrebind, - unrec, name2String) + unrec, name2String, string2Name, + rebind) import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) -import CLaSH.Core.DataCon (DataCon, dcName, dcTag, - dcUnivTyVars) +import CLaSH.Core.DataCon (DataCon (..), dataConInstArgTys) import CLaSH.Core.FreeVars (termFreeIds, termFreeTyVars, typeFreeVars) import CLaSH.Core.Pretty (showDoc) import CLaSH.Core.Subst (substTm, substTms, substTyInTm, substTysinTm) import CLaSH.Core.Term (LetBinding, Pat (..), Term (..)) -import CLaSH.Core.Type (TypeView (..), applyFunTy, - applyTy, splitFunTy, typeKind, tyView) +import CLaSH.Core.Type (TypeView (..), Type (..), + LitTy (..), applyFunTy, + applyTy, splitFunTy, typeKind, + tyView) +import CLaSH.Core.TyCon (tyConDataCons) +import CLaSH.Core.TysPrim (typeNatKind) import CLaSH.Core.Util (collectArgs, idToVar, isCon, isFun, isLet, isPolyFun, isPrim, isVar, mkApps, mkLams, mkTmApps, @@ -699,3 +705,162 @@ reduceConst _ e@(App _ _) _ -> return e reduceConst _ e = return e + +-- | Replace primitives by their "definition" if they would lead to let-bindings +-- with a non-representable type when a function is in ANF. This happens for +-- example when CLaSH.Size.Vector.map consumes or produces a vector of +-- non-representable elements. +-- +-- Basically what this transformation does is replace a primitive the completely +-- unrolled recursive definition that it represents. e.g. +-- +-- > zipWith ($) (xs :: Vec 2 (Int -> Int)) (ys :: Vec 2 Int) +-- +-- is replaced by: +-- +-- > let (x0 :: (Int -> Int)) = case xs of (:>) _ x xr -> x +-- > (xr0 :: Vec 1 (Int -> Int)) = case xs of (:>) _ x xr -> xr +-- > (x1 :: (Int -> Int)( = case xr0 of (:>) _ x xr -> x +-- > (y0 :: Int) = case ys of (:>) _ y yr -> y +-- > (yr0 :: Vec 1 Int) = case ys of (:>) _ y yr -> xr +-- > (y1 :: Int = case yr0 of (:>) _ y yr -> y +-- > in (($) x0 y0 :> ($) x1 y1 :> Nil) +-- +-- Currently, it only handles the following functions: +-- +-- * CLaSH.Sized.Vector.map +-- * CLaSH.Sized.Vector.zipWith +reduceNonRepPrim :: NormRewrite +reduceNonRepPrim _ e@(App _ _) + | (Prim f _, args) <- collectArgs e + = case f of + "CLaSH.Sized.Vector.zipWith" | length args == 7 -> do + let [lhsElTy,rhsElty,resElTy,nTy] = Either.rights args + case nTy of + (LitTy (NumTy n)) -> do + untranslatableTys <- mapM isUntranslatableType [lhsElTy,rhsElty,resElTy] + if or untranslatableTys + then let [fun,lhsArg,rhsArg] = Either.lefts args + in reduceZipWith n lhsElTy rhsElty resElTy fun lhsArg rhsArg + else return e + _ -> return e + "CLaSH.Sized.Vector.map" | length args == 5 -> do + let [argElTy,resElTy,nTy] = Either.rights args + case nTy of + (LitTy (NumTy n)) -> do + untranslatableTys <- mapM isUntranslatableType [argElTy,resElTy] + if or untranslatableTys + then let [fun,arg] = Either.lefts args + in reduceMap n argElTy resElTy fun arg + else return e + _ -> return e + _ -> return e + +reduceNonRepPrim _ e = return e + +-- | Replace an application of @CLaSH.Sized.Vector.zipWith@ primitive on vectors +-- of a known length @n@, by the fully unrolled recursive "definition" of of +-- @CLaSH.Sized.Vector.zipWith@ +reduceZipWith :: Int -- ^ Length of the vector(s) + -> Type -- ^ Type of the lhs of the function + -> Type -- ^ Type of the rhs of the function + -> Type -- ^ Type of the result of the function + -> Term -- ^ The zipWith'd functions + -> Term -- ^ The 1st vector argument + -> Term -- ^ The 2nd vector argument + -> NormalizeSession Term +reduceZipWith n lhsElTy rhsElTy resElTy fun lhsArg rhsArg = do + tcm <- Lens.view tcCache + (TyConApp vecTcNm _) <- tyView <$> termType tcm lhsArg + let (Just vecTc) = HashMap.lookup vecTcNm tcm + [nilCon,consCon] = tyConDataCons vecTc + (varsL,elemsL) = second concat . unzip $ extractElems consCon lhsElTy 'L' n lhsArg + (varsR,elemsR) = second concat . unzip $ extractElems consCon rhsElTy 'R' n rhsArg + funApps = zipWith (\l r -> mkApps fun [Left l,Left r]) varsL varsR + lbody = mkVec nilCon consCon resElTy n funApps + lb = Letrec (bind (rec (init elemsL ++ init elemsR)) lbody) + changed lb + +-- | Replace an application of @CLaSH.Sized.Vector.map@ primitive on vectors +-- of a known length @n@, by the fully unrolled recursive "definition" of of +-- @CLaSH.Sized.Vector.map@ +reduceMap :: Int -- ^ Length of the vector + -> Type -- ^ Argument type of the function + -> Type -- ^ Result type of the function + -> Term -- ^ The map'd function + -> Term -- ^ The map'd over vector + -> NormalizeSession Term +reduceMap n argElTy resElTy fun arg = do + tcm <- Lens.view tcCache + (TyConApp vecTcNm _) <- tyView <$> termType tcm arg + let (Just vecTc) = HashMap.lookup vecTcNm tcm + [nilCon,consCon] = tyConDataCons vecTc + (vars,elems) = second concat . unzip $ extractElems consCon argElTy 'A' n arg + funApps = map (fun `App`) vars + lbody = mkVec nilCon consCon resElTy n funApps + lb = Letrec (bind (rec (init elems)) lbody) + changed lb + +-- | Create a vector of supplied elements +mkVec :: DataCon -- ^ The Nil constructor + -> DataCon -- ^ The Cons (:>) constructor + -> Type -- ^ Element type + -> Int -- ^ Length of the vector + -> [Term] -- ^ Elements to put in the vector + -> Term +mkVec nilCon consCon resTy = go + where + go _ [] = mkApps (Data nilCon) [Right (LitTy (NumTy 0)) + ,Right resTy + ,Left (Prim "_CO_" nilCoTy) + ] + + go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n)) + ,Right resTy + ,Right (LitTy (NumTy (n-1))) + ,Left (Prim "_CO_" (consCoTy n)) + ,Left x + ,Left (go (n-1) xs)] + + nilCoTy = head (dataConInstArgTys nilCon [(LitTy (NumTy 0)),resTy]) + consCoTy n = head (dataConInstArgTys consCon [(LitTy (NumTy n)) + ,resTy + ,(LitTy (NumTy (n-1)))]) + +-- | Create let-bindings with case-statements that select elements out of a +-- vector. Returns both the variables to which element-selections are bound +-- and the let-bindings +extractElems :: DataCon -- ^ The Cons (:>) constructor + -> Type -- ^ The element type + -> Char -- ^ Char to append to the bound variable names + -> Int -- ^ Length of the vector + -> Term -- ^ The vector + -> [(Term,[LetBinding])] +extractElems consCon resTy s maxN = go maxN + where + go :: Int -> Term -> [(Term,[LetBinding])] + go 0 _ = [] + go n e = (elVar + ,[(Id elBNm (embed resTy) ,embed lhs) + ,(Id restBNm (embed restTy),embed rhs) + ] + ) : + go (n-1) (Var restTy restBNm) + + where + elBNm = string2Name ("el" ++ s:show (maxN-n)) + restBNm = string2Name ("rest" ++ s:show (maxN-n)) + elVar = Var resTy elBNm + pat = DataPat (embed consCon) (rebind [mTV] [co,el,rest]) + elPatNm = string2Name "el" + restPatNm = string2Name "rest" + lhs = Case e resTy [bind pat (Var resTy elPatNm)] + rhs = Case e restTy [bind pat (Var restTy restPatNm)] + + mName = string2Name "m" + mTV = TyVar mName (embed typeNatKind) + tys = [(LitTy (NumTy n)),resTy,(LitTy (NumTy (n-1)))] + idTys = dataConInstArgTys consCon tys + [co,el,rest] = zipWith Id [string2Name "_co_",elPatNm, restPatNm] + (map embed idTys) + restTy = last $ dataConInstArgTys consCon tys diff --git a/clash-lib/src/CLaSH/Rewrite/Util.hs b/clash-lib/src/CLaSH/Rewrite/Util.hs index af48a44994..538f191d7c 100644 --- a/clash-lib/src/CLaSH/Rewrite/Util.hs +++ b/clash-lib/src/CLaSH/Rewrite/Util.hs @@ -405,12 +405,24 @@ isLocalVar (Var _ name) $ Lens.use bindings isLocalVar _ = return False +{-# INLINE isUntranslatable #-} -- | Determine if a term cannot be represented in hardware isUntranslatable :: Term -> RewriteMonad extra Bool isUntranslatable tm = do tcm <- Lens.view tcCache - not <$> (representableType <$> Lens.view typeTranslator <*> pure tcm <*> termType tcm tm) + not <$> (representableType <$> Lens.view typeTranslator + <*> pure tcm + <*> termType tcm tm) + +{-# INLINE isUntranslatableType #-} +-- | Determine if a type cannot be represented in hardware +isUntranslatableType :: Type + -> RewriteMonad extra Bool +isUntranslatableType ty = + not <$> (representableType <$> Lens.view typeTranslator + <*> Lens.view tcCache + <*> pure ty) -- | Is the Context a Lambda/Term-abstraction context? isLambdaBodyCtx :: CoreContext diff --git a/clash-systemverilog/src/CLaSH/Backend/SystemVerilog.hs b/clash-systemverilog/src/CLaSH/Backend/SystemVerilog.hs index 90a05af01e..af144d5af7 100644 --- a/clash-systemverilog/src/CLaSH/Backend/SystemVerilog.hs +++ b/clash-systemverilog/src/CLaSH/Backend/SystemVerilog.hs @@ -320,7 +320,15 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(SP _ args),dcI,fI)))) = fromSLV argT end = start - argSize + 1 expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ _),_,fI)))) = text id_ <> dot <> verilogTypeMark ty <> "_sel" <> int fI -expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),_,fI)))) = text id_ <> brackets (int fI) + +expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),1,1)))) = text id_ <> brackets (int 0) +expr_ _ (Identifier id_ (Just (Indexed ((Vector n _),1,2)))) = text id_ <> brackets (int 1 <> colon <> int (n-1)) + +-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput +-- Vector's don't have a 10'th constructor, this is just so that we can +-- recognize the particular case +expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),10,fI)))) = text id_ <> brackets (int fI) + expr_ _ (Identifier id_ (Just (DC (ty@(SP _ _),_)))) = text id_ <> brackets (int start <> colon <> int end) where start = typeSize ty - 1 diff --git a/clash-verilog/src/CLaSH/Backend/Verilog.hs b/clash-verilog/src/CLaSH/Backend/Verilog.hs index fc2866cd63..7e85f2fb8e 100644 --- a/clash-verilog/src/CLaSH/Backend/Verilog.hs +++ b/clash-verilog/src/CLaSH/Backend/Verilog.hs @@ -206,7 +206,23 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ argTys),_,fI)))) = start = typeSize ty - 1 - otherSz end = start - argSize + 1 -expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),_,fI)))) = +expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),1,1)))) = + text id_ <> brackets (int start <> colon <> int end) + where + argSize = typeSize argTy + start = typeSize ty - 1 + end = start - argSize + 1 + +expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),1,2)))) = + text id_ <> brackets (int start <> colon <> int 0) + where + argSize = typeSize argTy + start = typeSize ty - argSize - 1 + +-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput +-- Vector's don't have a 10'th constructor, this is just so that we can +-- recognize the particular case +expr_ _ (Identifier id_ (Just (Indexed (ty@(Vector _ argTy),10,fI)))) = text id_ <> brackets (int start <> colon <> int end) where argSize = typeSize argTy diff --git a/clash-vhdl/src/CLaSH/Backend/VHDL.hs b/clash-vhdl/src/CLaSH/Backend/VHDL.hs index 8849ac721f..9bc3167027 100644 --- a/clash-vhdl/src/CLaSH/Backend/VHDL.hs +++ b/clash-vhdl/src/CLaSH/Backend/VHDL.hs @@ -485,7 +485,15 @@ expr_ _ (Identifier id_ (Just (Indexed (ty@(SP _ args),dcI,fI)))) = fromSLV argT end = start - argSize + 1 expr_ _ (Identifier id_ (Just (Indexed (ty@(Product _ _),_,fI)))) = text id_ <> dot <> tyName ty <> "_sel" <> int fI -expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),_,fI)))) = text id_ <> parens (int fI) + +expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),1,1)))) = text id_ <> parens (int 0) +expr_ _ (Identifier id_ (Just (Indexed ((Vector n _),1,2)))) = text id_ <> parens (int 1 <+> "to" <+> int (n-1)) + +-- This is a HACK for CLaSH.Driver.TopWrapper.mkOutput +-- Vector's don't have a 10'th constructor, this is just so that we can +-- recognize the particular case +expr_ _ (Identifier id_ (Just (Indexed ((Vector _ _),10,fI)))) = text id_ <> parens (int fI) + expr_ _ (Identifier id_ (Just (DC (ty@(SP _ _),_)))) = text id_ <> parens (int start <+> "downto" <+> int end) where start = typeSize ty - 1 diff --git a/testsuite/Main.hs b/testsuite/Main.hs index 29b2b108fc..d64263e73b 100755 --- a/testsuite/Main.hs +++ b/testsuite/Main.hs @@ -65,6 +65,10 @@ main = , testGroup "Fixed" [ runTest ("tests" "shouldwork" "Fixed") Both [] "SFixedTest" (Just ("testbench",True)) ] + , testGroup "HOPrim" + [ runTest ("tests" "shouldwork" "HOPrim") Both [] "TestMap" (Just ("testbench",True)) + , runTest ("tests" "shouldwork" "HOPrim") Both [] "VecFun" (Just ("testbench",True)) + ] , testGroup "Numbers" [ runTest ("tests" "shouldwork" "Numbers") Both [] "Resize" (Just ("testbench",True)) , runTest ("tests" "shouldwork" "Numbers") Both [] "Resize2" (Just ("testbench",True))