-
Notifications
You must be signed in to change notification settings - Fork 9
/
Solver.hs
348 lines (316 loc) · 13.6 KB
/
Solver.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
{-|
Copyright : (C) 2015-2016, University of Twente
License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
To use the plugin, add the
@
{\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
@
pragma to the header of your file
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.Extra.Solver
( plugin )
where
-- external
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Maybe (catMaybes)
import GHC.TcPluginM.Extra (evByFiat, tracePlugin, newWanted)
import qualified Data.Type.Ord
import qualified GHC.TypeError
-- GHC API
import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon)
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Core.Coercion (mkUnivCo)
import GHC.Core.DataCon (dataConWrapId)
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType)
import GHC.Core.Reduction (Reduction(..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind)
#if MIN_VERSION_ghc(9,6,0)
import GHC.Core.TyCo.Compare (eqType)
#else
import GHC.Core.Type (eqType)
#endif
import GHC.Data.IOEnv (getEnv)
import GHC.Driver.Env (hsc_NC)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Plugins (thNameToGhcNameIO)
import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace, tcPluginIO, unsafeTcPluginTcM)
import GHC.Tc.Types (TcPlugin(..), TcPluginSolveResult (..), TcPluginRewriter, TcPluginRewriteResult (..), Env (env_top))
import GHC.Tc.Types.Constraint
(Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt)
#if MIN_VERSION_ghc(9,8,0)
import GHC.Tc.Types.Constraint (Ct (..), DictCt(..), EqCt(..), IrredCt(..), qci_ev)
#else
import GHC.Tc.Types.Constraint (Ct (CQuantCan), qci_ev, cc_ev)
#endif
import GHC.Tc.Types.Evidence (EvTerm, EvBindsVar, Role(..), evCast, evId)
import GHC.Types.Unique.FM (UniqFM, listToUFM)
import GHC.Utils.Outputable (Outputable (..), (<+>), ($$), text)
import GHC (Name)
-- template-haskell
import qualified Language.Haskell.TH as TH
-- internal
import GHC.TypeLits.Extra.Solver.Operations
import GHC.TypeLits.Extra.Solver.Unify
import GHC.TypeLits.Extra
-- | A solver implement as a type-checker plugin for:
--
-- * 'Div': type-level 'div'
--
-- * 'Mod': type-level 'mod'
--
-- * 'FLog': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- .i.e. the exact integer equivalent to "@'floor' ('logBase' x y)@"
--
-- * 'CLog': type-level equivalent of /the ceiling of/ <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- .i.e. the exact integer equivalent to "@'ceiling' ('logBase' x y)@"
--
-- * 'Log': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- where the operation only reduces when "@'floor' ('logBase' b x) ~ 'ceiling' ('logBase' b x)@"
--
-- * 'GCD': a type-level 'gcd'
--
-- * 'LCM': a type-level 'lcm'
--
-- To use the plugin, add
--
-- @
-- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
-- @
--
-- To the header of your file.
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
, pluginRecompile = purePlugin
}
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-extra"
TcPlugin { tcPluginInit = lookupExtraDefs
, tcPluginSolve = decideEqualSOP
, tcPluginRewrite = extraRewrite
, tcPluginStop = const (return ())
}
extraRewrite :: ExtraDefs -> UniqFM TyCon TcPluginRewriter
extraRewrite defs = listToUFM
[ (gcdTyCon defs, gcdRewrite)
, (lcmTyCon defs, lcmRewrite)
]
where
gcdRewrite _ _ args@[LitTy (NumTyLit i), LitTy (NumTyLit j)] = pure $
TcPluginRewriteTo (reduce (gcdTyCon defs) args (LitTy (NumTyLit (i `gcd` j)))) []
gcdRewrite _ _ _ = pure TcPluginNoRewrite
lcmRewrite _ _ args@[LitTy (NumTyLit i), LitTy (NumTyLit j)] = pure $
TcPluginRewriteTo (reduce (lcmTyCon defs) args (LitTy (NumTyLit (i `lcm` j)))) []
lcmRewrite _ _ _ = pure TcPluginNoRewrite
reduce tc args res = Reduction co res
where
co = mkUnivCo (PluginProv "ghc-typelits-extra") Nominal
(mkTyConApp tc args) res
decideEqualSOP :: ExtraDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
decideEqualSOP _ _ _givens [] = return (TcPluginOk [] [])
decideEqualSOP defs _ givens wanteds = do
unit_wanteds <- catMaybes <$> mapM (runMaybeT . toSolverConstraint defs) wanteds
case unit_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
unit_givens <- catMaybes <$> mapM (runMaybeT . toSolverConstraint defs) givens
sr <- simplifyExtra defs (unit_givens ++ unit_wanteds)
tcPluginTrace "normalised" (ppr sr)
case sr of
Simplified evs new -> return (TcPluginOk (filter (isWantedCt . snd) evs) new)
Impossible eq -> return (TcPluginContradiction [fromSolverConstraint eq])
data SolverConstraint
= NatEquality Ct ExtraOp ExtraOp Normalised
| NatInequality Ct ExtraOp ExtraOp Bool Normalised
instance Outputable SolverConstraint where
ppr (NatEquality ct op1 op2 norm) = text "NatEquality" $$ ppr ct $$ ppr op1 $$ ppr op2 $$ ppr norm
ppr (NatInequality _ op1 op2 b norm) = text "NatInequality" $$ ppr op1 $$ ppr op2 $$ ppr b $$ ppr norm
data SimplifyResult
= Simplified [(EvTerm,Ct)] [Ct]
| Impossible SolverConstraint
instance Outputable SimplifyResult where
ppr (Simplified evs new) = text "Simplified" $$ text "Solved:" $$ ppr evs $$ text "New:" $$ ppr new
ppr (Impossible sct) = text "Impossible" <+> ppr sct
simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [] eqs
where
simples :: [Maybe (EvTerm, Ct)] -> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples evs news [] = return (Simplified (catMaybes evs) news)
simples evs news (eq@(NatEquality ct u v norm):eqs') = do
ur <- unifyExtra ct u v
tcPluginTrace "unifyExtra result" (ppr ur)
case ur of
Win -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'
Lose | null evs && null eqs' -> return (Impossible eq)
_ | norm == Normalised && isWantedCt ct -> do
newCt <- createWantedFromNormalised defs eq
simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs'
Lose -> simples evs news eqs'
Draw -> simples evs news eqs'
simples evs news (eq@(NatInequality ct u v b norm):eqs') = do
tcPluginTrace "unifyExtra leq result" (ppr (u,v,b))
case (u,v) of
(I i,I j)
| (i <= j) == b -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'
-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
(p, q@(V _))
| b -> case findMax q eqs of
Just m -> simples evs news (NatInequality ct p m b norm:eqs')
Nothing -> simples evs news eqs'
_ | norm == Normalised && isWantedCt ct -> do
newCt <- createWantedFromNormalised defs eq
simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs'
_ -> simples evs news eqs'
-- look for given constraint with the form: c ~ Max x y
findMax :: ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
findMax c = go
where
go [] = Nothing
go ((NatEquality ct a b@(Max _ _) _) :_)
| c == a && not (isWantedCt ct)
= Just b
go ((NatEquality ct a@(Max _ _) b _) :_)
| c == b && not (isWantedCt ct)
= Just a
go (_:rest) = go rest
-- Extract the Nat equality constraints
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| isNatKind (typeKind t1) || isNatKind (typeKind t2)
-> do
(t1', n1) <- normaliseNat defs t1
(t2', n2) <- normaliseNat defs t2
pure (NatEquality ct t1' t2' (mergeNormalised n1 n2))
| TyConApp tc [_,cmpNat,TyConApp tt1 [],TyConApp tt2 [],TyConApp ff1 []] <- t1
, tc == ordTyCon defs
, TyConApp cmpNatTc [x,y] <- cmpNat
, cmpNatTc == typeNatCmpTyCon
, tt1 == promotedTrueDataCon
, tt2 == promotedTrueDataCon
, ff1 == promotedFalseDataCon
, TyConApp tc' [] <- t2
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
let res | tc' == promotedTrueDataCon = pure (NatInequality ct x' y' True (mergeNormalised n1 n2))
| tc' == promotedFalseDataCon = pure (NatInequality ct x' y' False (mergeNormalised n1 n2))
| otherwise = fail "Nothing"
res
| TyConApp tc [TyConApp ordCondTc zs, _] <- t1
, tc == assertTC defs
, TyConApp tc' [] <- t2
, tc' == cTupleTyCon 0
, ordCondTc == ordTyCon defs
, [_,cmp,lt,eq,gt] <- zs
, TyConApp tcCmpNat [x,y] <- cmp
, tcCmpNat == typeNatCmpTyCon
, TyConApp ltTc [] <- lt
, ltTc == promotedTrueDataCon
, TyConApp eqTc [] <- eq
, eqTc == promotedTrueDataCon
, TyConApp gtTc [] <- gt
, gtTc == promotedFalseDataCon
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
pure (NatInequality ct x' y' True (mergeNormalised n1 n2))
IrredPred (TyConApp tc [TyConApp ordCondTc zs, _])
| tc == assertTC defs
, ordCondTc == ordTyCon defs
, [_,cmp,lt,eq,gt] <- zs
, TyConApp tcCmpNat [x,y] <- cmp
, tcCmpNat == typeNatCmpTyCon
, TyConApp ltTc [] <- lt
, ltTc == promotedTrueDataCon
, TyConApp eqTc [] <- eq
, eqTc == promotedTrueDataCon
, TyConApp gtTc [] <- gt
, gtTc == promotedFalseDataCon
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
pure (NatInequality ct x' y' True (mergeNormalised n1 n2))
_ -> fail "Nothing"
where
isNatKind :: Kind -> Bool
isNatKind = (`eqType` naturalTy)
createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised defs sct = do
let extractCtSides (NatEquality ct t1 t2 _) = (ct, reifyEOP defs t1, reifyEOP defs t2)
extractCtSides (NatInequality ct x y b _) =
let tc = if b then promotedTrueDataCon else promotedFalseDataCon
t1 = TyConApp (ordTyCon defs)
[ boolTy
, TyConApp typeNatCmpTyCon [reifyEOP defs x, reifyEOP defs y]
, TyConApp promotedTrueDataCon []
, TyConApp promotedTrueDataCon []
, TyConApp promotedFalseDataCon []
]
t2 = TyConApp tc []
in (ct, t1, t2)
let (ct, t1, t2) = extractCtSides sct
newPredTy <- case splitTyConApp_maybe $ ctEvPred $ ctEvidence ct of
Just (tc, [a, b, _, _]) | tc `hasKey` eqPrimTyConKey -> pure (mkTyConApp tc [a, b, t1, t2])
Just (tc, [_, b]) | tc `hasKey` getUnique (assertTC defs) -> pure (mkTyConApp tc [t1,b])
_ -> error "Impossible: neither (<=?) nor Assert"
ev <- newWanted (ctLoc ct) newPredTy
let ctN = case ct of
CQuantCan qc -> CQuantCan (qc { qci_ev = ev})
#if MIN_VERSION_ghc(9,8,0)
CDictCan di -> CDictCan (di { di_ev = ev})
CIrredCan ir -> CIrredCan (ir { ir_ev = ev})
CEqCan eq -> CEqCan (eq { eq_ev = ev})
CNonCanonical _ -> CNonCanonical ev
#else
ctX -> ctX { cc_ev = ev }
#endif
return ctN
fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint (NatEquality ct _ _ _) = ct
fromSolverConstraint (NatInequality ct _ _ _ _) = ct
lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs = do
ExtraDefs <$> look ''GHC.TypeLits.Extra.Max
<*> look ''GHC.TypeLits.Extra.Min
<*> pure typeNatDivTyCon
<*> pure typeNatModTyCon
<*> look ''GHC.TypeLits.Extra.FLog
<*> look ''GHC.TypeLits.Extra.CLog
<*> look ''GHC.TypeLits.Extra.Log
<*> look ''GHC.TypeLits.Extra.GCD
<*> look ''GHC.TypeLits.Extra.LCM
<*> look ''Data.Type.Ord.OrdCond
<*> look ''GHC.TypeError.Assert
where
look nm = tcLookupTyCon =<< lookupTHName nm
lookupTHName :: TH.Name -> TcPluginM Name
lookupTHName th = do
nc <- unsafeTcPluginTcM (hsc_NC . env_top <$> getEnv)
res <- tcPluginIO $ thNameToGhcNameIO nc th
maybe (fail $ "Failed to lookup " ++ show th) return res
-- Utils
evMagic :: Ct -> Maybe EvTerm
evMagic ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> Just (evByFiat "ghc-typelits-extra" t1 t2)
IrredPred p ->
let t1 = mkTyConApp (cTupleTyCon 0) []
co = mkUnivCo (PluginProv "ghc-typelits-extra") Representational t1 p
dcApp = evId (dataConWrapId (cTupleDataCon 0))
in Just (evCast dcApp co)
_ -> Nothing