Skip to content

Commit

Permalink
Unroll unconcat in more situations than m=0
Browse files Browse the repository at this point in the history
Fixes #1756
  • Loading branch information
christiaanb committed Apr 9, 2021
1 parent ecd01cb commit 40dcdc4
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog/2021-04-08T17_46_59+02_00_fix1756
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: `unconcat` cannot be used as initial/reset value for a `register` [#1756](https://github.com/clash-lang/clash-compiler/issues/1756)
54 changes: 45 additions & 9 deletions clash-lib/src/Clash/Normalize/PrimitiveReductions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
* Clash.Sized.Vector.dtfold
* Clash.Sized.RTree.tfold
* Clash.Sized.Vector.reverse
* Clash.Sized.Vector.unconcat
Partially handles:
* Clash.Sized.Vector.unconcat
* Clash.Sized.Vector.transpose
-}

Expand Down Expand Up @@ -914,12 +914,15 @@ reduceAppend inScope n m aTy lArg rArg = do
-- | Replace an application of the @Clash.Sized.Vector.unconcat@ primitive on
-- vectors of a known length @n@, by the fully unrolled recursive "definition"
-- of @Clash.Sized.Vector.unconcat@
reduceUnconcat :: Integer -- ^ Length of the result vector
reduceUnconcat :: InScopeSet
-> PrimInfo -- ^ Unconcat primitive info
-> Integer -- ^ Length of the result vector
-> Integer -- ^ Length of the elements of the result vector
-> Type -- ^ Element type
-> Term -- ^ SNat "Length of the elements of the result vector"
-> Term -- ^ Argument vector
-> NormalizeSession Term
reduceUnconcat n 0 aTy arg = do
reduceUnconcat inScope unconcatPrimInfo n m aTy sm arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
Expand All @@ -929,14 +932,47 @@ reduceUnconcat n 0 aTy arg = do
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= let nilVec = mkVec nilCon consCon aTy 0 []
innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy]
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
in changed retVec
, let innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy m), aTy]
= if n == 0 then
changed (mkVecNil nilCon innerVecTy)
else if m == 0 then do
let
nilVec = mkVecNil nilCon aTy
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
changed retVec
else do
uniqs0 <- Lens.use uniqSupply
let
(uniqs1,(vars,headsAndTails)) =
second (second concat . unzip)
(extractElems uniqs0 inScope consCon aTy 'U' (n*m) arg)
-- Build a vector out of the first m elements
mvec = mkVec nilCon consCon aTy m (take (fromInteger m) vars)
-- Get the vector representing the next ((n-1)*m) elements
-- N.B. `extractElems (xs :: Vec 2 a)` creates:
-- x0 = head xs
-- xs0 = tail xs
-- x1 = head xs0
-- xs1 = tail xs0
(lbs,head -> nextVec) = splitAt ((2*fromInteger m)-1) headsAndTails
-- recursively call unconcat
nextUnconcat = mkApps (Prim unconcatPrimInfo)
[ Right (LitTy (NumTy (n-1)))
, Right (LitTy (NumTy m))
, Right aTy
, Left (Literal (NaturalLiteral (n-1)))
, Left sm
, Left (snd nextVec)
]
-- let (mvec,nextVec) = splitAt sm arg
-- in Cons mvec (unconcat sm nextVec)
lBody = mkVecCons consCon innerVecTy n mvec nextUnconcat
lb = Letrec lbs lBody

uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceUnconcat: argument does not have a vector type: " ++ showPpr ty

reduceUnconcat _ _ _ _ = error $ $(curLoc) ++ "reduceUnconcat: unimplemented"

-- | Replace an application of the @Clash.Sized.Vector.transpose@ primitive on
-- vectors of a known length @n@, by the fully unrolled recursive "definition"
-- of @Clash.Sized.Vector.transpose@
Expand Down
15 changes: 13 additions & 2 deletions clash-lib/src/Clash/Normalize/Transformations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2327,9 +2327,20 @@ reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks
else return e
_ -> return e
"Clash.Sized.Vector.unconcat" | argLen == 6 -> do
let ([_knN,_sm,arg],[mTy,nTy,aTy]) = Either.partitionEithers args
let ([_knN,sm,arg],[nTy,mTy,aTy]) = Either.partitionEithers args
argTy = termType tcm arg
case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of
(Right n, Right 0) -> (`mkTicks` ticks) <$> reduceUnconcat n 0 aTy arg
(Right n, Right m) -> do
shouldReduce1 <- List.orM [ pure (m==0)
, shouldReduce ctx
, isUntranslatableType_not_poly aTy
-- Note [Unroll shouldSplit types]
, pure (Maybe.isJust (shouldSplit tcm argTy))
]
if shouldReduce1 then
(`mkTicks` ticks) <$> reduceUnconcat is0 p n m aTy sm arg
else
return e
_ -> return e
"Clash.Sized.Vector.transpose" | argLen == 5 -> do
let ([_knN,arg],[mTy,nTy,aTy]) = Either.partitionEithers args
Expand Down
5 changes: 5 additions & 0 deletions tests/shouldwork/Issues/T1756.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module T1756 where

import Clash.Prelude

topEntity a = register @System (unconcat d2 (replicate d8 False)) a
1 change: 1 addition & 0 deletions testsuite/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ runClashTest = defaultMain $ clashTestRoot
, runTest "T1606A" def{hdlSim=False}
, runTest "T1606B" def{hdlSim=False}
, runTest "T1742" def{hdlSim=False, buildTargets=BuildSpecific ["shell"]}
, runTest "T1756" def{hdlSim=False}
] <>
if compiledWith == Cabal then
-- This tests fails without environment files present, which are only
Expand Down

0 comments on commit 40dcdc4

Please sign in to comment.