**A simple bandit algorithm**

*Initialize, for a = 1 to k:*

   $$Q(a) \leftarrow 0$$
   $$N(a) \leftarrow 0$$

*Repeat forever*:

$$A \leftarrow
\begin{cases}
argmax_a Q(a), & \text{with probability } 1 - \epsilon \text{ (breaking ties randomly)}\\
\text{a random action}, & \text{with probability } \epsilon \\
\end{cases}\\
R \leftarrow bandit(A)\\
N(A) \leftarrow N(A) + 1\\
Q(A) \leftarrow Q(A) + \cfrac{1}{N(A)}[R - Q(A)]$$


In [2]:
type Action = Int
type Reward = Int
type NthTry = Int
type Estimate = Double
type Epsilon = Double

trait SimpleBanditAlgorithm {

  def epoch = (1 to 2000)
  def eps: Epsilon = 0.1
  def nActions = 10

  import scala.util.Random
  def bandit(a: Action): Reward = a match {
    case 0 => 5 //Random.nextInt(20)
    case 5 => 98 //Random.nextInt(maxReward)
    case 7 => 14
    case 9 => 70 //Random.nextInt(30)  
    case _ => 60 //Random.nextInt(maxReward)
  }

  def randomAction: Action = Random.nextInt(nActions)
  def distribution(eps: Epsilon): Boolean = Random.nextDouble() < eps

  def argmax(values: Map[Action, Estimate]): Action = if (values.nonEmpty){
    values.iterator.maxBy(_._2)._1
  } else randomAction
    
  final case class State(n: NthTry = 0, accR: List[Reward] = List.empty, q: Map[Action, Estimate] = Map.empty)

  def step(state: State): State = {
    val action = 
      if (distribution(eps)) randomAction else argmax(state.q)
    val reward = bandit(action)
    val nUpdate = state.n + 1
    val previousEstimate = state.q.getOrElse(action, 0.0)
    val updatedEstimate = 
      previousEstimate + (reward - previousEstimate) / nUpdate
    val qUpdate = state.q.updated(action, updatedEstimate)
    State(nUpdate, reward :: state.accR, qUpdate)
  }

  val results = epoch
    .scanLeft(State())((state, _) => step(state))

}

defined [32mtype[39m [36mAction[39m
defined [32mtype[39m [36mReward[39m
defined [32mtype[39m [36mNthTry[39m
defined [32mtype[39m [36mEstimate[39m
defined [32mtype[39m [36mEpsilon[39m
defined [32mtrait[39m [36mSimpleBanditAlgorithm[39m

The average reward for different $\epsilon$:

In [9]:
//------------------------Plotly---------------------------

import $ivy.`org.plotly-scala::plotly-jupyter-scala:0.3.0`

import plotly._
import plotly.element._
import plotly.layout._
import plotly.JupyterScala._

plotly.JupyterScala.init()

def algByEps(epsilon: Epsilon) = new SimpleBanditAlgorithm {
  override def eps = epsilon
}

val plot = Seq(0.1, 0.01, 0) map { eps =>
  val (x, y) = algByEps(eps).results.map(s => s.n -> s.accR.sum / (s.accR.size + 1)).unzip
  Scatter(x, y, name = eps.toString)  
}

def drawPlot(): Unit = plot.plot()

//----------------------Vegas---------------------------------

import $ivy.`org.vegas-viz::vegas:0.3.6`

import vegas._
import vegas.render.HTMLRenderer._

implicit val displayer: String => Unit = publish.html(_)


val vegasPlotData = for {
  eps <- Seq(0.1, 0.01, 0)
  s <- algByEps(eps).results
} yield Map(
  "eps" -> eps.toString, 
  "step" -> s.n, 
  "average reward" -> s.accR.sum / (s.accR.size + 1)
) 

def drawPlotWithVegas(): Unit = Vegas("Plot")
  .withData(vegasPlotData)
  .encodeX("step", Quant)
  .encodeY("average reward", Quant)
  .encodeDetailFields(Field(field = "eps", dataType = Nominal))
  .encodeColor(
       field = "eps",
       dataType = Nominal,
       legend = vegas.Legend(orient = "left", title = "epsilon"))
  .mark(vegas.Line)
  .show

[32mimport [39m[36m$ivy.$                                             

[39m
[32mimport [39m[36mplotly._
[39m
[32mimport [39m[36mplotly.element._
[39m
[32mimport [39m[36mplotly.layout._
[39m
[32mimport [39m[36mplotly.JupyterScala._

[39m
defined [32mfunction[39m [36malgByEps[39m
[36mplot[39m: [32mSeq[39m[[32mScatter[39m] = [33mList[39m(
  [33mScatter[39m(
    [33mSome[39m(
      [33mDoubles[39m(
        [33mVector[39m(
          [32m0.0[39m,
          [32m1.0[39m,
          [32m2.0[39m,
          [32m3.0[39m,
          [32m4.0[39m,
          [32m5.0[39m,
          [32m6.0[39m,
[33m...[39m
defined [32mfunction[39m [36mdrawPlot[39m
[32mimport [39m[36m$ivy.$                           

[39m
[32mimport [39m[36mvegas._
[39m
[32mimport [39m[36mvegas.render.HTMLRenderer._

[39m
[36mdisplayer[39m: [32mString[39m => [32mUnit[39m = <function1>
[36mvegasPlotData[39m: [32mSeq[39m[[32mMap[39m[[32mString[39m, [32m

In [4]:
drawPlot()

Same Plot rendered with Vegas library:

In [10]:
drawPlotWithVegas()