-
Notifications
You must be signed in to change notification settings - Fork 4
/
TRef.scala
70 lines (49 loc) · 2.22 KB
/
TRef.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package com.olegpy.stm
import cats.InvariantMonoidal
import cats.data.State
import cats.effect.Sync
import internal.TRefImpl
import cats.implicits._
import cats.effect.concurrent.Ref
trait TRef[A] extends Ref[STM, A] {
def get: STM[A]
def set(a: A): STM[Unit]
def update(f: A => A): STM[Unit] = get >>= (f >>> set)
def updateF(f: A => STM[A]): STM[Unit] = get >>= f >>= set
def updOrRetry(f: PartialFunction[A, A]): STM[Unit] =
get.collect(f) >>= set
def getAndSet(a: A): STM[A] = get <* set(a)
def access: STM[(A, A => STM[Boolean])] = get.tupleRight(set(_).as(true))
def tryUpdate(f: A => A): STM[Boolean] = update(f).as(true)
def tryModify[B](f: A => (A, B)): STM[Option[B]] = modify(f).map(_.some)
def modify[B](f: A => (A, B)): STM[B] = get.map(f).flatMap { case (a, b) => set(a) as b }
def tryModifyState[B](state: State[A, B]): STM[Option[B]] = modifyState(state).map(_.some)
def modifyState[B](state: State[A, B]): STM[B] = modify(state.run(_).value)
protected[stm] def unsafeLastValue: A
override def toString: String = s"TRef($unsafeLastValue)"
}
object TRef {
def apply[A](initial: A): STM[TRef[A]] = STM.delay(new TRefImpl(initial))
def in[F[_]]: InPartiallyApplied[F] = new InPartiallyApplied[F]
final class InPartiallyApplied[F[_]](private val dummy: Boolean = false) extends AnyVal {
def apply[A](initial: A)(implicit F: Sync[F]): F[TRef[A]] =
STM.unsafeToSync(TRef(initial))
}
implicit val invariantMonoidal: InvariantMonoidal[TRef] = new InvariantMonoidal[TRef] {
val unit: TRef[Unit] = new TRef[Unit] {
def get: STM[Unit] = STM.unit
def set(a: Unit): STM[Unit] = STM.unit
protected[stm] def unsafeLastValue: Unit = ()
}
def imap[A, B](fa: TRef[A])(f: A => B)(g: B => A): TRef[B] = new TRef[B] {
def get: STM[B] = fa.get map f
def set(a: B): STM[Unit] = fa.set(g(a))
protected[stm] def unsafeLastValue: B = f(fa.unsafeLastValue)
}
def product[A, B](fa: TRef[A], fb: TRef[B]): TRef[(A, B)] = new TRef[(A, B)] {
def get: STM[(A, B)] = fa.get product fb.get
def set(a: (A, B)): STM[Unit] = fa.set(a._1) *> fb.set(a._2)
protected[stm] def unsafeLastValue: (A, B) = (fa.unsafeLastValue, fb.unsafeLastValue)
}
}
}