From f4a967190d42c3b9ce47ffd6565255f9b58eb564 Mon Sep 17 00:00:00 2001 From: Sandy Maguire Date: Sat, 27 Feb 2021 09:09:10 -0800 Subject: [PATCH] Simplify tactics state structure (#1449) * Store total number of arguments in TopLevelArgPrv * Re-enable tracing other solutions * Document a bug I ran into while trying to dogfood * Replace the janky (Trace,) extract with something more principled * Add explicit constructors for introducing hypotheses * Split apart creation of hypothesis from introduction of it * Produce better debug output for other solns * Track all the bindings that get generated in Synthesized * Remove ts_intro_vals; use synthesized bindings instead * Track used variables in the Synthesis * Remove a debug trace * Add some documentation about what's happening here. * Minor tidying Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .../src/Ide/Plugin/Tactic.hs | 2 + .../src/Ide/Plugin/Tactic/CodeGen.hs | 75 ++++------ .../src/Ide/Plugin/Tactic/Judgements.hs | 90 +++++------- .../Tactic/KnownStrategies/QuickCheck.hs | 10 +- .../src/Ide/Plugin/Tactic/LanguageServer.hs | 3 +- .../src/Ide/Plugin/Tactic/Machinery.hs | 40 +++--- .../src/Ide/Plugin/Tactic/Tactics.hs | 65 +++++---- .../src/Ide/Plugin/Tactic/Types.hs | 133 +++++++++++------- 8 files changed, 212 insertions(+), 206 deletions(-) diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs index 1a909649e6..3fb774d4d1 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs @@ -23,6 +23,7 @@ import Data.Aeson import Data.Bifunctor (Bifunctor (bimap)) import Data.Bool (bool) import Data.Data (Data) +import Data.Foldable (for_) import Data.Generics.Aliases (mkQ) import Data.Generics.Schemes (everything) import Data.Maybe @@ -144,6 +145,7 @@ mkWorkspaceEdits -> RunTacticResults -> Either ResponseError (Maybe WorkspaceEdit) mkWorkspaceEdits span dflags ccs uri pm rtr = do + for_ (rtr_other_solns rtr) $ traceMX "other solution" let g = graftHole (RealSrcSpan span) rtr response = transform dflags ccs uri g pm in case response of diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs index e3959a629b..c73b6090ff 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs @@ -7,14 +7,11 @@ module Ide.Plugin.Tactic.CodeGen , module Ide.Plugin.Tactic.CodeGen.Utils ) where -import Control.Lens ((%~), (+~), (<>~)) +import Control.Lens ((+~)) import Control.Monad.Except -import Control.Monad.State (MonadState) -import Control.Monad.State.Class (modify) -import Data.Generics.Product (field) +import Data.Generics.Product (field) import Data.List -import qualified Data.Map as M -import qualified Data.Set as S +import qualified Data.Set as S import Data.Traversable import DataCon import Development.IDE.GHC.Compat @@ -29,18 +26,9 @@ import Ide.Plugin.Tactic.Judgements import Ide.Plugin.Tactic.Machinery import Ide.Plugin.Tactic.Naming import Ide.Plugin.Tactic.Types -import Type hiding (Var) +import Type hiding (Var) -useOccName :: MonadState TacticState m => Judgement -> OccName -> m () -useOccName jdg name = - -- Only score points if this is in the local hypothesis - case M.lookup name $ hyByName $ jLocalHypothesis jdg of - Just{} -> modify - $ (withUsedVals $ S.insert name) - . (field @"ts_unused_top_vals" %~ S.delete name) - Nothing -> pure () - ------------------------------------------------------------------------------ -- | Doing recursion incurs a small penalty in the score. @@ -48,13 +36,6 @@ countRecursiveCall :: TacticState -> TacticState countRecursiveCall = field @"ts_recursion_count" +~ 1 ------------------------------------------------------------------------------- --- | Insert some values into the unused top values field. These are --- subsequently removed via 'useOccName'. -addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m () -addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals - - destructMatches :: (DataCon -> Judgement -> Rule) -- ^ How to construct each match @@ -63,7 +44,7 @@ destructMatches -> CType -- ^ Type being destructed -> Judgement - -> RuleM (Trace, [RawMatch]) + -> RuleM (Synthesized [RawMatch]) destructMatches f scrut t jdg = do let hy = jEntireHypothesis jdg g = jGoal jdg @@ -76,16 +57,21 @@ destructMatches f scrut t jdg = do _ -> fmap unzipTrace $ for dcs $ \dc -> do let args = dataConInstOrigArgTys' dc apps names <- mkManyGoodNames (hyNamesInScope hy) args - let hy' = zip names $ coerce args - j = introducingPat scrut dc hy' + let hy' = patternHypothesis scrut dc jdg + $ zip names + $ coerce args + j = introduce hy' $ withNewGoal g jdg - (tr, sg) <- f dc j - modify $ withIntroducedVals $ mappend $ S.fromList names - pure ( rose ("match " <> show dc <> " {" <> + Synthesized tr sc uv sg <- f dc j + pure + $ Synthesized + ( rose ("match " <> show dc <> " {" <> intercalate ", " (fmap show names) <> "}") - $ pure tr - , match [mkDestructPat dc names] $ unLoc sg - ) + $ pure tr) + (sc <> hy') + uv + $ match [mkDestructPat dc names] + $ unLoc sg ------------------------------------------------------------------------------ @@ -114,10 +100,8 @@ infixifyPatIfNecessary dcon x -unzipTrace :: [(Trace, a)] -> (Trace, [a]) -unzipTrace l = - let (trs, as) = unzip l - in (rose mempty trs, as) +unzipTrace :: [Synthesized a] -> Synthesized [a] +unzipTrace = sequenceA -- | Essentially same as 'dataConInstOrigArgTys' in GHC, @@ -154,16 +138,19 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule destruct' f hi jdg = do when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic let term = hi_name hi - useOccName jdg term - (tr, ms) + Synthesized tr sc uv ms <- destructMatches f (Just term) (hi_type hi) $ disallowing AlreadyDestructed [term] jdg - pure ( rose ("destruct " <> show term) $ pure tr - , noLoc $ case' (var' term) ms - ) + pure + $ Synthesized + (rose ("destruct " <> show term) $ pure tr) + sc + (S.insert term uv) + $ noLoc + $ case' (var' term) ms ------------------------------------------------------------------------------ @@ -186,10 +173,10 @@ buildDataCon :: Judgement -> DataCon -- ^ The data con to build -> [Type] -- ^ Type arguments for the data con - -> RuleM (Trace, LHsExpr GhcPs) + -> RuleM (Synthesized (LHsExpr GhcPs)) buildDataCon jdg dc tyapps = do let args = dataConInstOrigArgTys' dc tyapps - (tr, sgs) + Synthesized tr sc uv sgs <- fmap unzipTrace $ traverse ( \(arg, n) -> newSubgoal @@ -199,6 +186,6 @@ buildDataCon jdg dc tyapps = do $ CType arg ) $ zip args [0..] pure - . (rose (show dc) $ pure tr,) + $ Synthesized (rose (show dc) $ pure tr) sc uv $ mkCon dc sgs diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs index f2d830052a..c865f53650 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs @@ -1,33 +1,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -module Ide.Plugin.Tactic.Judgements - ( blacklistingDestruct - , unwhitelistingSplit - , introducingLambda - , introducingRecursively - , introducingPat - , jGoal - , jHypothesis - , jEntireHypothesis - , jPatHypothesis - , substJdg - , unsetIsTopHole - , filterSameTypeFromOtherPositions - , isDestructBlacklisted - , withNewGoal - , jLocalHypothesis - , isSplitWhitelisted - , isPatternMatch - , filterPosition - , isTopHole - , disallowing - , mkFirstJudgement - , hypothesisFromBindings - , isTopLevel - , hyNamesInScope - , hyByName - ) where +module Ide.Plugin.Tactic.Judgements where import Control.Arrow import Control.Lens hiding (Context) @@ -89,35 +63,39 @@ withNewGoal :: a -> Judgement' a -> Judgement' a withNewGoal t = field @"_jGoal" .~ t +introduce :: Hypothesis a -> Judgement' a -> Judgement' a +introduce hy = field @"_jHypothesis" <>~ hy + + ------------------------------------------------------------------------------ -- | Helper function for implementing functions which introduce new hypotheses. -introducing - :: (Int -> Provenance) -- ^ A function from the position of the arg to its - -- provenance. +introduceHypothesis + :: (Int -> Int -> Provenance) + -- ^ A function from the total number of args and position of this arg + -- to its provenance. -> [(OccName, a)] - -> Judgement' a - -> Judgement' a -introducing f ns = - field @"_jHypothesis" <>~ (Hypothesis $ zip [0..] ns <&> - \(pos, (name, ty)) -> HyInfo name (f pos) ty) + -> Hypothesis a +introduceHypothesis f ns = + Hypothesis $ zip [0..] ns <&> \(pos, (name, ty)) -> + HyInfo name (f (length ns) pos) ty ------------------------------------------------------------------------------ -- | Introduce bindings in the context of a lamba. -introducingLambda +lambdaHypothesis :: Maybe OccName -- ^ The name of the top level function. For any other -- function, this should be 'Nothing'. -> [(OccName, a)] - -> Judgement' a - -> Judgement' a -introducingLambda func = introducing $ \pos -> - maybe UserPrv (\x -> TopLevelArgPrv x pos) func + -> Hypothesis a +lambdaHypothesis func = + introduceHypothesis $ \count pos -> + maybe UserPrv (\x -> TopLevelArgPrv x pos count) func ------------------------------------------------------------------------------ -- | Introduce a binding in a recursive context. -introducingRecursively :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducingRecursively = introducing $ const RecursivePrv +recursiveHypothesis :: [(OccName, a)] -> Hypothesis a +recursiveHypothesis = introduceHypothesis $ const $ const RecursivePrv ------------------------------------------------------------------------------ @@ -176,7 +154,7 @@ findPositionVal jdg defn pos = listToMaybe $ do -- ancstry through potentially disallowed terms in the hypothesis. (name, hi) <- M.toList $ M.map (overProvenance expandDisallowed) $ hyByName $ jEntireHypothesis jdg case hi_provenance hi of - TopLevelArgPrv defn' pos' + TopLevelArgPrv defn' pos' _ | defn == defn' , pos == pos' -> pure name PatternMatchPrv pv @@ -243,26 +221,22 @@ extremelyStupid__definingFunction = fst . head . ctxDefiningFuncs ------------------------------------------------------------------------------- --- | Pattern vals are currently tracked in jHypothesis, with an extra piece of --- data sitting around in jPatternVals. -introducingPat +patternHypothesis :: Maybe OccName -> DataCon - -> [(OccName, a)] -> Judgement' a - -> Judgement' a -introducingPat scrutinee dc ns jdg - = introducing (\pos -> + -> [(OccName, a)] + -> Hypothesis a +patternHypothesis scrutinee dc jdg + = introduceHypothesis $ \_ pos -> PatternMatchPrv $ PatVal - scrutinee - (maybe mempty - (\scrut -> S.singleton scrut <> getAncestry jdg scrut) - scrutinee) - (Uniquely dc) - pos - ) ns jdg + scrutinee + (maybe mempty + (\scrut -> S.singleton scrut <> getAncestry jdg scrut) + scrutinee) + (Uniquely dc) + pos ------------------------------------------------------------------------------ diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs index 3fe2263995..25ba3b0832 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs @@ -44,8 +44,13 @@ deriveArbitrary = do terminal_expr = mkVal "terminal" oneof_expr = mkVal "oneof" pure - ( tracePrim "deriveArbitrary" - , noLoc $ + $ Synthesized (tracePrim "deriveArbitrary") + -- TODO(sandy): This thing is not actually empty! We produced + -- a bespoke binding "terminal", and a not-so-bespoke "n". + -- But maybe it's fine for known rules? + mempty + mempty + $ noLoc $ let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $ appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $ case' (infixCall "<=" (mkVal "n") (int 1)) @@ -57,7 +62,6 @@ deriveArbitrary = do (list $ fmap genExpr big) terminal_expr ] - ) _ -> throwError $ GoalMismatch "deriveArbitrary" ty diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/LanguageServer.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/LanguageServer.hs index 1c672116e1..4fcbdabd5a 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/LanguageServer.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/LanguageServer.hs @@ -197,7 +197,8 @@ getRhsPosVals rss tcs , isHole $ occName hole -- and the span is a hole -> First $ do patnames <- traverse getPatName ps - pure $ zip patnames $ [0..] <&> TopLevelArgPrv name + pure $ zip patnames $ [0..] <&> \n -> + TopLevelArgPrv name n (length patnames) _ -> mempty ) tcs diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs index ac1b18463d..bdaa0aa77f 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs @@ -77,14 +77,9 @@ runTactic ctx jdg t = $ toList $ hyByName $ jHypothesis jdg - unused_topvals = M.keysSet - $ M.filter (isTopLevel . hi_provenance) - $ hyByName - $ jHypothesis jdg tacticState = defaultTacticState { ts_skolems = skolems - , ts_unused_top_vals = unused_topvals } in case partitionEithers . flip runReader ctx @@ -93,14 +88,14 @@ runTactic ctx jdg t = (errs, []) -> Left $ take 50 errs (_, fmap assoc23 -> solns) -> do let sorted = - flip sortBy solns $ comparing $ \((_, ext), (jdg, holes)) -> + flip sortBy solns $ comparing $ \(ext, (jdg, holes)) -> Down $ scoreSolution ext jdg holes case sorted of - (((tr, ext), _) : _) -> + ((syn, _) : _) -> Right $ RunTacticResults - { rtr_trace = tr - , rtr_extract = simplify ext + { rtr_trace = syn_trace syn + , rtr_extract = simplify $ syn_val syn , rtr_other_solns = reverse . fmap fst $ take 5 sorted , rtr_jdg = jdg , rtr_ctx = ctx @@ -119,11 +114,11 @@ tracePrim = flip rose [] tracing :: Functor m => String - -> TacticT jdg (Trace, ext) err s m a - -> TacticT jdg (Trace, ext) err s m a + -> TacticT jdg (Synthesized ext) err s m a + -> TacticT jdg (Synthesized ext) err s m a tracing s (TacticT m) = TacticT $ StateT $ \jdg -> - mapExtract' (first $ rose s . pure) $ runStateT m jdg + mapExtract' (mapTrace $ rose s . pure) $ runStateT m jdg ------------------------------------------------------------------------------ @@ -155,10 +150,10 @@ markStructuralySmallerRecursion pv = do -- | Given the results of running a tactic, score the solutions by -- desirability. -- --- TODO(sandy): This function is completely unprincipled and was just hacked --- together to produce the right test results. +-- NOTE: This function is completely unprincipled and was just hacked together +-- to produce the right test results. scoreSolution - :: LHsExpr GhcPs + :: Synthesized (LHsExpr GhcPs) -> TacticState -> [Judgement] -> ( Penalize Int -- number of holes @@ -171,13 +166,18 @@ scoreSolution ) scoreSolution ext TacticState{..} holes = ( Penalize $ length holes - , Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals - , Penalize $ S.size ts_unused_top_vals - , Penalize $ S.size ts_intro_vals - , Reward $ S.size ts_used_vals + , Reward $ S.null $ intro_vals S.\\ used_vals + , Penalize $ S.size unused_top_vals + , Penalize $ S.size intro_vals + , Reward $ S.size used_vals , Penalize ts_recursion_count - , Penalize $ solutionSize ext + , Penalize $ solutionSize $ syn_val ext ) + where + intro_vals = M.keysSet $ hyByName $ syn_scoped ext + used_vals = S.intersection intro_vals $ syn_used_vals ext + top_vals = S.fromList . fmap hi_name . filter (isTopLevel . hi_provenance) $ unHypothesis $ syn_scoped ext + unused_top_vals = top_vals S.\\ used_vals ------------------------------------------------------------------------------ diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs index 0e3c99d016..ae1eda428c 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs @@ -9,18 +9,16 @@ module Ide.Plugin.Tactic.Tactics , runTactic ) where -import Control.Monad (when) -import Control.Monad.Except (throwError) -import Control.Monad.Reader.Class (MonadReader (ask)) +import Control.Monad.Except (throwError) +import Control.Monad.Reader.Class (MonadReader (ask)) import Control.Monad.State.Class -import Control.Monad.State.Strict (StateT (..), runStateT) -import Data.Bool (bool) +import Control.Monad.State.Strict (StateT(..), runStateT) import Data.Foldable import Data.List -import qualified Data.Map as M +import qualified Data.Map as M import Data.Maybe -import Data.Set (Set) -import qualified Data.Set as S +import Data.Set (Set) +import qualified Data.Set as S import DataCon import Development.IDE.GHC.Compat import GHC.Exts @@ -33,11 +31,11 @@ import Ide.Plugin.Tactic.Judgements import Ide.Plugin.Tactic.Machinery import Ide.Plugin.Tactic.Naming import Ide.Plugin.Tactic.Types -import Name (occNameString) +import Name (occNameString) import Refinery.Tactic import Refinery.Tactic.Internal import TcType -import Type hiding (Var) +import Type hiding (Var) ------------------------------------------------------------------------------ @@ -50,13 +48,15 @@ assumption = attemptOn (S.toList . allNames) assume -- | Use something named in the hypothesis to fill the hole. assume :: OccName -> TacticsM () assume name = rule $ \jdg -> do - let g = jGoal jdg case M.lookup name $ hyByName $ jHypothesis jdg of Just (hi_type -> ty) -> do unify ty $ jGoal jdg for_ (M.lookup name $ jPatHypothesis jdg) markStructuralySmallerRecursion - useOccName jdg name - pure $ (tracePrim $ "assume " <> occNameString name, ) $ noLoc $ var' name + pure $ Synthesized (tracePrim $ "assume " <> occNameString name) + mempty + (S.singleton name) + $ noLoc + $ var' name Nothing -> throwError $ UndefinedHypothesis name @@ -64,9 +64,14 @@ recursion :: TacticsM () recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions attemptOn (const defs) $ \(name, ty) -> do + -- TODO(sandy): When we can inspect the extract of a TacticsM bind + -- (requires refinery support), this recursion stack stuff is unnecessary. + -- We can just inspect the extract to see i we used any pattern vals, and + -- then be on our merry way. modify $ pushRecursionStack . countRecursiveCall ensure guardStructurallySmallerRecursion popRecursionStack $ do - (localTactic (apply $ HyInfo name RecursivePrv ty) $ introducingRecursively defs) + let hy' = recursiveHypothesis defs + localTactic (apply $ HyInfo name RecursivePrv ty) (introduce hy') <@> fmap (localTactic assumption . filterPosition name) [0..] @@ -74,21 +79,22 @@ recursion = requireConcreteHole $ tracing "recursion" $ do -- | Introduce a lambda binding every variable. intros :: TacticsM () intros = rule $ \jdg -> do - let hy = jHypothesis jdg - g = jGoal jdg + let g = jGoal jdg ctx <- ask case tcSplitFunTys $ unCType g of ([], _) -> throwError $ GoalMismatch "intros" g (as, b) -> do vs <- mkManyGoodNames (hyNamesInScope $ jEntireHypothesis jdg) as let top_hole = isTopHole ctx jdg - jdg' = introducingLambda top_hole (zip vs $ coerce as) + hy' = lambdaHypothesis top_hole $ zip vs $ coerce as + jdg' = introduce hy' $ withNewGoal (CType b) jdg - modify $ withIntroducedVals $ mappend $ S.fromList vs - when (isJust top_hole) $ addUnusedTopVals $ S.fromList vs - (tr, sg) <- newSubgoal jdg' + Synthesized tr sc uv sg <- newSubgoal jdg' pure - . (rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") $ pure tr, ) + . Synthesized + (rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") $ pure tr) + (sc <> hy') + uv . noLoc . lambda (fmap bvar' vs) $ unLoc sg @@ -148,27 +154,26 @@ homoLambdaCase = apply :: HyInfo CType -> TacticsM () apply hi = requireConcreteHole $ tracing ("apply' " <> show (hi_name hi)) $ do jdg <- goal - let hy = jHypothesis jdg - g = jGoal jdg + let g = jGoal jdg ty = unCType $ hi_type hi func = hi_name hi ty' <- freshTyvars ty let (_, _, args, ret) = tacticsSplitFunTy ty' + -- TODO(sandy): Bug here! Prevents us from doing mono-map like things + -- Don't require new holes for locally bound vars; only respect linearity + -- see https://github.com/haskell/haskell-language-server/issues/1447 requireNewHoles $ rule $ \jdg -> do unify g (CType ret) - useOccName jdg func - (tr, sgs) + Synthesized tr sc uv sgs <- fmap unzipTrace $ traverse ( newSubgoal . blacklistingDestruct . flip withNewGoal jdg . CType ) args - pure - . (tr, ) - . noLoc - . foldl' (@@) (var' func) - $ fmap unLoc sgs + pure $ Synthesized tr sc (S.insert func uv) + $ noLoc . foldl' (@@) (var' func) + $ fmap unLoc sgs ------------------------------------------------------------------------------ diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs index b42998a09f..0ea9c81c8d 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -20,28 +22,29 @@ module Ide.Plugin.Tactic.Types , Range ) where -import Control.Lens hiding (Context, (.=)) -import Control.Monad.Reader -import Control.Monad.State -import Data.Coerce -import Data.Function -import Data.Generics.Product (field) -import Data.Set (Set) -import Data.Tree -import Development.IDE.GHC.Compat hiding (Node) -import Development.IDE.GHC.Orphans () -import Development.IDE.Types.Location -import GHC.Generics -import Ide.Plugin.Tactic.Debug -import Ide.Plugin.Tactic.FeatureSet (FeatureSet) -import OccName -import Refinery.Tactic -import System.IO.Unsafe (unsafePerformIO) -import Type -import UniqSupply (UniqSupply, mkSplitUniqSupply, - takeUniqFromSupply) -import Unique (Uniquable, Unique, getUnique, - nonDetCmpUnique) +import Control.Lens hiding (Context, (.=)) +import Control.Monad.Reader +import Control.Monad.State +import Data.Coerce +import Data.Function +import Data.Generics.Product (field) +import Data.List.NonEmpty (NonEmpty (..)) +import Data.Semigroup +import Data.Set (Set) +import Data.Tree +import Development.IDE.GHC.Compat hiding (Node) +import Development.IDE.GHC.Orphans () +import Development.IDE.Types.Location +import GHC.Generics +import GHC.SourceGen (var) +import Ide.Plugin.Tactic.Debug +import Ide.Plugin.Tactic.FeatureSet (FeatureSet) +import OccName +import Refinery.Tactic +import System.IO.Unsafe (unsafePerformIO) +import Type (TCvSubst, Var, eqType, nonDetCmpType, emptyTCvSubst) +import UniqSupply (takeUniqFromSupply, mkSplitUniqSupply, UniqSupply) +import Unique (nonDetCmpUnique, Uniquable, getUnique, Unique) ------------------------------------------------------------------------------ @@ -80,24 +83,28 @@ instance Show (Pat GhcPs) where ------------------------------------------------------------------------------ +-- | The state that should be shared between subgoals. Extracts move towards +-- the root, judgments move towards the leaves, and the state moves *sideways*. data TacticState = TacticState { ts_skolems :: !(Set TyVar) -- ^ The known skolems. , ts_unifier :: !TCvSubst -- ^ The current substitution of univars. - , ts_used_vals :: !(Set OccName) - -- ^ Set of values used by tactics. - , ts_intro_vals :: !(Set OccName) - -- ^ Set of values introduced by tactics. - , ts_unused_top_vals :: !(Set OccName) - -- ^ Set of currently unused arguments to the function being defined. , ts_recursion_stack :: ![Maybe PatVal] -- ^ Stack for tracking whether or not the current recursive call has -- used at least one smaller pat val. Recursive calls for which this -- value is 'False' are guaranteed to loop, and must be pruned. + -- + -- TODO(sandy): This thing need not exist; we should just inspect + -- 'syn_used_vals' to see if anything was a pattern val. , ts_recursion_count :: !Int -- ^ Number of calls to recursion. We penalize each. - , ts_unique_gen :: !UniqSupply + -- + -- TODO(sandy): This thing need not exist; it should just be a field + -- inside of 'Synthesized', but can't implement that without support from + -- refinery directly. Need the ability to get the extract of a TacticT + -- inside of TacticT, first. + , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) instance Show UniqSupply where @@ -117,9 +124,6 @@ defaultTacticState = TacticState { ts_skolems = mempty , ts_unifier = emptyTCvSubst - , ts_used_vals = mempty - , ts_intro_vals = mempty - , ts_unused_top_vals = mempty , ts_recursion_stack = mempty , ts_recursion_count = 0 , ts_unique_gen = unsafeDefaultUniqueSupply @@ -147,16 +151,6 @@ popRecursionStack :: TacticState -> TacticState popRecursionStack = withRecursionStack tail -withUsedVals :: (Set OccName -> Set OccName) -> TacticState -> TacticState -withUsedVals f = - field @"ts_used_vals" %~ f - - -withIntroducedVals :: (Set OccName -> Set OccName) -> TacticState -> TacticState -withIntroducedVals f = - field @"ts_intro_vals" %~ f - - ------------------------------------------------------------------------------ -- | Describes where hypotheses came from. Used extensively to prune stupid -- solutions from the search space. @@ -167,6 +161,7 @@ data Provenance TopLevelArgPrv OccName -- ^ Binding function Int -- ^ Argument Position + Int -- ^ of how many arguments total? -- | A binding created in a pattern match. | PatternMatchPrv PatVal -- | A class method from the given context. @@ -265,8 +260,12 @@ newtype ExtractM a = ExtractM { unExtractM :: Reader Context a } ------------------------------------------------------------------------------ -- | Orphan instance for producing holes when attempting to solve tactics. -instance MonadExtract (Trace, LHsExpr GhcPs) ExtractM where - hole = pure (mempty, noLoc $ HsVar noExtField $ noLoc $ Unqual $ mkVarOcc "_") +instance MonadExtract (Synthesized (LHsExpr GhcPs)) ExtractM where + hole + = pure + . Synthesized mempty mempty mempty + . noLoc + $ var "_" ------------------------------------------------------------------------------ @@ -325,12 +324,43 @@ instance Show TacticError where ------------------------------------------------------------------------------ -type TacticsM = TacticT Judgement (Trace, LHsExpr GhcPs) TacticError TacticState ExtractM -type RuleM = RuleT Judgement (Trace, LHsExpr GhcPs) TacticError TacticState ExtractM -type Rule = RuleM (Trace, LHsExpr GhcPs) +type TacticsM = TacticT Judgement (Synthesized (LHsExpr GhcPs)) TacticError TacticState ExtractM +type RuleM = RuleT Judgement (Synthesized (LHsExpr GhcPs)) TacticError TacticState ExtractM +type Rule = RuleM (Synthesized (LHsExpr GhcPs)) type Trace = Rose String +------------------------------------------------------------------------------ +-- | The extract for refinery. Represents a "synthesized attribute" in the +-- context of attribute grammars. In essence, 'Synthesized' describes +-- information we'd like to pass from leaves of the tactics search upwards. +-- This includes the actual AST we've generated (in 'syn_val'). +data Synthesized a = Synthesized + { syn_trace :: Trace + -- ^ A tree describing which tactics were used produce the 'syn_val'. + -- Mainly for debugging when you get the wrong answer, to see the other + -- things it tried. + , syn_scoped :: Hypothesis CType + -- ^ All of the bindings created to produce the 'syn_val'. + , syn_used_vals :: Set OccName + -- ^ The values used when synthesizing the 'syn_val'. + , syn_val :: a + } + deriving (Eq, Show, Functor, Foldable, Traversable) + +mapTrace :: (Trace -> Trace) -> Synthesized a -> Synthesized a +mapTrace f (Synthesized tr sc uv a) = Synthesized (f tr) sc uv a + + +------------------------------------------------------------------------------ +-- | This might not be lawful, due to the semigroup on 'Trace' maybe not being +-- lawful. But that's only for debug output, so it's not anything I'm concerned +-- about. +instance Applicative Synthesized where + pure = Synthesized mempty mempty mempty + Synthesized tr1 sc1 uv1 f <*> Synthesized tr2 sc2 uv2 a = + Synthesized (tr1 <> tr2) (sc1 <> sc2) (uv1 <> uv2) $ f a + ------------------------------------------------------------------------------ -- | The Reader context of tactics and rules @@ -361,10 +391,13 @@ dropEveryOther [] = [] dropEveryOther [a] = [a] dropEveryOther (a : _ : as) = a : dropEveryOther as -instance Semigroup a => Semigroup (Rose a) where +------------------------------------------------------------------------------ +-- | This might not be lawful! I didn't check, and it feels sketchy. +instance (Eq a, Monoid a) => Semigroup (Rose a) where Rose (Node a as) <> Rose (Node b bs) = Rose $ Node (a <> b) (as <> bs) + sconcat (a :| as) = rose mempty $ a : as -instance Monoid a => Monoid (Rose a) where +instance (Eq a, Monoid a) => Monoid (Rose a) where mempty = Rose $ Node mempty mempty rose :: (Eq a, Monoid a) => a -> [Rose a] -> Rose a @@ -377,7 +410,7 @@ rose a rs = Rose $ Node a $ coerce rs data RunTacticResults = RunTacticResults { rtr_trace :: Trace , rtr_extract :: LHsExpr GhcPs - , rtr_other_solns :: [(Trace, LHsExpr GhcPs)] + , rtr_other_solns :: [Synthesized (LHsExpr GhcPs)] , rtr_jdg :: Judgement , rtr_ctx :: Context } deriving Show