# Towards an Applicative for macro

(This post is Jupyter notebook you can download and run. You can download it [here]())

Some time ago with the release of the Haskell's [Haxl](https://hackage.haskell.org/package/haxl) library the same authors realized that the `do` notation in Haskell could use some improvements. The `do` notation can be used to "flatten" deep nestings of `>>=` applications. For example instead of writting something like this:

```haskell
findVal :: String -> Maybe Int
findVal key = ... -- some function definition

sum :: Maybe Int
sum = (findVal "key1") >>= (\val1 -> (findVal "key2") >>= (\val2 -> return (val1 + val2))))
```
you can write:

```haskell
sum :: Maybe Int
sum = do val1 <- findVal "key1"
         val2 <- findVal "key2"
         return (val1 + val2)
```

This makes it easier to read the expression and to understand what's going on. Haskell's compiler simply desugars this notation to `>>=` applications. The `do` notation is specially readable because you can interpret it as follows: At the left side of a `<-` you will find the "extracted" value of the monadic value at the right side.

Now as you may have noticed in the previous example we are not using the whole power of the `Monad` typeclass. We are just calling `>>=` because we want to have the values inside the monadic values in the same place so we can express our desired computation: `(val1 + val2)`. But `Monad` is more powerful than that! For example we can express something like this:

```haskell
do val1 <- findVal "key1"
   val2 <- findVal ("key" ++ show val1)
   return val2
```

In this case the second monadic value depends on the result of the first monadic value. Unlike the previous example this can only be achieved with something as powerful as `Monad`.

When programming we want to use the least powerful abstraction and in functional programming this translates into using the least specific typeclass. In this case the least powerful abstraction that allows us to join values in different contexts into one is `Applicative`. We can rewrite the first example like this:

```haskell
add x y = x + y

sum = add <$> (findVal "key1") <*> (findVal "key2")
```

If you are thinking this is less readable than the `do` version you are on to something. As it turns out for the `Maybe` monad there is not much difference between choosing one style over the other because both alternatives have a similar performance. But there are other monads in which it is preferrable to use `Applicative` when possible. The first one that comes to my mind is precisely Haxl's `Fetch` monad. With `Fetch` applicative independent computations can be done concurrently. Even more if they access the same data source then the query can be batched.

For instance in the following situation:

```haskell
fetchData :: String -> Fetch Data
fetchData key = ... -- build fetch value

join x y = (x,y)

dataTuple :: (Data,Data)
dataTuple = join <$> fetchData "key1" <*> fetchData "key2"
```

The same final result can be expressed with `do`:

```haskell
do data1 <- fetchData "key1"
   data2 <- fetchData "key2"
   return (data1,data2)
```

But there is a catch: given that in this situation the `do` expression is desugared into `>>=` applications then the fetches are done sequentially. This is because the first monadic value must be computed before computing the one in the second line. Thus, by trying to use a more intuitive and readable notation we are incurring in a performance loss. 

But the problem is not just readability. When you are making changes to a codebase you must be aware that the change may allow you to use `Applicative` instead of monad. This may happen if, for example you are changing a line of a `do` expression.

Wouldn't it be cool if `do` expressions would use `Monad` just when they are strictly necessary and `Applicative` where possible? This would allow us to always use the `do` notation without having to worry if we are using the correct typeclass. This is precisely the topic of Haxl's follow-up [paper](http://research.microsoft.com/en-us/um/people/simonpj/papers/list-comp/applicativedo.pdf). There is a [proposal](https://ghc.haskell.org/trac/ghc/wiki/ApplicativeDo) in the Haskell community to include this behaviour in the GHC.

Now, let's see in Scala how useful this could be. First remember that Scala's `for` comprehensions are similar to Haskell's `do` expressions: they just call `flatMap` for each `<-` except for the last one which will be a `map` call.

Let's start with our very own definition of a `Validation` which will be similar to Scalaz's `ValidationNel` or Cat's `ValidatedNel`:

In [1]:
sealed trait Validation[+A] {
    def flatMap[B](f: A => Validation[B]): Validation[B] = this match {
        case Success(value) => f(value)
        case Failure(error) => Failure(error)
    }
    def map[B](f: A => B): Validation[B] = flatMap(a => Success(f(a)))
}
case class Success[A](value: A) extends Validation[A]
case class Failure(errors: List[String]) extends Validation[Nothing]

defined [32mtrait [36mValidation[0m
defined [32mclass [36mSuccess[0m
defined [32mclass [36mFailure[0m

Now let's define the `Applicative` typeclass and then describe the instance for `Validation`:

In [2]:
trait Applicative[F[_]] {
    def pure[A](a: A): F[A]
    def map2[A,B,C](fa: F[A], fb: F[B])(f: (A,B) => C): F[C]
}

implicit object ValidationApplicative extends Applicative[Validation] {
    def pure[A](a: A) = Success(a)
    def map2[A,B,C](va: Validation[A], vb: Validation[B])(f: (A,B) => C): Validation[C] = {
        (va, vb) match {
            case (Success(a) , Success(b) ) => Success(f(a,b))
            case (Failure(ea), Failure(eb)) => Failure(ea ++ eb)
            case (Failure(ea), _          ) => Failure(ea)
            case (_          , Failure(eb)) => Failure(eb)            
        }
    }
}

defined [32mtrait [36mApplicative[0m
defined [32mobject [36mValidationApplicative[0m

Our definition of `Applicative` is a little bit different from the usual formulation which describes a function `ap` (with type `F[A => B] => F[A] => F[B]`). As it turns out both formulations are equivalent: you can convince yourself by implementing one in terms of the other.

Another thing to notice is that when `flatMap` fails it will fail with the error of the first `Validation` value. In contrast the `Applicative` instance says that if both of the values are failed then we can accumulate them.

Let's see an example. First with `flatMap`:

In [3]:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))

val withFlatMap = for {
    x <- v1
    y <- v2
} yield x + y

[36mv1[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error1))
[36mv2[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error2))
[36mwithFlatMap[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error1))

In this case the error in the result value is the one in the first validation. But given that the value `v2` doesn't depend on the value `x` couldn't we report also the failure of `v2`? Let's see  what `Applicative` can do:

In [4]:
val withApplicative = ValidationApplicative.map2(v1,v2)(_ + _)

[36mwithApplicative[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error1, error2))

Now, let's imagine we have a web form with a bunch of fields, each one of which has to be validated. But when the form is submitted and contains errors we don't want to bother the user by just reporting the first error. We would like to report the majority of independent errors:

```scala
for {
    okFirstName <- validateFirstNameField
    okLastName  <- validateLastNameField
    okFullName  <- validateFullName(okFirstName, okLastName)
    okAge       <- validateAge
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)
```

If we used a for comprehension we would be making a mistake because if there is one error only that one is going to be returned. We can use Applicative and some syntactic sugar like the one in Scalaz to get something like this:

```scala
for {
    (okFirstName, okLastName, okAge) <- (validateFirstNameField |@| validateLastNameField |@| validateAge).tupled
    okFullName                       <- validateFullName(okFirstName, okLastName)
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)
```

This works but it may be less readable. More importantly this is more brittle to future changes and is coupled to the current computation structure. You can imagine what may happen with more fields and more complex dependencies between those fields.

## Towards an Applicative macro

I would be very useful if this could be done automatically by the compiler. In Scala the for comprehensions syntax is just another phase of the compiler. So when a macro inspects this code it will be already desugared into a nested sequence of `flatMap`s and `map`s. Let's see if we can build a macro that replaces `flatMap`s and `map`s by `map2`s when possible. 

We are going to start with a very simple example:

In [5]:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))

for {
    x <- v1
    y <- v2
} yield x + y

[36mv1[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error1))
[36mv2[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error2))
[36mres4_2[0m: [32mValidation[0m[[32mInt[0m] = Failure(List(error1))

Let's inspect the tree generated by this for comprehension:

In [6]:
import scala.reflect.runtime.universe._

val tree = reify {
    for {
        x <- v1
        y <- v2
    } yield x + y
}.tree

showRaw(tree)

[32mimport [36mscala.reflect.runtime.universe._[0m
[36mtree[0m: [32mreflect[0m.[32mruntime[0m.[32mpackage[0m.[32muniverse[0m.[32mTree[0m = cmd5.$ref$cmd4.v1.flatMap(((x) => cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y)))))
[36mres5_2[0m: [32mString[0m = [32m"""
Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v1")), TermName("flatMap")), List(Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)), Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y"))))))))))
"""[0m

That's a lot. The part that interests us is the `flatMap` application:

In [7]:
val Apply(Select(firstMonadicValue, TermName("flatMap")), List(functionDef)) = tree

[36mfirstMonadicValue[0m: [32mTree[0m = cmd5.$ref$cmd4.v1
[36mfunctionDef[0m: [32mTree[0m = ((x) => cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y))))

Now we must separate the function into it's argument and it's body:

In [8]:
showRaw(functionDef)
val Function(List(firstArgumentTerm), functionBody) = functionDef 
//               ^ only works for functions of arity one, which works for map and flatMap

[36mres7_0[0m: [32mString[0m = [32m"""
Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)), Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y"))))))))
"""[0m
[36mfirstArgumentTerm[0m: [32mValDef[0m = val x = _
[36mfunctionBody[0m: [32mTree[0m = cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y)))

If the `functionBody` calls `map` or `flatMap` over some expression, then we must identify if that expression uses `firstTermArgument`. If that's the case then we can't do anything, one expression depends on the other and `flatMap` is the right choice. But if not, then that's an opportunity to use `Applicative`'s `map2` instead of `flatMap`. Let's first define a function `usesTerm` that will indicate if a term is used in a expression:

In [9]:
def usesTerm(term: ValDef, exp: Tree): Boolean = {
    val ValDef(_,termName,_,_) = term
    exp.find{
        case Ident(_termName) if termName == _termName => 
            true
        case _ => 
            false
    }.isDefined
}

defined [32mfunction [36musesTerm[0m

Let's separate the `functionBody` into two parts: the second monadic value and the next function definition:

In [10]:
showRaw(functionBody)
val Apply(Select(secondMonadicValue, TermName("map")), List(secondFunctionDef)) = functionBody
showRaw(firstArgumentTerm)

[36mres9_0[0m: [32mString[0m = [32m"""
Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))))
"""[0m
[36msecondMonadicValue[0m: [32mTree[0m = cmd5.$ref$cmd4.v2
[36msecondFunctionDef[0m: [32mTree[0m = ((y) => x.$plus(y))
[36mres9_2[0m: [32mString[0m = [32m"""
ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)
"""[0m

And we are interested in answering the question: is the "extracted" value for the first monad (`firstArgumentTerm`) being used when defining the second monadic value?

In [11]:
usesTerm(firstArgumentTerm, secondMonadicValue)

[36mres10[0m: [32mBoolean[0m = [32mfalse[0m

As we can see for computing the second term we don't need the function argument. So we want to transform this `flatMap`->`map` call into an `Applicative`'s `map2` call. For this we will need to extract the innermost expression of the for comprehension, that is the `x+y` expression. After that we will have to build the `map2` call passing the appropiate arguments. First, let's extract that expression from `nextFunctionDef`: 

In [12]:
showRaw(secondFunctionDef)
val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef

[36mres11_0[0m: [32mString[0m = [32m"""
Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))
"""[0m
[36msecondArgumentTerm[0m: [32mValDef[0m = val y = _
[36minnerExpr[0m: [32mTree[0m = x.$plus(y)

These are all the ingredients we need:

In [13]:
showRaw(firstMonadicValue)
showRaw(firstArgumentTerm)
showRaw(secondMonadicValue)
showRaw(secondArgumentTerm)
showRaw(innerExpr)

[36mres12_0[0m: [32mString[0m = [32m"""
Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v1"))
"""[0m
[36mres12_1[0m: [32mString[0m = [32m"""
ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)
"""[0m
[36mres12_2[0m: [32mString[0m = [32m"""
Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2"))
"""[0m
[36mres12_3[0m: [32mString[0m = [32m"""
ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)
"""[0m
[36mres12_4[0m: [32mString[0m = [32m"""
Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y"))))
"""[0m

And let's combine them with `ValidationApplicative.map2`:

In [14]:
val result = q"""
ValidationApplicative.map2(
    $firstMonadicValue,
    $secondMonadicValue,
    ${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
)
"""
showRaw(result)

[36mresult[0m: [32mTree[0m = ValidationApplicative.map2(cmd5.$ref$cmd4.v1, cmd5.$ref$cmd4.v2, ((x, y) => x.$plus(y)))
[36mres13_1[0m: [32mString[0m = [32m"""
Apply(Select(Ident(TermName("ValidationApplicative")), TermName("map2")), List(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v1")), Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree), ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))))
"""[0m

Unfortunately I haven't found a way to compile and execute expressions from a REPL. I think there is the function `compile` in the [`ToolBox`](http://www.scala-lang.org/api/2.11.0-RC4/scala-compiler/index.html#scala.tools.reflect.ToolBox) trait but I haven't found a way to use it from a REPL. That'd be useful for a faster development process. If you know how, please tell me!

Anyway, what comes next is just the code I wrote in a sbt project. You can find the full working code [here](https://github.com/miguel-vila/applicative-for/tree/afa1ea6316bd77e4b4401a1e97d584786e21a655). Putting it all together here is our macro implementation:

```scala
def app_for_impl(c: Context)(valid: c.Expr[M]): c.Expr[M] = {
    import c.universe._

    val Apply(
        TypeApply(Select(firstMonadicValue, TermName("flatMap")),_), 
        List(functionDef)
    ) = valid.tree
    val Function(List(firstArgumentTerm), functionBody) = functionDef
    val Apply(
            TypeApply(Select(secondMonadicValue, TermName("map")),_), 
            List(secondFunctionDef)
    ) = functionBody
    if(usesTerm(c.universe)(firstArgumentTerm, secondMonadicValue)) {
      valid
    } else {
      val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef
      c.Expr(q"""_root_.appfor.Validation.applicativeInstance.map2(
          $firstMonadicValue,
          $secondMonadicValue
      )(
          ${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
      )""")
    }
}
```

As you may have noticed there are some differences with respect to what we did above: when matching the function calls the first argument has different shape (it's wrapped in a `TypeApply` object). Other than that it's mostly the same as what we did above. I don't know the cause for these differences with respect to our REPL / Jupyter session.

Let's see the macro in action:

```scala
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))
  
val resultWhenApplicative = app_for {
    for {
      x <- v1
      y <- v2
    } yield x + y
}
println(resultWhenApplicative)
// > Failure(List(error1, error2))
```

This is what we wanted! We used a for comprehension but our macro detected that `map2` could be used!

Now let's test it with an expression that really needs `flatMap`:

```scala
val resultWhenMonad = app_for {
    for {
      x <- v1
      y <- if(x>0) v2 else v1
    } yield x + y
}

println(resultWhenMonad)
// Failure(List(error1))
```

It works, just like a a normal for comprehension!

## What's next?

This macro is not fully functional though. It doesn't account for a lot of situations:

* The pattern matches are unsafe and don't produce good error messages when they fail. 
* This works just for two level expressions, but for comprehensions can be longer. 
* This macro doesn't account for more subtle uses of a monad result like pattern matches. 
* There is the issue of generalizing this for any type and not just for our own `Validation`.
* Also I'm not sure if the pattern matches are as general as they can be.

And I may be missing some other things.

The next steps will be to fix one by one each of these problems.