Skip to content

Commit

Permalink
io-sim-classes: Eq instances for stm's mutable varibles
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Oct 26, 2021
1 parent f96d61d commit d6b3b0b
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 2 deletions.
143 changes: 142 additions & 1 deletion io-classes/src/Control/Monad/Class/MonadSTM.hs
@@ -1,11 +1,19 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}

module Control.Monad.Class.MonadSTM
( MonadSTM (..)
, MonadSTMTx (..)
, MonadSTMVars
, MonadLabelledSTM (..)
, MonadLabelledSTMTx (..)
, LazyTVar
Expand All @@ -14,6 +22,14 @@ module Control.Monad.Class.MonadSTM
, TMVar
, TQueue
, TBQueue
, EqTVar
, eqTVar
, EqTMVar
, eqTMVar
, EqTQueue
, eqTQueue
, EqTBQueue
, eqTBQueue

-- * Default 'TMVar' implementation
, TMVarDefault (..)
Expand Down Expand Up @@ -82,6 +98,7 @@ import Control.Applicative (Alternative (..))
import Control.Exception
import Control.Monad.Reader
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import GHC.Stack
import Numeric.Natural (Natural)

Expand All @@ -91,10 +108,18 @@ import Numeric.Natural (Natural)
type LazyTVar m = TVar m
type LazyTMVar m = TMVar m

-- The STM primitives
class ( forall a. Eq (tvar a)
, forall a. Eq (tmvar a)
, forall a. Eq (tqueue a)
, forall a. Eq (tbqueue a)
) => MonadSTMVars tvar tmvar tqueue tbqueue

-- | The STM primitive operations and associated types.
--
class ( Monad stm
, Alternative stm
, MonadPlus stm
, MonadSTMVars (TVar_ stm) (TMVar_ stm) (TQueue_ stm) (TBQueue_ stm)
) => MonadSTMTx stm where
type TVar_ stm :: Type -> Type

Expand Down Expand Up @@ -175,6 +200,9 @@ type TMVar m = TMVar_ (STM m)
type TQueue m = TQueue_ (STM m)
type TBQueue m = TBQueue_ (STM m)

-- | 'MonadSTM' provides the @'STM' m@ monad as well as all operations to
-- execute it.
--
class (Monad m, MonadSTMTx (STM m)) => MonadSTM m where
-- STM transactions
type STM m :: Type -> Type
Expand Down Expand Up @@ -210,6 +238,114 @@ newEmptyTMVarM :: MonadSTM m => m (TMVar m a)
newEmptyTMVarM = newEmptyTMVarIO
{-# DEPRECATED newEmptyTMVarM "Use newEmptyTMVarIO" #-}

-- context for the `eqTVar'`, `eqTMVar'`, etc.
type MonadSTMCtx stm tvar tmvar tqueue tbqueue
= ( MonadSTMVars tvar tmvar tqueue tbqueue
, TVar_ stm ~ tvar
, TMVar_ stm ~ tmvar
, TQueue_ stm ~ tqueue
, TBQueue_ stm ~ tbqueue
, MonadSTMTx stm
)

type EqTVar m tvar
= ( tvar ~ TVar m
, forall a. Eq (tvar a)
)

eqTVar' :: forall stm tvar tmvar tqueue tbqueue a.
MonadSTMCtx stm tvar tmvar tqueue tbqueue
=> Proxy stm
-> tvar a -> tvar a -> Bool
eqTVar' _ = (==)
{-# INLINE eqTVar' #-}


-- | Polymorphic equality of 'TVar's.
--
eqTVar :: forall m a.
MonadSTM m
=> Proxy m
-> TVar m a -> TVar m a -> Bool
eqTVar p = eqTVar' (f p)
where
f :: Proxy m -> Proxy (STM m)
f Proxy = Proxy


type EqTMVar m tvar tmvar tqueue tbqueuue
= ( tmvar ~ TMVar m
, forall a. Eq (tmvar a)
)

eqTMVar' :: forall stm tvar tmvar tqueue tbqueue a.
MonadSTMCtx stm tvar tmvar tqueue tbqueue
=> Proxy stm
-> tmvar a -> tmvar a -> Bool
eqTMVar' _ = (==)
{-# INLINE eqTMVar' #-}


-- | Polymorphic equality of 'TMVar's.
--
eqTMVar :: forall m a.
MonadSTM m
=> Proxy m
-> TMVar m a -> TMVar m a -> Bool
eqTMVar p = eqTMVar' (f p)
where
f :: Proxy m -> Proxy (STM m)
f Proxy = Proxy


type EqTQueue m tqueue
= ( tqueue ~ TQueue m
, forall a. Eq (tqueue a)
)

eqTQueue' :: forall stm tvar tmvar tqueue tbqueue a.
MonadSTMCtx stm tvar tmvar tqueue tbqueue
=> Proxy stm
-> tqueue a -> tqueue a -> Bool
eqTQueue' _ = (==)
{-# INLINE eqTQueue' #-}


-- | Polymorphic equality of 'TQueue's.
--
eqTQueue :: forall m a.
MonadSTM m
=> Proxy m
-> TQueue m a -> TQueue m a -> Bool
eqTQueue p = eqTQueue' (f p)
where
f :: Proxy m -> Proxy (STM m)
f Proxy = Proxy


type EqTBQueue m tbqueue
= ( tbqueue ~ TBQueue m
, forall a. Eq (tbqueue a)
)

eqTBQueue' :: forall stm tvar tmvar tqueue tbqueue a.
MonadSTMCtx stm tvar tmvar tqueue tbqueue
=> Proxy stm
-> tbqueue a -> tbqueue a -> Bool
eqTBQueue' _ = (==)
{-# INLINE eqTBQueue' #-}


-- | Polymorphic equality of 'TBQueue's.
--
eqTBQueue :: forall m a.
MonadSTM m
=> Proxy m
-> TBQueue m a -> TBQueue m a -> Bool
eqTBQueue p = eqTBQueue' (f p)
where
f :: Proxy m -> Proxy (STM m)
f Proxy = Proxy

-- | Labelled 'TVar's, 'TMVar's, 'TQueue's and 'TBQueue's.
--
Expand Down Expand Up @@ -245,6 +381,8 @@ class (MonadSTM m, MonadLabelledSTMTx (STM m))
-- Instance for IO uses the existing STM library implementations
--

instance MonadSTMVars STM.TVar STM.TMVar STM.TQueue STM.TBQueue

instance MonadSTMTx STM.STM where
type TVar_ STM.STM = STM.TVar
type TMVar_ STM.STM = STM.TMVar
Expand Down Expand Up @@ -354,6 +492,7 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where

newtype TMVarDefault m a = TMVar (TVar m (Maybe a))


labelTMVarDefault
:: MonadLabelledSTM m
=> TMVarDefault m a -> String -> STM m ()
Expand Down Expand Up @@ -446,6 +585,7 @@ isEmptyTMVarDefault (TMVar t) = do
data TQueueDefault m a = TQueue !(TVar m [a])
!(TVar m [a])


labelTQueueDefault
:: MonadLabelledSTM m
=> TQueueDefault m a -> String -> STM m ()
Expand Down Expand Up @@ -518,6 +658,7 @@ data TBQueueDefault m a = TBQueue
!(TVar m [a]) -- written elements
!Natural


labelTBQueueDefault
:: MonadLabelledSTM m
=> TBQueueDefault m a -> String -> STM m ()
Expand Down
18 changes: 17 additions & 1 deletion io-classes/src/Control/Monad/Class/MonadSTM/Strict.hs
Expand Up @@ -3,7 +3,11 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-- 'Eq' instance of 'StrictTVar'
{-# LANGUAGE UndecidableInstances #-}

-- to preserve 'HasCallstack' constraint on 'checkInvariant'
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
Expand All @@ -14,6 +18,7 @@ module Control.Monad.Class.MonadSTM.Strict
, LazyTMVar
-- * 'StrictTVar'
, StrictTVar
, eqTVar
, labelTVar
, labelTVarIO
, castStrictTVar
Expand Down Expand Up @@ -61,8 +66,10 @@ import Control.Monad.Class.MonadSTM as X hiding (LazyTMVar, LazyTVar,
newTMVarM, newTVar, newTVarIO, newTVarM, putTMVar,
readTMVar, readTVarIO, readTVar, stateTVar, swapTVar,
swapTMVar, takeTMVar, tryPutTMVar, tryReadTMVar,
tryTakeTMVar, writeTVar)
tryTakeTMVar, writeTVar, eqTVar)
import qualified Control.Monad.Class.MonadSTM as Lazy
import Data.Function (on)
import Data.Proxy (Proxy)
import GHC.Stack

{-------------------------------------------------------------------------------
Expand All @@ -88,6 +95,15 @@ newtype StrictTVar m a = StrictTVar
}
#endif

instance Eq (LazyTVar m a) => Eq (StrictTVar m a) where
(==) = on (==) tvar

eqTVar :: forall m a.
MonadSTM m
=> Proxy m
-> StrictTVar m a -> StrictTVar m a -> Bool
eqTVar p = on (Lazy.eqTVar p) tvar

labelTVar :: MonadLabelledSTM m => StrictTVar m a -> String -> STM m ()
labelTVar StrictTVar { tvar } = Lazy.labelTVar tvar

Expand Down
26 changes: 26 additions & 0 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Expand Up @@ -14,6 +14,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE QuantifiedConstraints #-}

{-# OPTIONS_GHC -Wno-orphans #-}
-- incomplete uni patterns in 'schedule' (when interpreting 'StmTxCommitted')
Expand Down Expand Up @@ -58,6 +59,7 @@ import Data.Bifoldable
import Data.Bifunctor
import Data.Dynamic (Dynamic, toDyn)
import Data.Foldable (traverse_)
import Data.Function (on)
import qualified Data.List as List
import qualified Data.List.Trace as Trace
import Data.Map.Strict (Map)
Expand Down Expand Up @@ -349,6 +351,27 @@ instance MonadFork (IOSim s) where
instance MonadSay (STMSim s) where
say msg = STM $ \k -> SayStm msg (k ())

instance Eq (TMVarDefault (IOSim s) a) where
(==) = (==) `on` \(TMVar v) -> v

instance Eq (TQueueDefault (IOSim s) a) where
TQueue read write == TQueue read' write' =
read == read'
&& write == write'

instance Eq (TBQueueDefault (IOSim s) a) where
TBQueue rsize read wsize write size
== TBQueue rsize' read' wsize' write' size' =
rsize == rsize'
&& read == read'
&& wsize == wsize'
&& write == write'
&& size == size'

instance ( tvar ~ TVar_ (STM s)
, forall a. Eq (tvar a)
) => MonadSTMVars tvar (TMVarDefault (IOSim s)) (TQueueDefault (IOSim s)) (TBQueueDefault (IOSim s))

instance MonadSTMTx (STM s) where
type TVar_ (STM s) = TVar s
type TMVar_ (STM s) = TMVarDefault (IOSim s)
Expand Down Expand Up @@ -1373,6 +1396,9 @@ data TVar s a = TVar {
tvarBlocked :: !(STRef s ([ThreadId], Set ThreadId))
}

instance Eq (TVar s a) where
(==) = on (==) tvarId

data StmTxResult s a =
-- | A committed transaction reports the vars that were written (in order
-- of first write) so that the scheduler can unblock other threads that
Expand Down

0 comments on commit d6b3b0b

Please sign in to comment.