Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove recursion tracking from TacticState #1453

Merged
merged 6 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 19 additions & 31 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}

Expand All @@ -7,9 +8,10 @@ module Ide.Plugin.Tactic.CodeGen
, module Ide.Plugin.Tactic.CodeGen.Utils
) where

import Control.Lens ((+~))

import Control.Lens ((%~), (<>~), (&))
import Control.Monad.Except
import Data.Generics.Product (field)
import Data.Generics.Labels ()
import Data.List
import qualified Data.Set as S
import Data.Traversable
Expand All @@ -29,13 +31,6 @@ import Ide.Plugin.Tactic.Types
import Type hiding (Var)



------------------------------------------------------------------------------
-- | Doing recursion incurs a small penalty in the score.
countRecursiveCall :: TacticState -> TacticState
countRecursiveCall = field @"ts_recursion_count" +~ 1


destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
Expand All @@ -62,16 +57,12 @@ destructMatches f scrut t jdg = do
$ coerce args
j = introduce hy'
$ withNewGoal g jdg
Synthesized tr sc uv sg <- f dc j
pure
$ Synthesized
( rose ("match " <> show dc <> " {" <>
intercalate ", " (fmap show names) <> "}")
$ pure tr)
(sc <> hy')
uv
$ match [mkDestructPat dc names]
$ unLoc sg
ext <- f dc j
pure $ ext
& #syn_trace %~ rose ("match " <> show dc <> " {" <> intercalate ", " (fmap show names) <> "}")
. pure
& #syn_scoped <>~ hy'
& #syn_val %~ match [mkDestructPat dc names] . unLoc


------------------------------------------------------------------------------
Expand Down Expand Up @@ -138,19 +129,16 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule
destruct' f hi jdg = do
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
let term = hi_name hi
Synthesized tr sc uv ms
ext
<- destructMatches
f
(Just term)
(hi_type hi)
$ disallowing AlreadyDestructed [term] jdg
pure
$ Synthesized
(rose ("destruct " <> show term) $ pure tr)
sc
(S.insert term uv)
$ noLoc
$ case' (var' term) ms
pure $ ext
& #syn_trace %~ rose ("destruct " <> show term) . pure
& #syn_used_vals %~ S.insert term
& #syn_val %~ noLoc . case' (var' term)


------------------------------------------------------------------------------
Expand All @@ -176,7 +164,7 @@ buildDataCon
-> RuleM (Synthesized (LHsExpr GhcPs))
buildDataCon jdg dc tyapps = do
let args = dataConInstOrigArgTys' dc tyapps
Synthesized tr sc uv sgs
ext
<- fmap unzipTrace
$ traverse ( \(arg, n) ->
newSubgoal
Expand All @@ -185,7 +173,7 @@ buildDataCon jdg dc tyapps = do
. flip withNewGoal jdg
$ CType arg
) $ zip args [0..]
pure
$ Synthesized (rose (show dc) $ pure tr) sc uv
$ mkCon dc sgs
pure $ ext
& #syn_trace %~ rose (show dc) . pure
& #syn_val %~ mkCon dc

Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ deriveArbitrary = do
-- But maybe it's fine for known rules?
mempty
mempty
mempty
$ noLoc $
let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $
appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $
Expand Down
115 changes: 68 additions & 47 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,35 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Ide.Plugin.Tactic.Machinery
( module Ide.Plugin.Tactic.Machinery
) where
module Ide.Plugin.Tactic.Machinery where

import Class (Class (classTyVars))
import Control.Arrow
import Class (Class (classTyVars))
import Control.Lens ((<>~))
import Control.Monad.Error.Class
import Control.Monad.Reader
import Control.Monad.State (MonadState (..))
import Control.Monad.State.Class (gets, modify)
import Control.Monad.State.Strict (StateT (..))
import Data.Bool (bool)
import Control.Monad.State.Class (gets, modify)
import Control.Monad.State.Strict (StateT (..))
import Data.Bool (bool)
import Data.Coerce
import Data.Either
import Data.Foldable
import Data.Functor ((<&>))
import Data.Generics (everything, gcount, mkQ)
import Data.List (sortBy)
import qualified Data.Map as M
import Data.Ord (Down (..), comparing)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Functor ((<&>))
import Data.Generics (everything, gcount, mkQ)
import Data.Generics.Product (field')
import Data.List (sortBy)
import qualified Data.Map as M
import Data.Monoid (getSum)
import Data.Ord (Down (..), comparing)
import Data.Set (Set)
import qualified Data.Set as S
import Development.IDE.GHC.Compat
import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.Simplify (simplify)
import Ide.Plugin.Tactic.Simplify (simplify)
import Ide.Plugin.Tactic.Types
import OccName (HasOccName (occName))
import OccName (HasOccName (occName))
import Refinery.ProofState
import Refinery.Tactic
import Refinery.Tactic.Internal
Expand Down Expand Up @@ -88,8 +87,8 @@ runTactic ctx jdg t =
(errs, []) -> Left $ take 50 errs
(_, fmap assoc23 -> solns) -> do
let sorted =
flip sortBy solns $ comparing $ \(ext, (jdg, holes)) ->
Down $ scoreSolution ext jdg holes
flip sortBy solns $ comparing $ \(ext, (_, holes)) ->
Down $ scoreSolution ext holes
case sorted of
((syn, _) : _) ->
Right $
Expand All @@ -111,39 +110,37 @@ tracePrim :: String -> Trace
tracePrim = flip rose []


------------------------------------------------------------------------------
-- | Mark that a tactic used the given string in its extract derivation. Mainly
-- used for debugging the search when things go terribly wrong.
tracing
:: Functor m
=> String
-> TacticT jdg (Synthesized ext) err s m a
-> TacticT jdg (Synthesized ext) err s m a
tracing s (TacticT m)
= TacticT $ StateT $ \jdg ->
mapExtract' (mapTrace $ rose s . pure) $ runStateT m jdg
tracing s = mappingExtract (mapTrace $ rose s . pure)


------------------------------------------------------------------------------
-- | Recursion is allowed only when we can prove it is on a structurally
-- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller
-- pattern val.
guardStructurallySmallerRecursion
:: TacticState
-> Maybe TacticError
guardStructurallySmallerRecursion s =
case head $ ts_recursion_stack s of
Just _ -> Nothing
Nothing -> Just NoProgress
-- | Mark that a tactic performed recursion. Doing so incurs a small penalty in
-- the score.
markRecursion
:: Functor m
=> TacticT jdg (Synthesized ext) err s m a
-> TacticT jdg (Synthesized ext) err s m a
markRecursion = mappingExtract (field' @"syn_recursion_count" <>~ 1)


------------------------------------------------------------------------------
-- | Mark that the current recursive call is structurally smaller, due to
-- having been matched on a pattern value.
--
-- Implemented by setting the top of the 'ts_recursion_stack'.
markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m ()
markStructuralySmallerRecursion pv = do
modify $ withRecursionStack $ \case
(_ : bs) -> Just pv : bs
[] -> []
-- | Map a function over the extract created by a tactic.
mappingExtract
:: Functor m
=> (ext -> ext)
-> TacticT jdg ext err s m a
-> TacticT jdg ext err s m a
mappingExtract f (TacticT m)
= TacticT $ StateT $ \jdg ->
mapExtract' f $ runStateT m jdg


------------------------------------------------------------------------------
Expand All @@ -154,7 +151,6 @@ markStructuralySmallerRecursion pv = do
-- to produce the right test results.
scoreSolution
:: Synthesized (LHsExpr GhcPs)
-> TacticState
-> [Judgement]
-> ( Penalize Int -- number of holes
, Reward Bool -- all bindings used
Expand All @@ -164,19 +160,23 @@ scoreSolution
, Penalize Int -- number of recursive calls
, Penalize Int -- size of extract
)
scoreSolution ext TacticState{..} holes
scoreSolution ext holes
= ( Penalize $ length holes
, Reward $ S.null $ intro_vals S.\\ used_vals
, Penalize $ S.size unused_top_vals
, Penalize $ S.size intro_vals
, Reward $ S.size used_vals
, Penalize ts_recursion_count
, Penalize $ getSum $ syn_recursion_count ext
, Penalize $ solutionSize $ syn_val ext
)
where
intro_vals = M.keysSet $ hyByName $ syn_scoped ext
used_vals = S.intersection intro_vals $ syn_used_vals ext
top_vals = S.fromList . fmap hi_name . filter (isTopLevel . hi_provenance) $ unHypothesis $ syn_scoped ext
top_vals = S.fromList
. fmap hi_name
. filter (isTopLevel . hi_provenance)
. unHypothesis
$ syn_scoped ext
unused_top_vals = top_vals S.\\ used_vals


Expand Down Expand Up @@ -240,6 +240,26 @@ methodHypothesis ty = do
)


------------------------------------------------------------------------------
-- | Mystical time-traveling combinator for inspecting the extracts produced by
-- a tactic. We can use it to guard that extracts match certain predicates, for
-- example.
--
-- Note, that this thing is WEIRD. To illustrate:
--
-- @@
-- peek f
-- blah
-- @@
--
-- Here, @f@ can inspect the extract _produced by @blah@,_ which means the
-- causality appears to go backwards.
--
-- 'peek' should be exposed directly by @refinery@ in the next release.
peek :: (ext -> TacticT jdg ext err s m ()) -> TacticT jdg ext err s m ()
peek k = tactic $ \j -> Subgoal ((), j) $ \e -> proofState (k e) j


------------------------------------------------------------------------------
-- | Run the given tactic iff the current hole contains no univars. Skolems and
-- already decided univars are OK though.
Expand All @@ -251,3 +271,4 @@ requireConcreteHole m = do
case S.size $ vars S.\\ skolems of
0 -> m
_ -> throwError TooPolymorphic