Skip to content

Commit

Permalink
Simplify tactics state structure (#1449)
Browse files Browse the repository at this point in the history
* Store total number of arguments in TopLevelArgPrv

* Re-enable tracing other solutions

* Document a bug I ran into while trying to dogfood

* Replace the janky (Trace,) extract with something more principled

* Add explicit constructors for introducing hypotheses

* Split apart creation of hypothesis from introduction of it

* Produce better debug output for other solns

* Track all the bindings that get generated in Synthesized

* Remove ts_intro_vals; use synthesized bindings instead

* Track used variables in the Synthesis

* Remove a debug trace

* Add some documentation about what's happening here.

* Minor tidying

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
isovector and mergify[bot] committed Feb 27, 2021
1 parent 7d416f0 commit f4a9671
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 206 deletions.
2 changes: 2 additions & 0 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs
Expand Up @@ -23,6 +23,7 @@ import Data.Aeson
import Data.Bifunctor (Bifunctor (bimap))
import Data.Bool (bool)
import Data.Data (Data)
import Data.Foldable (for_)
import Data.Generics.Aliases (mkQ)
import Data.Generics.Schemes (everything)
import Data.Maybe
Expand Down Expand Up @@ -144,6 +145,7 @@ mkWorkspaceEdits
-> RunTacticResults
-> Either ResponseError (Maybe WorkspaceEdit)
mkWorkspaceEdits span dflags ccs uri pm rtr = do
for_ (rtr_other_solns rtr) $ traceMX "other solution"
let g = graftHole (RealSrcSpan span) rtr
response = transform dflags ccs uri g pm
in case response of
Expand Down
75 changes: 31 additions & 44 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs
Expand Up @@ -7,14 +7,11 @@ module Ide.Plugin.Tactic.CodeGen
, module Ide.Plugin.Tactic.CodeGen.Utils
) where

import Control.Lens ((%~), (+~), (<>~))
import Control.Lens ((+~))
import Control.Monad.Except
import Control.Monad.State (MonadState)
import Control.Monad.State.Class (modify)
import Data.Generics.Product (field)
import Data.Generics.Product (field)
import Data.List
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Set as S
import Data.Traversable
import DataCon
import Development.IDE.GHC.Compat
Expand All @@ -29,32 +26,16 @@ import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.Machinery
import Ide.Plugin.Tactic.Naming
import Ide.Plugin.Tactic.Types
import Type hiding (Var)
import Type hiding (Var)


useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
useOccName jdg name =
-- Only score points if this is in the local hypothesis
case M.lookup name $ hyByName $ jLocalHypothesis jdg of
Just{} -> modify
$ (withUsedVals $ S.insert name)
. (field @"ts_unused_top_vals" %~ S.delete name)
Nothing -> pure ()


------------------------------------------------------------------------------
-- | Doing recursion incurs a small penalty in the score.
countRecursiveCall :: TacticState -> TacticState
countRecursiveCall = field @"ts_recursion_count" +~ 1


------------------------------------------------------------------------------
-- | Insert some values into the unused top values field. These are
-- subsequently removed via 'useOccName'.
addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m ()
addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals


destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
Expand All @@ -63,7 +44,7 @@ destructMatches
-> CType
-- ^ Type being destructed
-> Judgement
-> RuleM (Trace, [RawMatch])
-> RuleM (Synthesized [RawMatch])
destructMatches f scrut t jdg = do
let hy = jEntireHypothesis jdg
g = jGoal jdg
Expand All @@ -76,16 +57,21 @@ destructMatches f scrut t jdg = do
_ -> fmap unzipTrace $ for dcs $ \dc -> do
let args = dataConInstOrigArgTys' dc apps
names <- mkManyGoodNames (hyNamesInScope hy) args
let hy' = zip names $ coerce args
j = introducingPat scrut dc hy'
let hy' = patternHypothesis scrut dc jdg
$ zip names
$ coerce args
j = introduce hy'
$ withNewGoal g jdg
(tr, sg) <- f dc j
modify $ withIntroducedVals $ mappend $ S.fromList names
pure ( rose ("match " <> show dc <> " {" <>
Synthesized tr sc uv sg <- f dc j
pure
$ Synthesized
( rose ("match " <> show dc <> " {" <>
intercalate ", " (fmap show names) <> "}")
$ pure tr
, match [mkDestructPat dc names] $ unLoc sg
)
$ pure tr)
(sc <> hy')
uv
$ match [mkDestructPat dc names]
$ unLoc sg


------------------------------------------------------------------------------
Expand Down Expand Up @@ -114,10 +100,8 @@ infixifyPatIfNecessary dcon x



unzipTrace :: [(Trace, a)] -> (Trace, [a])
unzipTrace l =
let (trs, as) = unzip l
in (rose mempty trs, as)
unzipTrace :: [Synthesized a] -> Synthesized [a]
unzipTrace = sequenceA


-- | Essentially same as 'dataConInstOrigArgTys' in GHC,
Expand Down Expand Up @@ -154,16 +138,19 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule
destruct' f hi jdg = do
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
let term = hi_name hi
useOccName jdg term
(tr, ms)
Synthesized tr sc uv ms
<- destructMatches
f
(Just term)
(hi_type hi)
$ disallowing AlreadyDestructed [term] jdg
pure ( rose ("destruct " <> show term) $ pure tr
, noLoc $ case' (var' term) ms
)
pure
$ Synthesized
(rose ("destruct " <> show term) $ pure tr)
sc
(S.insert term uv)
$ noLoc
$ case' (var' term) ms


------------------------------------------------------------------------------
Expand All @@ -186,10 +173,10 @@ buildDataCon
:: Judgement
-> DataCon -- ^ The data con to build
-> [Type] -- ^ Type arguments for the data con
-> RuleM (Trace, LHsExpr GhcPs)
-> RuleM (Synthesized (LHsExpr GhcPs))
buildDataCon jdg dc tyapps = do
let args = dataConInstOrigArgTys' dc tyapps
(tr, sgs)
Synthesized tr sc uv sgs
<- fmap unzipTrace
$ traverse ( \(arg, n) ->
newSubgoal
Expand All @@ -199,6 +186,6 @@ buildDataCon jdg dc tyapps = do
$ CType arg
) $ zip args [0..]
pure
. (rose (show dc) $ pure tr,)
$ Synthesized (rose (show dc) $ pure tr) sc uv
$ mkCon dc sgs

90 changes: 32 additions & 58 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs
@@ -1,33 +1,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Ide.Plugin.Tactic.Judgements
( blacklistingDestruct
, unwhitelistingSplit
, introducingLambda
, introducingRecursively
, introducingPat
, jGoal
, jHypothesis
, jEntireHypothesis
, jPatHypothesis
, substJdg
, unsetIsTopHole
, filterSameTypeFromOtherPositions
, isDestructBlacklisted
, withNewGoal
, jLocalHypothesis
, isSplitWhitelisted
, isPatternMatch
, filterPosition
, isTopHole
, disallowing
, mkFirstJudgement
, hypothesisFromBindings
, isTopLevel
, hyNamesInScope
, hyByName
) where
module Ide.Plugin.Tactic.Judgements where

import Control.Arrow
import Control.Lens hiding (Context)
Expand Down Expand Up @@ -89,35 +63,39 @@ withNewGoal :: a -> Judgement' a -> Judgement' a
withNewGoal t = field @"_jGoal" .~ t


introduce :: Hypothesis a -> Judgement' a -> Judgement' a
introduce hy = field @"_jHypothesis" <>~ hy


------------------------------------------------------------------------------
-- | Helper function for implementing functions which introduce new hypotheses.
introducing
:: (Int -> Provenance) -- ^ A function from the position of the arg to its
-- provenance.
introduceHypothesis
:: (Int -> Int -> Provenance)
-- ^ A function from the total number of args and position of this arg
-- to its provenance.
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducing f ns =
field @"_jHypothesis" <>~ (Hypothesis $ zip [0..] ns <&>
\(pos, (name, ty)) -> HyInfo name (f pos) ty)
-> Hypothesis a
introduceHypothesis f ns =
Hypothesis $ zip [0..] ns <&> \(pos, (name, ty)) ->
HyInfo name (f (length ns) pos) ty


------------------------------------------------------------------------------
-- | Introduce bindings in the context of a lamba.
introducingLambda
lambdaHypothesis
:: Maybe OccName -- ^ The name of the top level function. For any other
-- function, this should be 'Nothing'.
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducingLambda func = introducing $ \pos ->
maybe UserPrv (\x -> TopLevelArgPrv x pos) func
-> Hypothesis a
lambdaHypothesis func =
introduceHypothesis $ \count pos ->
maybe UserPrv (\x -> TopLevelArgPrv x pos count) func


------------------------------------------------------------------------------
-- | Introduce a binding in a recursive context.
introducingRecursively :: [(OccName, a)] -> Judgement' a -> Judgement' a
introducingRecursively = introducing $ const RecursivePrv
recursiveHypothesis :: [(OccName, a)] -> Hypothesis a
recursiveHypothesis = introduceHypothesis $ const $ const RecursivePrv


------------------------------------------------------------------------------
Expand Down Expand Up @@ -176,7 +154,7 @@ findPositionVal jdg defn pos = listToMaybe $ do
-- ancstry through potentially disallowed terms in the hypothesis.
(name, hi) <- M.toList $ M.map (overProvenance expandDisallowed) $ hyByName $ jEntireHypothesis jdg
case hi_provenance hi of
TopLevelArgPrv defn' pos'
TopLevelArgPrv defn' pos' _
| defn == defn'
, pos == pos' -> pure name
PatternMatchPrv pv
Expand Down Expand Up @@ -243,26 +221,22 @@ extremelyStupid__definingFunction =
fst . head . ctxDefiningFuncs


------------------------------------------------------------------------------
-- | Pattern vals are currently tracked in jHypothesis, with an extra piece of
-- data sitting around in jPatternVals.
introducingPat
patternHypothesis
:: Maybe OccName
-> DataCon
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducingPat scrutinee dc ns jdg
= introducing (\pos ->
-> [(OccName, a)]
-> Hypothesis a
patternHypothesis scrutinee dc jdg
= introduceHypothesis $ \_ pos ->
PatternMatchPrv $
PatVal
scrutinee
(maybe mempty
(\scrut -> S.singleton scrut <> getAncestry jdg scrut)
scrutinee)
(Uniquely dc)
pos
) ns jdg
scrutinee
(maybe mempty
(\scrut -> S.singleton scrut <> getAncestry jdg scrut)
scrutinee)
(Uniquely dc)
pos


------------------------------------------------------------------------------
Expand Down
Expand Up @@ -44,8 +44,13 @@ deriveArbitrary = do
terminal_expr = mkVal "terminal"
oneof_expr = mkVal "oneof"
pure
( tracePrim "deriveArbitrary"
, noLoc $
$ Synthesized (tracePrim "deriveArbitrary")
-- TODO(sandy): This thing is not actually empty! We produced
-- a bespoke binding "terminal", and a not-so-bespoke "n".
-- But maybe it's fine for known rules?
mempty
mempty
$ noLoc $
let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $
appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $
case' (infixCall "<=" (mkVal "n") (int 1))
Expand All @@ -57,7 +62,6 @@ deriveArbitrary = do
(list $ fmap genExpr big)
terminal_expr
]
)
_ -> throwError $ GoalMismatch "deriveArbitrary" ty


Expand Down
Expand Up @@ -197,7 +197,8 @@ getRhsPosVals rss tcs
, isHole $ occName hole -- and the span is a hole
-> First $ do
patnames <- traverse getPatName ps
pure $ zip patnames $ [0..] <&> TopLevelArgPrv name
pure $ zip patnames $ [0..] <&> \n ->
TopLevelArgPrv name n (length patnames)
_ -> mempty
) tcs

Expand Down

0 comments on commit f4a9671

Please sign in to comment.