Skip to content

Commit

Permalink
MonadSTM: added TArray
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Sep 23, 2022
1 parent 41f7240 commit 38a24f0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 10 deletions.
1 change: 1 addition & 0 deletions io-classes/io-classes.cabal
Expand Up @@ -52,6 +52,7 @@ library
ScopedTypeVariables
RankNTypes
build-depends: base >=4.9 && <4.18,
array,
async >=2.1,
bytestring,
deque,
Expand Down
101 changes: 92 additions & 9 deletions io-classes/src/Control/Monad/Class/MonadSTM.hs
Expand Up @@ -2,6 +2,7 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
Expand Down Expand Up @@ -31,6 +32,8 @@ module Control.Monad.Class.MonadSTM
, TQueueDefault (..)
-- * Default 'TBQueue' implementation
, TBQueueDefault (..)
-- * Default 'TArray' implementation
, TArrayDefault (..)
-- * MonadThrow aliases
, throwSTM
, catchSTM
Expand All @@ -46,6 +49,7 @@ module Control.Monad.Class.MonadSTM

import Prelude hiding (read)

import qualified Control.Concurrent.STM.TArray as STM
import qualified Control.Concurrent.STM.TBQueue as STM
import qualified Control.Concurrent.STM.TMVar as STM
import qualified Control.Concurrent.STM.TQueue as STM
Expand All @@ -65,7 +69,13 @@ import qualified Control.Monad.Class.MonadThrow as MonadThrow

import Control.Applicative (Alternative (..))
import Control.Exception
import Data.Array (Array, bounds)
import qualified Data.Array as Array
import Data.Array.Base (IArray (numElements), MArray (..),
arrEleBottom, listArray, unsafeAt)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.Ix (Ix, rangeSize)
import Data.Kind (Type)
import Data.Typeable (Typeable)
import GHC.Stack
Expand Down Expand Up @@ -148,6 +158,8 @@ class ( Monad m
isFullTBQueue :: TBQueue m a -> STM m Bool
unGetTBQueue :: TBQueue m a -> a -> STM m ()

type TArray m :: Type -> Type -> Type

-- Helpful derived functions with default implementations

newTVarIO :: a -> m (TVar m a)
Expand Down Expand Up @@ -314,15 +326,19 @@ newEmptyTMVarM = newEmptyTMVarIO
--
class MonadSTM m
=> MonadLabelledSTM m where
labelTVar :: TVar m a -> String -> STM m ()
labelTMVar :: TMVar m a -> String -> STM m ()
labelTQueue :: TQueue m a -> String -> STM m ()
labelTBQueue :: TBQueue m a -> String -> STM m ()

labelTVarIO :: TVar m a -> String -> m ()
labelTMVarIO :: TMVar m a -> String -> m ()
labelTQueueIO :: TQueue m a -> String -> m ()
labelTBQueueIO :: TBQueue m a -> String -> m ()
labelTVar :: TVar m a -> String -> STM m ()
labelTMVar :: TMVar m a -> String -> STM m ()
labelTQueue :: TQueue m a -> String -> STM m ()
labelTBQueue :: TBQueue m a -> String -> STM m ()
labelTArray :: (Ix i, Show i)
=> TArray m i e -> String -> STM m ()

labelTVarIO :: TVar m a -> String -> m ()
labelTMVarIO :: TMVar m a -> String -> m ()
labelTQueueIO :: TQueue m a -> String -> m ()
labelTBQueueIO :: TBQueue m a -> String -> m ()
labelTArrayIO :: (Ix i, Show i)
=> TArray m i e -> String -> m ()

--
-- default implementations
Expand All @@ -340,6 +356,13 @@ class MonadSTM m
=> TBQueue m a -> String -> STM m ()
labelTBQueue = labelTBQueueDefault

default labelTArray :: ( TArray m ~ TArrayDefault m
, Ix i
, Show i
)
=> TArray m i e -> String -> STM m ()
labelTArray = labelTArrayDefault

default labelTVarIO :: TVar m a -> String -> m ()
labelTVarIO = \v l -> atomically (labelTVar v l)

Expand All @@ -352,6 +375,10 @@ class MonadSTM m
default labelTBQueueIO :: TBQueue m a -> String -> m ()
labelTBQueueIO = \v l -> atomically (labelTBQueue v l)

default labelTArrayIO :: (Ix i, Show i)
=> TArray m i e -> String -> m ()
labelTArrayIO = \v l -> atomically (labelTArray v l)


-- | This type class is indented for 'io-sim', where one might want to access
-- 'TVar' in the underlying 'ST' monad.
Expand Down Expand Up @@ -511,6 +538,7 @@ instance MonadSTM IO where
type TMVar IO = STM.TMVar
type TQueue IO = STM.TQueue
type TBQueue IO = STM.TBQueue
type TArray IO = STM.TArray

newTVar = STM.newTVar
readTVar = STM.readTVar
Expand Down Expand Up @@ -566,11 +594,13 @@ instance MonadLabelledSTM IO where
labelTMVar = \_ _ -> return ()
labelTQueue = \_ _ -> return ()
labelTBQueue = \_ _ -> return ()
labelTArray = \_ _ -> return ()

labelTVarIO = \_ _ -> return ()
labelTMVarIO = \_ _ -> return ()
labelTQueueIO = \_ _ -> return ()
labelTBQueueIO = \_ _ -> return ()
labelTArrayIO = \_ _ -> return ()

-- | noop instance
--
Expand Down Expand Up @@ -910,6 +940,47 @@ unGetTBQueueDefault (TBQueue rsize read wsize _write _size) a = do
writeTVar read (a:xs)


--
-- Default `TArray` implementation
--

-- | Default implementation of 'TArray'.
--
data TArrayDefault m i e = TArray (Array i (TVar m e))
deriving Typeable

deriving instance (Eq (TVar m e), Ix i) => Eq (TArrayDefault m i e)

instance (Monad stm, MonadSTM m, stm ~ STM m)
=> MArray (TArrayDefault m) e stm where
getBounds (TArray a) = return (bounds a)
newArray b e = do
a <- rep (rangeSize b) (newTVar e)
return $ TArray (listArray b a)
newArray_ b = do
a <- rep (rangeSize b) (newTVar arrEleBottom)
return $ TArray (listArray b a)
unsafeRead (TArray a) i = readTVar $ unsafeAt a i
unsafeWrite (TArray a) i e = writeTVar (unsafeAt a i) e
getNumElements (TArray a) = return (numElements a)

rep :: Monad m => Int -> m a -> m [a]
rep n m = go n []
where
go 0 xs = return xs
go i xs = do
x <- m
go (i-1) (x:xs)

labelTArrayDefault :: ( MonadLabelledSTM m
, Ix i
, Show i
)
=> TArrayDefault m i e -> String -> STM m ()
labelTArrayDefault (TArray arr) name = do
let as = Array.assocs arr
traverse_ (\(i, v) -> labelTVar v (name ++ ":" ++ show i)) as

-- | 'throwIO' specialised to @stm@ monad.
--
throwSTM :: (MonadSTM m, MonadThrow.MonadThrow (STM m), Exception e)
Expand Down Expand Up @@ -1021,6 +1092,8 @@ instance MonadSTM m => MonadSTM (ContT r m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (ContT r m) = TArray m


instance MonadSTM m => MonadSTM (ReaderT r m) where
type STM (ReaderT r m) = WrappedSTM Reader r m
Expand Down Expand Up @@ -1074,6 +1147,8 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (ReaderT r m) = TArray m


instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
type STM (WriterT w m) = WrappedSTM Writer w m
Expand Down Expand Up @@ -1127,6 +1202,8 @@ instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (WriterT w m) = TArray m


instance MonadSTM m => MonadSTM (StateT s m) where
type STM (StateT s m) = WrappedSTM State s m
Expand Down Expand Up @@ -1180,6 +1257,8 @@ instance MonadSTM m => MonadSTM (StateT s m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (StateT s m) = TArray m


instance MonadSTM m => MonadSTM (ExceptT e m) where
type STM (ExceptT e m) = WrappedSTM Except e m
Expand Down Expand Up @@ -1233,6 +1312,8 @@ instance MonadSTM m => MonadSTM (ExceptT e m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (ExceptT e m) = TArray m


instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
type STM (RWST r w s m) = WrappedSTM RWS (r, w, s) m
Expand Down Expand Up @@ -1286,6 +1367,8 @@ instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
isFullTBQueue = WrappedSTM . isFullTBQueue
unGetTBQueue = WrappedSTM .: unGetTBQueue

type TArray (RWST r w s m) = TArray m


(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
(f .: g) x y = f (g x y)
3 changes: 2 additions & 1 deletion io-sim/src/Control/Monad/IOSim/Types.hs
Expand Up @@ -76,7 +76,7 @@ import Control.Monad.Class.MonadMVar
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadSTM (MonadInspectSTM (..),
MonadLabelledSTM (..), MonadSTM, MonadTraceSTM (..),
TMVarDefault, TraceValue)
TArrayDefault, TMVarDefault, TraceValue)
import qualified Control.Monad.Class.MonadSTM as MonadSTM
import Control.Monad.Class.MonadSay
import Control.Monad.Class.MonadTest
Expand Down Expand Up @@ -398,6 +398,7 @@ instance MonadSTM (IOSim s) where
type TMVar (IOSim s) = TMVarDefault (IOSim s)
type TQueue (IOSim s) = TQueueDefault (IOSim s)
type TBQueue (IOSim s) = TBQueueDefault (IOSim s)
type TArray (IOSim s) = TArrayDefault (IOSim s)

atomically action = IOSim $ oneShot $ \k -> Atomically action k

Expand Down

0 comments on commit 38a24f0

Please sign in to comment.