Skip to content

Commit

Permalink
Reimplement custom linearizations, respecting the Core/Simp IR distin…
Browse files Browse the repository at this point in the history
…ction.

This uses the noinline mechanism to preserve function identity, so only
functions tagged `@noinline` can be given a custom linearization rule.

Co-authored-by: Alexey Radul <axch@google.com>
  • Loading branch information
dougalm and axch committed Feb 2, 2023
1 parent aaa536e commit c996e98
Show file tree
Hide file tree
Showing 13 changed files with 542 additions and 354 deletions.
83 changes: 61 additions & 22 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ liftTopBuilderHoisted cont = do
Distinct <- getDistinct
return $ runHardFail $ runTopBuilderT env $ localTopBuilder cont

liftTopBuilderAndEmit
:: (HoistingTopBuilder TopEnvFrag m, RenameE e, SinkableE e, HoistableE e)
=> (forall l. (Mut l, DExt n l) => TopBuilderM l (e l))
-> m n (Maybe (e n))
liftTopBuilderAndEmit cont = do
liftTopBuilderHoisted cont >>= emitHoistedEnv

newtype DoubleBuilderT (r::IR) (topEmissions::B) (m::MonadKind) (n::S) (a:: *) =
DoubleBuilderT { runDoubleBuilderT' :: DoubleInplaceT Env topEmissions (BuilderEmissions r) m n a }
deriving ( Functor, Applicative, Monad, MonadFail, Fallible
Expand Down Expand Up @@ -282,8 +289,10 @@ emitTopLet hint letAnn expr = do
ty <- getType expr
emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn ty expr)

emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFun n -> m n (TopFunName n)
emitTopFunBinding hint f = emitBinding hint $ TopFunBinding f
emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> LamExpr SimpIR n -> m n (TopFunName n)
emitTopFunBinding hint def f = do
ty <- getNaryLamExprType f
emitBinding hint $ TopFunBinding $ DexTopFun def ty f Waiting

updateTopFunStatus :: (Mut n, TopBuilder m) => TopFunName n -> EvalStatus (TopFunLowerings n) -> m n ()
updateTopFunStatus f status =
Expand Down Expand Up @@ -326,13 +335,24 @@ queryIxDictCache d = do
cache <- ixDictCache <$> getCache
return $ lookupEMap cache d

queryLinearizationCache :: EnvReader m => LinearizationSpec n -> m n (Maybe (TopFunName n, TopFunName n))
queryLinearizationCache s = do
cache <- linearizationCache <$> getCache
return $ fmap fromPairE $ lookupEMap cache s

extendLinearizationCache
:: (TopBuilder m, Fallible1 m, Mut n)
=> LinearizationSpec n -> (TopFunName n, TopFunName n) -> m n ()
extendLinearizationCache s fs =
extendCache $ mempty { linearizationCache = eMapSingleton s (toPairE fs) }

finishSpecializedDict :: (Mut n, TopBuilder m) => SpecDictName n -> [LamExpr SimpIR n] -> m n ()
finishSpecializedDict name methods =
emitPartialTopEnvFrag $ mempty {fragFinishSpecializedDict = toSnocList [(name, methods)]}

queryObjCache :: EnvReader m => TopFunName n -> m n (Maybe (FunObjCodeName n))
queryObjCache v = lookupEnv v >>= \case
TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl
TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl
_ -> return Nothing

emitObjFile
Expand Down Expand Up @@ -560,6 +580,13 @@ instance (SinkableE e, HoistableState e, HoistingTopBuilder frag m)
canHoistToTop e = lift11 $ canHoistToTop e
{-# INLINE canHoistToTop #-}

instance (SinkableE e, HoistingTopBuilder frag m)
=> HoistingTopBuilder frag (ReaderT1 e m) where
emitHoistedEnv ab = lift11 $ emitHoistedEnv ab
{-# INLINE emitHoistedEnv #-}
canHoistToTop e = lift11 $ canHoistToTop e
{-# INLINE canHoistToTop #-}

instance Builder r m => Builder r (MaybeT1 m) where
emitDecl hint ann expr = lift11 $ emitDecl hint ann expr
{-# INLINE emitDecl #-}
Expand Down Expand Up @@ -810,6 +837,15 @@ asNaryLam ty f = liftBuilder do
buildNaryLamExprFromPi ty \xs ->
naryApp (sink f) (map Var $ toList xs)

buildNonDepNaryLamExpr
:: ScopableBuilder r m
=> [Type r n]
-> (forall l. (Emits l, Distinct l, DExt n l) => [AtomName r l] -> m l (Atom r l))
-> m n (LamExpr r n)
buildNonDepNaryLamExpr tys cont = do
bs <- typesAsBinderNest tys
buildNaryLamExpr bs cont

buildNaryLamExpr
:: ScopableBuilder r m
=> (EmptyAbs (Nest (Binder r)) n)
Expand Down Expand Up @@ -894,22 +930,6 @@ buildSplitCase tys scrut resultTy match fallback = do
1 -> fallback v
_ -> error "should only have two cases"

buildLamExprWithRecon
:: ScopableBuilder r m
=> EmptyAbs (Nest (Binder r)) n
-> (forall l. (Emits l, DExt n l) => [AtomName r l] -> m l (e l))
-> (forall l1 l2. Nest (Binder r) n l1 -> Nest (Decl r) l1 l2 -> e l2 -> m l2 (Atom r l2, recon))
-> m n (LamExpr r n, recon)
buildLamExprWithRecon _ _ _ = undefined

buildForWithRecon
:: (Emits n, ScopableBuilder r m)
=> NameHint -> Direction -> IxType r n
-> (forall l. (Emits l, DExt n l) => AtomName r l -> m l (e l))
-> (forall l1 l2. Binder r n l1 -> Nest (Decl r) l1 l2 -> e l2 -> m l2 (Atom r l2, recon))
-> m n (Atom r n, recon)
buildForWithRecon _ _ _ _ _ = undefined

buildEffLam
:: ScopableBuilder r m
=> RWS -> NameHint -> Type r n
Expand Down Expand Up @@ -1008,10 +1028,10 @@ tangentType ty = maybeTangentType ty >>= \case
Just tanTy -> return tanTy
Nothing -> error $ "can't differentiate wrt type: " ++ pprint ty

maybeTangentType :: EnvReader m => SType n -> m n (Maybe (SType n))
maybeTangentType :: (IRRep r, EnvReader m) => Type r n -> m n (Maybe (Type r n))
maybeTangentType ty = liftEnvReaderT $ maybeTangentType' ty

maybeTangentType' :: SType n -> EnvReaderT Maybe n (SType n)
maybeTangentType' :: IRRep r => Type r n -> EnvReaderT Maybe n (Type r n)
maybeTangentType' ty = case ty of
TabTy b bodyTy -> do
refreshAbs (Abs b bodyTy) \b' bodyTy' -> do
Expand Down Expand Up @@ -1048,6 +1068,25 @@ addTangent x y = do
ty -> notTangent ty
where notTangent ty = error $ "Not a tangent type: " ++ pprint ty

symbolicTangentTy :: (EnvReader m, Fallible1 m) => CType n -> m n (CType n)
symbolicTangentTy elTy = lookupSourceMap "SymbolicTangent" >>= \case
Just (UTyConVar symTanName) -> do
TyConBinding dataDefName _ <- lookupEnv symTanName
return $ TypeCon "SymbolicTangent" dataDefName $
DataDefParams [(PlainArrow, elTy)]
Nothing -> throw UnboundVarErr $
"Can't define a custom linearization with symbolic zeros: " ++
"the SymbolicTangent type is not in scope."
Just _ -> throw TypeErr "SymbolicTangent should name a `data` type"

symbolicTangentZero :: EnvReader m => SType n -> m n (SAtom n)
symbolicTangentZero argTy = return $ SumVal [UnitTy, argTy] 0 UnitVal

symbolicTangentNonZero :: EnvReader m => SAtom n -> m n (SType n)
symbolicTangentNonZero val = do
ty <- getType val
return $ SumVal [UnitTy, ty] 1 val

-- === builder versions of common top-level emissions ===

emitDataDef :: (Mut n, TopBuilder m) => DataDef n -> m n (DataDefName n)
Expand Down Expand Up @@ -1195,7 +1234,7 @@ ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y
ieq x y = emitOp $ BinOp (ICmp Equal) x y

fromPair :: Builder r m => Atom r n -> m n (Atom r n, Atom r n)
fromPair :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n (Atom r n, Atom r n)
fromPair pair = do
~[x, y] <- getUnpacked pair
return (x, y)
Expand Down
8 changes: 6 additions & 2 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,17 @@ normalizeNaryProj (i:is) x = normalizeProj i =<< normalizeNaryProj is x

-- assumes the atom is already normalized
normalizeProj :: EnvReader m => Projection -> Atom r n -> m n (Atom r n)
normalizeProj UnwrapNewtype = \case
normalizeProj UnwrapNewtype atom = case atom of
NewtypeCon _ x -> return x
x -> return $ ProjectElt UnwrapNewtype x
normalizeProj (ProjectProduct i) = \case
normalizeProj (ProjectProduct i) atom = case atom of
Con (ProdCon xs) -> return $ xs !! i
DepPair l _ _ | i == 0 -> return l
DepPair _ r _ | i == 1 -> return r
SimpInCore (LiftSimp maybeTy x) -> do
maybeTy' <- forM maybeTy \t -> projType i t atom
x' <- normalizeProj (ProjectProduct i) x
return $ SimpInCore $ LiftSimp maybeTy' x'
RepValAtom (RepVal t tree) -> do
t' <- projType i t (RepValAtom (RepVal t tree))
case tree of
Expand Down
2 changes: 1 addition & 1 deletion src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ instance IRRep r => HasType r (Expr r) where
TopApp f xs -> do
NaryPiType bs _ resultTy <- getTypeTopFun =<< renameM f
xs' <- mapM renameM xs
applySubst (bs @@> map SubstVal xs') resultTy
checkedApplyNaryAbs (Abs bs resultTy) xs'
Atom x -> getTypeE x
PrimOp op -> typeCheckPrimOp op
Hof hof -> typeCheckPrimHof hof
Expand Down
9 changes: 4 additions & 5 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
f <- substM f'
xs <- mapM substM xs'
lookupTopFun f >>= \case
DexTopFun piTy _ _ ->
emitCall maybeDest piTy f $ toList xs
DexTopFun _ piTy _ _ -> emitCall maybeDest piTy f $ toList xs
FFITopFun _ _ -> do
resultTy <- getType $ TopApp f xs
scalarArgs <- liftM toList $ mapM fromScalarAtom xs
Expand Down Expand Up @@ -1418,8 +1417,8 @@ abstractLinktimeObjects f = do
let allVars = freeVarsE f
(funVars, funTys) <- unzip <$> forMFilter (nameSetToList @TopFunNameC allVars) \v ->
lookupTopFun v >>= \case
DexTopFun ty _ _ -> do
ty' <- getImpFunType StandardCC ty
DexTopFun _ piTy _ _ -> do
ty' <- getImpFunType StandardCC piTy
return $ Just (v, ty')
FFITopFun _ _ -> return Nothing
(ptrVars, ptrTys) <- unzip <$> forMFilter (nameSetToList @PtrNameC allVars) \v -> do
Expand Down Expand Up @@ -1470,7 +1469,7 @@ impInstrTypes instr = case instr of
DebugPrint _ _ -> return []
IQueryParallelism _ _ -> return [IIdxRepTy, IIdxRepTy]
ICall f _ -> lookupTopFun f >>= \case
DexTopFun piTy _ _ -> do
DexTopFun _ piTy _ _ -> do
IFunType _ _ resultTys <- getImpFunType StandardCC piTy
return resultTys
FFITopFun _ (IFunType _ _ resultTys) -> return resultTys
Expand Down
2 changes: 1 addition & 1 deletion src/lib/ImpToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ compileInstr instr = case instr of
return []
RenameOperandSubstVal v -> do
lookupTopFun v >>= \case
DexTopFun _ _ _ -> error "Imp functions should be abstracted at this point"
DexTopFun _ _ _ _ -> error "Imp functions should be abstracted at this point"
FFITopFun fname ty@(IFunType cc _ impResultTys) -> do
let resultTys = map scalarTy impResultTys
case cc of
Expand Down
91 changes: 32 additions & 59 deletions src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Control.Monad.Reader
import Data.Foldable (toList)
import Data.Functor
import Data.List (elemIndex)
import Data.Maybe (catMaybes, isJust)
import qualified Data.Set as S
import GHC.Stack

Expand All @@ -21,6 +22,7 @@ import IRVariants
import MTL1
import Name
import Subst
import {-# SOURCE #-} Simplify (linearizeTopFun)
import PPrint
import QueryType
import Types.Core
Expand All @@ -38,8 +40,8 @@ emptyActivePrimals = ActivePrimals [] Pure

data TangentArgs (n::S) = TangentArgs [SAtomName n]

type PrimalM = SubstReaderT Name (ReaderT1 ActivePrimals (BuilderM SimpIR)) :: MonadKind2
type TangentM = ReaderT1 TangentArgs (BuilderM SimpIR) :: MonadKind1
type PrimalM = SubstReaderT Name (ReaderT1 ActivePrimals (DoubleBuilder SimpIR)) :: MonadKind2
type TangentM = ReaderT1 TangentArgs (DoubleBuilder SimpIR) :: MonadKind1

data WithTangent (n::S) (e1::E) (e2::E) =
WithTangent (e1 n) (forall l. (Emits l, DExt n l) => TangentM l (e2 l))
Expand All @@ -52,7 +54,7 @@ pureLin x = do
x' <- renameM x
return $ WithTangent x' (sinkM x')

runPrimalM :: Subst Name i o -> ActivePrimals o -> PrimalM i o a -> BuilderM SimpIR o a
runPrimalM :: Subst Name i o -> ActivePrimals o -> PrimalM i o a -> DoubleBuilder SimpIR o a
runPrimalM subst args cont = runReaderT1 args $ runSubstReaderT subst cont

activePrimalIdx :: AtomName SimpIR o -> PrimalM i o (Maybe Int)
Expand Down Expand Up @@ -242,33 +244,23 @@ applyLinLamAbsAbs (Abs bsLam (Abs bsRecon (LamExpr bs body))) lamArgs residuals
<.> bsRecon @@> map SubstVal residualss
<.> bs @@> map Rename args) body >>= emitBlock

-- repeat the primal computation in the tangent part (this is ok if the
-- computation is cheap, e.g. the body of a table lambda)
_rematPrimal :: Emits o
=> Subst Name i o -> ActivePrimals o
-> LinM i o e1 e2 -> TangentM o (e2 o)
_rematPrimal subst wrt m = do
WithTangent _ lin <- lift11 $ runPrimalM subst wrt m
Distinct <- getDistinct
lin

-- === actual linearization passs ===

-- main API entrypoint
linearize :: (Emits n, Builder SimpIR m) => LamExpr SimpIR n -> SAtom n -> m n (SAtom n, LamExpr SimpIR n)
linearize f x = liftM fromPairE $ liftEmitBuilder $
linearize :: Emits n => SLam n -> SAtom n -> DoubleBuilder SimpIR n (SAtom n, SLam n)
linearize f x = do
runPrimalM idSubst emptyActivePrimals $
linearizeLambdaApp f x
{-# SCC linearize #-}

-- reify the tangent builder as a lambda
linearizeLambdaApp :: Emits o => LamExpr SimpIR i -> SAtom o -> PrimalM i o (PairE SAtom (LamExpr SimpIR) o)
linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam o)
linearizeLambdaApp (UnaryLamExpr b body) x = do
vp <- emitAtomToName noHint x
extendActiveSubst b vp do
WithTangent primalResult tangentAction <- linearizeBlock body
tanFun <- tangentFunAsLambda tangentAction
return $ PairE primalResult tanFun
return (primalResult, tanFun)
linearizeLambdaApp _ _ = error "not implemented"

linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom
Expand Down Expand Up @@ -322,7 +314,29 @@ linearizeDecls (Nest (Let b (DeclBinding ann _ expr)) rest) cont = do
linearizeExpr :: Emits o => SExpr i -> LinM i o SAtom SAtom
linearizeExpr expr = case expr of
Atom x -> linearizeAtom x
TopApp _ _ -> error "not implemented"
TopApp f xs -> do
(xs', ts) <- unzip <$> forM xs \x -> do
x' <- renameM x
isActive x' >>= \case
True -> do
WithTangent x'' t <- dropSubst $ linearizeAtom x'
return (x'', Just (WithTangent (unitLike x'') t))
False -> return (x', Nothing)
f' <- renameM f
-- TODO(dougalm): this works, but I think that what we really want here is
-- to hoist the argument to `linearizeTopFun`, rather than the result. We
-- want to pop all the way up to the top level, hoisting the E-kinded
-- `LinearizationSpec` with us, rather than working underneath all the local
-- bindings and then only hoisting the final result.
Just (PairE fPrimal fTan) <- liftTopBuilderAndEmit $
liftM toPairE $ linearizeTopFun (sink $ LinearizationSpec f' (map isJust ts))
(ans, residuals) <- fromPair =<< naryTopApp fPrimal xs'
return $ WithTangent ans do
ts' <- forM (catMaybes ts) \(WithTangent UnitE t) -> t
naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals] ++ ts')
where
unitLike :: e n -> UnitE n
unitLike _ = UnitE
TabApp x idxs -> do
zipLin (linearizeAtom x) (pureLin $ ListE $ toList idxs) `bindLin`
\(PairE x' (ListE idxs')) -> naryTabApp x' idxs'
Expand All @@ -337,7 +351,6 @@ linearizeExpr expr = case expr of
MGet -> linearizeAtom ref `bindLin` \ref' -> liftM Var $ emit $ RefOp ref' MGet
MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') ->
liftM Var $ emit $ RefOp ref' $ MPut x'

IndexRef i -> zipLin (la ref) (pureLin i) `bindLin`
\(PairE ref' i') -> emitExpr $ RefOp ref' $ IndexRef i'
ProjRef i -> la ref `bindLin` \ref' -> emitExpr $ RefOp ref' $ ProjRef i
Expand Down Expand Up @@ -580,46 +593,6 @@ linearizeHof hof = case hof of
return $ WithTangent ans $ applyLinLamAbs (sink linLam) (sink residuals)
_ -> error $ "not implemented: " ++ pprint hof

_applyCustomLinearization :: Emits o => AtomRules o -> [SAtom i] -> LinM i o SAtom SAtom
_applyCustomLinearization = undefined
-- applyCustomLinearization (CustomLinearize n zeros cl) xs = do
-- let (polyXs, argXs) = splitAt n $ toList xs
-- polyXs' <- mapM (renameM . injectCore) polyXs
-- (any id <$> mapM isActive polyXs') >>= \case
-- True -> error $
-- "Polymorphic arguments of custom linearization rules are " ++
-- "expected to be inactive (i.e. independent of any differentiated " ++
-- "function argument)"
-- False -> return ()
-- wts <- case zeros of
-- InstantiateZeros -> forM (toList argXs) linearizeAtom
-- SymbolicZeros -> do
-- stDefName <- lookupSourceMap "ZeroTangent" >>= \case
-- Just (UDataConVar conName) -> do
-- DataConBinding dataDefName zeroConIx _ <- lookupEnv conName
-- unless (zeroConIx == 0) $ error "Ill-defined SymbolicTangent?"
-- return dataDefName
-- _ -> error "Ill-defined SymbolicTangent?"
-- forM (toList argXs) \arg -> do
-- arg' <- renameM arg
-- argTy' <- getType arg'
-- isActive arg' >>= \case
-- False -> -- Pass in ZeroTangent as the tangent
-- return $ WithTangent (injectCore arg') $
-- return $ sink $ Con $ Newtype
-- (TypeCon "SymbolicTangent" stDefName
-- (DataDefParams [(PlainArrow, injectCore argTy')]))
-- (SumVal [UnitTy, injectCore argTy'] 0 UnitVal)
-- True -> do -- Wrap tangent in SomeTangent
-- WithTangent arg'' argLin <- dropSubst $ linearizeAtom arg'
-- return $ WithTangent arg'' $ argLin <&> \argTan ->
-- Con $ Newtype
-- (TypeCon "SymbolicTangent" (sink stDefName)
-- (DataDefParams [(PlainArrow, sink (injectIRE argTy'))]))
-- (SumVal [UnitTy, sink (injectIRE argTy')] 1 argTan)
-- (ans, flin) <- fromPair =<< naryApp cl (polyXs' ++ (wts <&> \(WithTangent p _) -> p))
-- return $ WithTangent ans $ naryApp (sink flin) =<< sequence (wts <&> \(WithTangent _ t) -> t)

linearizeEffectFun :: RWS -> SLam i -> PrimalM i o (SLam o, LinLamAbs o)
linearizeEffectFun rws (BinaryLamExpr hB refB body) = do
eff <- getAllowedEffects
Expand Down
Loading

0 comments on commit c996e98

Please sign in to comment.