diff --git a/changelog/2021-04-08T17_46_59+02_00_fix1756 b/changelog/2021-04-08T17_46_59+02_00_fix1756 new file mode 100644 index 0000000000..4362ee312d --- /dev/null +++ b/changelog/2021-04-08T17_46_59+02_00_fix1756 @@ -0,0 +1 @@ +FIXED: `unconcat` cannot be used as initial/reset value for a `register` [#1756](https://github.com/clash-lang/clash-compiler/issues/1756) diff --git a/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs b/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs index f15f73e391..12947ce657 100644 --- a/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs +++ b/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs @@ -23,10 +23,10 @@ * Clash.Sized.Vector.dtfold * Clash.Sized.RTree.tfold * Clash.Sized.Vector.reverse + * Clash.Sized.Vector.unconcat Partially handles: - * Clash.Sized.Vector.unconcat * Clash.Sized.Vector.transpose -} @@ -914,12 +914,15 @@ reduceAppend inScope n m aTy lArg rArg = do -- | Replace an application of the @Clash.Sized.Vector.unconcat@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.unconcat@ -reduceUnconcat :: Integer -- ^ Length of the result vector +reduceUnconcat :: InScopeSet + -> PrimInfo -- ^ Unconcat primitive info + -> Integer -- ^ Length of the result vector -> Integer -- ^ Length of the elements of the result vector -> Type -- ^ Element type + -> Term -- ^ SNat "Length of the elements of the result vector" -> Term -- ^ Argument vector -> NormalizeSession Term -reduceUnconcat n 0 aTy arg = do +reduceUnconcat inScope unconcatPrimInfo n m aTy sm arg = do tcm <- Lens.view tcCache let ty = termType tcm arg go tcm ty @@ -929,14 +932,47 @@ reduceUnconcat n 0 aTy arg = do | (Just vecTc) <- lookupUniqMap vecTcNm tcm , nameOcc vecTcNm == "Clash.Sized.Vector.Vec" , [nilCon,consCon] <- tyConDataCons vecTc - = let nilVec = mkVec nilCon consCon aTy 0 [] - innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy] - retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec) - in changed retVec + , let innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy m), aTy] + = if n == 0 then + changed (mkVecNil nilCon innerVecTy) + else if m == 0 then do + let + nilVec = mkVecNil nilCon aTy + retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec) + changed retVec + else do + uniqs0 <- Lens.use uniqSupply + let + (uniqs1,(vars,headsAndTails)) = + second (second concat . unzip) + (extractElems uniqs0 inScope consCon aTy 'U' (n*m) arg) + -- Build a vector out of the first m elements + mvec = mkVec nilCon consCon aTy m (take (fromInteger m) vars) + -- Get the vector representing the next ((n-1)*m) elements + -- N.B. `extractElems (xs :: Vec 2 a)` creates: + -- x0 = head xs + -- xs0 = tail xs + -- x1 = head xs0 + -- xs1 = tail xs0 + (lbs,head -> nextVec) = splitAt ((2*fromInteger m)-1) headsAndTails + -- recursively call unconcat + nextUnconcat = mkApps (Prim unconcatPrimInfo) + [ Right (LitTy (NumTy (n-1))) + , Right (LitTy (NumTy m)) + , Right aTy + , Left (Literal (NaturalLiteral (n-1))) + , Left sm + , Left (snd nextVec) + ] + -- let (mvec,nextVec) = splitAt sm arg + -- in Cons mvec (unconcat sm nextVec) + lBody = mkVecCons consCon innerVecTy n mvec nextUnconcat + lb = Letrec lbs lBody + + uniqSupply Lens..= uniqs1 + changed lb go _ ty = error $ $(curLoc) ++ "reduceUnconcat: argument does not have a vector type: " ++ showPpr ty -reduceUnconcat _ _ _ _ = error $ $(curLoc) ++ "reduceUnconcat: unimplemented" - -- | Replace an application of the @Clash.Sized.Vector.transpose@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.transpose@ diff --git a/clash-lib/src/Clash/Normalize/Transformations.hs b/clash-lib/src/Clash/Normalize/Transformations.hs index 87a8784dd6..dbce760d08 100644 --- a/clash-lib/src/Clash/Normalize/Transformations.hs +++ b/clash-lib/src/Clash/Normalize/Transformations.hs @@ -2327,9 +2327,20 @@ reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks else return e _ -> return e "Clash.Sized.Vector.unconcat" | argLen == 6 -> do - let ([_knN,_sm,arg],[mTy,nTy,aTy]) = Either.partitionEithers args + let ([_knN,sm,arg],[nTy,mTy,aTy]) = Either.partitionEithers args + argTy = termType tcm arg case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of - (Right n, Right 0) -> (`mkTicks` ticks) <$> reduceUnconcat n 0 aTy arg + (Right n, Right m) -> do + shouldReduce1 <- List.orM [ pure (m==0) + , shouldReduce ctx + , isUntranslatableType_not_poly aTy + -- Note [Unroll shouldSplit types] + , pure (Maybe.isJust (shouldSplit tcm argTy)) + ] + if shouldReduce1 then + (`mkTicks` ticks) <$> reduceUnconcat is0 p n m aTy sm arg + else + return e _ -> return e "Clash.Sized.Vector.transpose" | argLen == 5 -> do let ([_knN,arg],[mTy,nTy,aTy]) = Either.partitionEithers args diff --git a/tests/shouldwork/Issues/T1756.hs b/tests/shouldwork/Issues/T1756.hs new file mode 100644 index 0000000000..41aa675071 --- /dev/null +++ b/tests/shouldwork/Issues/T1756.hs @@ -0,0 +1,5 @@ +module T1756 where + +import Clash.Prelude + +topEntity a = register @System (unconcat d2 (replicate d8 False)) a diff --git a/testsuite/Main.hs b/testsuite/Main.hs index 5c58867719..f57b8fd8a0 100755 --- a/testsuite/Main.hs +++ b/testsuite/Main.hs @@ -535,6 +535,7 @@ runClashTest = defaultMain $ clashTestRoot , runTest "T1606A" def{hdlSim=False} , runTest "T1606B" def{hdlSim=False} , runTest "T1742" def{hdlSim=False, buildTargets=BuildSpecific ["shell"]} + , runTest "T1756" def{hdlSim=False} ] <> if compiledWith == Cabal then -- This tests fails without environment files present, which are only