Skip to content

Commit

Permalink
Unroll primitives for splittable types
Browse files Browse the repository at this point in the history
Fixes #1606
  • Loading branch information
christiaanb committed Mar 28, 2021
1 parent 88dcfd0 commit 279c189
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 91 deletions.
2 changes: 1 addition & 1 deletion Clash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ doHDL b src = do

generateHDL (buildCustomReprs reprs) domainConfs bindingsMap (Just b) primMap tcm tupTcm
(ghcTypeToHWType WORD_SIZE_IN_BITS True) evaluator topEntities Nothing
defClashOpts{opt_cachehdl = False, opt_dbgLevel = DebugSilent}
defClashOpts{opt_cachehdl = False, opt_dbgLevel = DebugSilent, opt_clear = True}
(startTime,prepTime)

main :: IO ()
Expand Down
1 change: 1 addition & 0 deletions changelog/2021-03-28T12_39_05+02_00_fix1606
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: Blackboxes of `Clash.Sized.Vector` functions error on vectors containing `Clocks`, `Reset`, or `Enable` [#1606](https://github.com/clash-lang/clash-compiler/issues/1606)
3 changes: 1 addition & 2 deletions clash-ghc/src-ghc/Clash/GHC/Evaluator/Primitive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ import Clash.Core.TyCon
import Clash.Core.TysPrim
import Clash.Core.Util
(mkRTree,mkVec,tyNatSize,dataConInstArgTys,primCo,
undefinedTm)
undefinedTm, mkSelectorCase)
import Clash.Core.Var (mkLocalId, mkTyVar)
import Clash.Debug
import Clash.GHC.GHC2Core (modNameM)
import Clash.Rewrite.Util (mkSelectorCase)
import Clash.Unique (lookupUniqMap)
import Clash.Util
(MonadUnique (..), clogBase, flogBase, curLoc)
Expand Down
2 changes: 1 addition & 1 deletion clash-ghc/src-ghc/Clash/GHC/GenerateBindings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import Clash.Core.Term (Term (..), mkLams, mkTyLams)
import Clash.Core.Type (Type (..), TypeView (..), mkFunTy, splitFunForallTy, tyView)
import Clash.Core.TyCon (TyConMap, TyConName, isNewTypeTc)
import Clash.Core.TysPrim (tysPrimMap)
import Clash.Core.Util (mkInternalVar, mkSelectorCase)
import Clash.Core.Var (Var (..), Id, IdScope (..), setIdScope)
import Clash.Core.VarEnv
(InScopeSet, VarEnv, emptyInScopeSet, extendInScopeSet, mkInScopeSet
Expand All @@ -85,7 +86,6 @@ import Clash.Netlist.Types (TopEntityT(..))
import Clash.Primitives.Types
(Primitive (..), CompiledPrimMap)
import Clash.Primitives.Util (generatePrimMap)
import Clash.Rewrite.Util (mkInternalVar, mkSelectorCase)
import Clash.Unique
(listToUniqMap, lookupUniqMap, mapUniqMap, unionUniqMap, uniqMapToUniqSet)
import Clash.Util (reportTimeDiff)
Expand Down
106 changes: 99 additions & 7 deletions clash-lib/src/Clash/Core/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Core.Util where

import Control.Concurrent.Supply (Supply, freshId)
import qualified Control.Lens as Lens
import Control.Monad.Trans.Except (Except, throwE)
import Control.Monad.Trans.Except (Except, throwE, runExcept)
import qualified Data.HashSet as HashSet
import qualified Data.Graph as Graph
import Data.List (foldl', mapAccumR)
Expand Down Expand Up @@ -46,6 +48,7 @@ import Clash.Core.Name
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
import Clash.Core.Term
import Clash.Core.TermInfo (termType)
import Clash.Core.TyCon (TyConMap, tyConDataCons)
import Clash.Core.Type
import Clash.Core.TysPrim (typeNatKind)
Expand Down Expand Up @@ -458,13 +461,18 @@ tyLitShow _ (LitTy (SymTy s)) = return s
tyLitShow _ (LitTy (NumTy s)) = return (show s)
tyLitShow _ ty = throwE $ $(curLoc) ++ "Cannot reduce to a string:\n" ++ showPpr ty

data Projection where
Projection :: (forall m . (Functor m, MonadUnique m) => InScopeSet -> Term -> m [Term])
-> Projection


-- | Determine whether we should split away types from a product type, i.e.
-- clocks should always be separate arguments, and not part of a product.
shouldSplit
:: TyConMap
-> Type
-- ^ Type to examine
-> Maybe (Term,[Type])
-> Maybe ([Term] -> Term, Projection, [Type])
-- ^ If we want to split values of the given type then we have /Just/:
--
-- 1. The (type-applied) data-constructor which, when applied to values of
Expand All @@ -487,17 +495,45 @@ shouldSplit tcm ty = shouldSplit0 tcm (tyView (coreView tcm ty))
shouldSplit0
:: TyConMap
-> TypeView
-> Maybe (Term,[Type])
-> Maybe ([Term] -> Term, Projection, [Type])
shouldSplit0 tcm (TyConApp tcNm tyArgs)
| Just tc <- lookupUniqMap tcNm tcm
, [dc] <- tyConDataCons tc
, let dcArgs = substArgTys dc tyArgs
, let dcArgs = substArgTys dc tyArgs
, let dcArgsLen = length dcArgs
, dcArgsLen > 1
, let dcArgVs = map (tyView . coreView tcm) dcArgs
= if any shouldSplitTy dcArgVs && not (isHidden tcNm tyArgs) then
Just (mkApps (Data dc) (map Right tyArgs), dcArgs)
Just ( mkApps (Data dc) . (map Right tyArgs ++) . map Left
, Projection
(\is0 subj -> mapM (mkSelectorCase ($(curLoc) ++ "splitArg") is0 tcm subj 1)
[0..dcArgsLen - 1])
, dcArgs
)
else
Nothing
| "Clash.Sized.Vector.Vec" <- nameOcc tcNm
, [nTy,argTy] <- tyArgs
, Right n <- runExcept (tyNatSize tcm nTy)
, n > 1
, Just tc <- lookupUniqMap tcNm tcm
, [nil,cons] <- tyConDataCons tc
= if shouldSplitTy (tyView (coreView tcm argTy)) then
Just ( mkVec nil cons argTy n
, Projection (\is0 subj -> mapM (mkVecSelector is0 subj) [0..n-1])
, replicate (fromInteger n) argTy)
else
Nothing
where
mkVecSelector :: forall m . (Functor m, MonadUnique m) => InScopeSet -> Term -> Integer -> m Term
mkVecSelector is0 subj 0 =
mkSelectorCase ($(curLoc) ++ "mkVecSelector") is0 tcm subj 2 1

mkVecSelector is0 subj !n = do
subj1 <- mkSelectorCase ($(curLoc) ++ "mkVecSelector") is0 tcm subj 2 2
mkVecSelector is0 subj1 (n-1)


shouldSplitTy :: TypeView -> Bool
shouldSplitTy ty = isJust (shouldSplit0 tcm ty) || splitTy ty

Expand Down Expand Up @@ -554,8 +590,8 @@ splitShouldSplit
splitShouldSplit tcm = foldr go []
where
go ty rest = case shouldSplit tcm ty of
Just (_,tys) -> splitShouldSplit tcm tys ++ rest
Nothing -> ty : rest
Just (_,_,tys) -> splitShouldSplit tcm tys ++ rest
Nothing -> ty : rest

-- | Strip implicit parameter wrappers (IP)
stripIP :: Type -> Type
Expand Down Expand Up @@ -601,3 +637,59 @@ sccLetBindings =
(Set.elems (Lens.setOf freeLocalIds e) )
in ((i,e),varUniq i,fvs)))
{-# SCC sccLetBindings #-}

-- | Make a case-decomposition that extracts a field out of a (Sum-of-)Product type
mkSelectorCase
:: HasCallStack
=> (Functor m, MonadUnique m)
=> String -- ^ Name of the caller of this function
-> InScopeSet
-> TyConMap -- ^ TyCon cache
-> Term -- ^ Subject of the case-composition
-> Int -- n'th DataCon
-> Int -- n'th field
-> m Term
mkSelectorCase caller inScope tcm scrut dcI fieldI = go (termType tcm scrut)
where
go (coreView1 tcm -> Just ty') = go ty'
go scrutTy@(tyView -> TyConApp tc args) =
case tyConDataCons (lookupUniqMap' tcm tc) of
[] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showPpr tc) scrutTy
dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max" scrutTy
| otherwise -> do
let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI-1)) dcs (dcI-1)
let (Just fieldTys) = dataConInstArgTysE inScope tcm dc args
if fieldI >= length fieldTys
then cantCreate $(curLoc) "Field index exceed max" scrutTy
else do
wildBndrs <- mapM (mkWildValBinder inScope) fieldTys
let ty = indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI
selBndr <- mkInternalVar inScope "sel" ty
let bndrs = take fieldI wildBndrs ++ [selBndr] ++ drop (fieldI+1) wildBndrs
pat = DataPat dc (dcExtTyVars dc) bndrs
retVal = Case scrut ty [ (pat, Var selBndr) ]
return retVal
go scrutTy = cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showPpr scrutTy) scrutTy

cantCreate loc info scrutTy = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showPpr scrut ++ " :: " ++ showPpr scrutTy ++ ")\nAdditional info: " ++ info

-- | Make a binder that should not be referenced
mkWildValBinder
:: (MonadUnique m)
=> InScopeSet
-> Type
-> m Id
mkWildValBinder is = mkInternalVar is "wild"

-- | Make a new, unique, identifier
mkInternalVar
:: (MonadUnique m)
=> InScopeSet
-> OccName
-- ^ Name of the identifier
-> KindOrType
-> m Id
mkInternalVar inScope name ty = do
i <- getUniqueM
let nm = mkUnsafeInternalName name i
return (uniqAway inScope (mkLocalId ty nm))
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Driver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ splitTopAnn tcm sp typ@(tyView -> FunTy {}) t@Synthesize{t_inputs} =
= PortName "" : go res (p:ps)
| otherwise =
case shouldSplit tcm a of
Just (_,argTys@(_:_:_)) ->
Just (_,_,argTys@(_:_:_)) ->
-- Port must be split up into 'n' pieces.. can it?
case p of
PortProduct nm portNames0 ->
Expand Down
5 changes: 2 additions & 3 deletions clash-lib/src/Clash/Normalize/DEC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,13 @@ import Clash.Core.Term
import Clash.Core.TermInfo (termType)
import Clash.Core.TyCon (tyConDataCons)
import Clash.Core.Type (Type, isPolyFunTy, mkTyConApp, splitFunForallTy)
import Clash.Core.Util (sccLetBindings)
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings)
import Clash.Core.Var (isGlobalId)
import Clash.Core.VarEnv
(InScopeSet, elemInScopeSet, extendInScopeSetList, notElemInScopeSet, unionInScope)
import Clash.Normalize.Types (NormalizeState)
import Clash.Rewrite.Types
import Clash.Rewrite.Util (mkInternalVar, mkSelectorCase,
isUntranslatableType)
import Clash.Rewrite.Util (isUntranslatableType)
import Clash.Rewrite.WorkFree (isConstant)
import Clash.Unique (lookupUniqMap)
import Clash.Util
Expand Down
50 changes: 40 additions & 10 deletions clash-lib/src/Clash/Normalize/Transformations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ import Clash.Core.Type (Type (..), TypeView (..), applyFun
normalizeType, splitFunForallTy,
splitFunTy,
tyView, mkPolyFunTy, coreView,
LitTy (..), coreView1)
LitTy (..), coreView1, mkTyConApp)
import Clash.Core.TyCon (TyConMap, tyConDataCons)
import Clash.Core.Util
( isSignalType, mkVec, tyNatSize, undefinedTm,
shouldSplit, inverseTopSortLetBindings)
(Projection (..), isSignalType, mkVec, tyNatSize, undefinedTm,
shouldSplit, inverseTopSortLetBindings, mkInternalVar, mkSelectorCase)
import Clash.Core.Var
(Id, TyVar, Var (..), isGlobalId, isLocalId, mkLocalId)
import Clash.Core.VarEnv
Expand Down Expand Up @@ -2142,6 +2142,26 @@ reduceConst _ e = return e
-- * Clash.Sized.RTree.treplicate
-- * Clash.Sized.Internal.BitVector.split#
-- * Clash.Sized.Internal.BitVector.eq#
--
-- Note [Unroll shouldSplit types]
-- 1. Certain higher-order functions over Vec, such as map, have specialized
-- code-paths to turn them into generate-for loops in HDL, instead of having to
-- having to unroll/inline their recursive definitions, e.g. Clash.Sized.Vector.map
--
-- 2. Clash, in general, translates Haskell product types to VHDL records. This
-- mostly works out fine, there is however one exception: certain synthesis
-- tools, and some HDL simulation tools (like verilator), do not like it when
-- the clock (and certain other global control signals) is contained in a
-- record type; they want them to be separate inputs to the entity/module.
-- And Clash actually does some transformations to try to ensure that values of
-- type Clock do not end up in a VHDL record type.
--
-- The problem is that the transformations in 2. never took into account the
-- specialized code-paths in 1. Making the code-paths in 1. aware of the
-- transformations in 2. is really not worth the effort for such a niche case.
-- It's easier to just unroll the recursive definitions.
--
-- See https://github.com/clash-lang/clash-compiler/issues/1606
reduceNonRepPrim :: HasCallStack => NormRewrite
reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks) <- collectArgsTicks e = do
tcm <- Lens.view tcCache
Expand All @@ -2157,12 +2177,18 @@ reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks
tv -> let argLen = length args in case primName p of
"Clash.Sized.Vector.zipWith" | argLen == 7 -> do
let [lhsElTy,rhsElty,resElTy,nTy] = Either.rights args
TyConApp vecTcNm _ = tv
lhsTy = mkTyConApp vecTcNm [nTy,lhsElTy]
rhsTy = mkTyConApp vecTcNm [nTy,rhsElty]
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n < 2)
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly
[lhsElTy,rhsElty,resElTy] ]
[lhsElTy,rhsElty,resElTy]
-- Note [Unroll shouldSplit types]
, pure (any (Maybe.isJust . shouldSplit tcm)
[lhsTy,rhsTy,eTy]) ]
if shouldReduce1
then let [fun,lhsArg,rhsArg] = Either.lefts args
in (`mkTicks` ticks) <$>
Expand All @@ -2171,12 +2197,17 @@ reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks
_ -> return e
"Clash.Sized.Vector.map" | argLen == 5 -> do
let [argElTy,resElTy,nTy] = Either.rights args
TyConApp vecTcNm _ = tv
argTy = mkTyConApp vecTcNm [nTy,argElTy]
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n < 2 )
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly
[argElTy,resElTy] ]
[argElTy,resElTy]
-- Note [Unroll shouldSplit types]
, pure (any (Maybe.isJust . shouldSplit tcm)
[argTy,eTy]) ]
if shouldReduce1
then let [fun,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceMap c p n argElTy resElTy fun arg
Expand Down Expand Up @@ -2888,12 +2919,12 @@ separateLambda
-- ^ If lambda is split up, this function returns a Just containing the new term
separateLambda tcm ctx@(TransformContext is0 _) b eb0 =
case shouldSplit tcm (varType b) of
Just (dc,argTys@(_:_:_)) ->
Just (dc, _, argTys) ->
let
nm = mkDerivedName ctx (nameOcc (varName b))
bs0 = map (`mkLocalId` nm) argTys
(is1, bs1) = List.mapAccumL newBinder is0 bs0
subst = extendIdSubst (mkSubst is1) b (mkApps dc (map (Left . Var) bs1))
subst = extendIdSubst (mkSubst is1) b (dc (map Var bs1))
eb1 = substTm "separateArguments" subst eb0
in
Just (mkLams eb1 bs1)
Expand Down Expand Up @@ -2953,9 +2984,8 @@ separateArguments (TransformContext is0 _) e@(collectArgsTicks -> (Var g, args,
tcm <- Lens.view tcCache
let argTy = termType tcm tmArg
case shouldSplit tcm argTy of
Just (_,argTys@(_:_:_)) -> do
tmArgs <- mapM (mkSelectorCase ($(curLoc) ++ "splitArg") is0 tcm tmArg 1)
[0..length argTys - 1]
Just (_,Projection proj,_) -> do
tmArgs <- proj is0 tmArg
changed (map ((ty,) . Left) tmArgs)
_ ->
return [(ty,arg)]
Expand Down
Loading

0 comments on commit 279c189

Please sign in to comment.