Skip to content

Commit

Permalink
Hack for avoiding division zero in segreduce-iota
Browse files Browse the repository at this point in the history
  • Loading branch information
HnimNart committed Aug 13, 2020
1 parent d2a34aa commit c313ca9
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions src/Futhark/CodeGen/ImpGen/Multicore/SegRed.hs
Expand Up @@ -200,45 +200,41 @@ compileSegRedBody n_segments pat space reds kbody = do
ns' <- mapM toExp ns
let inner_bound = last ns'

n_segments' <- toExp $ Var n_segments

let per_red_pes = segBinOpChunks reds $ patternValueElements pat
-- Perform sequential reduce on inner most dimension
collect $ do
flat_idx <- dPrimV "flat_idx" (n_segments' * inner_bound)
zipWithM_ dPrimV_ is $ unflattenIndex ns' $ Imp.vi32 flat_idx
sComment "neutral-initialise the accumulators" $
forM_ (zip per_red_pes reds) $ \(pes, red) ->
forM_ (zip pes (segBinOpNeutral red)) $ \(pe, ne) ->
sLoopNest (segBinOpShape red) $ \vec_is ->
copyDWIMFix (patElemName pe) (map Imp.vi32 (init is) ++ vec_is) ne []

sComment "main body" $ do
dScope Nothing $ scopeOfLParams $ concatMap (lambdaParams . segBinOpLambda) reds
sFor "i" inner_bound $ \i -> do
forM_ (zip (init is) $ unflattenIndex (init ns') n_segments') $ uncurry (<--)
dPrimV_ (last is) i
kbody $ \all_red_res -> do
let red_res' = chunks (map (length . segBinOpNeutral) reds) all_red_res
forM_ (zip3 per_red_pes reds red_res') $ \(pes, red, res') ->
sLoopNest (segBinOpShape red) $ \vec_is -> do

sComment "load accum" $ do
let acc_params = take (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red
forM_ (zip acc_params pes) $ \(p, pe) ->
copyDWIMFix (paramName p) [] (Var $ patElemName pe) (map Imp.vi32 (init is) ++ vec_is)

sComment "load new val" $ do
let next_params = drop (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red
forM_ (zip next_params res') $ \(p, (res, res_is)) ->
copyDWIMFix (paramName p) [] res (res_is ++ vec_is)

sComment "apply reduction" $ do
let lbody = (lambdaBody . segBinOpLambda) red
compileStms mempty (bodyStms lbody) $
sComment "write back to res" $
forM_ (zip pes (bodyResult lbody)) $
\(pe, se') -> copyDWIMFix (patElemName pe) (map Imp.vi32 (init is) ++ vec_is) se' []
dScope Nothing $ scopeOfLParams $ concatMap (lambdaParams . segBinOpLambda) reds
sFor "i" inner_bound $ \i -> do
zipWithM_ dPrimV_ (init is) $ unflattenIndex (init ns') $ Imp.vi32 n_segments
dPrimV_ (last is) i
sWhen (i .==. 0) $
sComment "neutral-initialise the accumulator" $
forM_ (zip per_red_pes reds) $ \(pes, red) ->
forM_ (zip pes (segBinOpNeutral red)) $ \(pe, ne) ->
sLoopNest (segBinOpShape red) $ \vec_is ->
copyDWIMFix (patElemName pe) (map Imp.vi32 (init is) ++ vec_is) ne []

kbody $ \all_red_res -> do
let red_res' = chunks (map (length . segBinOpNeutral) reds) all_red_res
forM_ (zip3 per_red_pes reds red_res') $ \(pes, red, res') ->
sLoopNest (segBinOpShape red) $ \vec_is -> do

sComment "load accum" $ do
let acc_params = take (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red
forM_ (zip acc_params pes) $ \(p, pe) ->
copyDWIMFix (paramName p) [] (Var $ patElemName pe) (map Imp.vi32 (init is) ++ vec_is)

sComment "load new val" $ do
let next_params = drop (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red
forM_ (zip next_params res') $ \(p, (res, res_is)) ->
copyDWIMFix (paramName p) [] res (res_is ++ vec_is)

sComment "apply reduction" $ do
let lbody = (lambdaBody . segBinOpLambda) red
compileStms mempty (bodyStms lbody) $
sComment "write back to res" $
forM_ (zip pes (bodyResult lbody)) $
\(pe, se') -> copyDWIMFix (patElemName pe) (map Imp.vi32 (init is) ++ vec_is) se' []



Expand Down

0 comments on commit c313ca9

Please sign in to comment.