# XABY is functional machine learning

XABY is a machine learning library designed to be super readible and flexible. It's pronounced with a z/zed and [all caps](https://www.youtube.com/watch?v=gSJeHDlhYls) please.

It's focused on assembling pure functions into one big pure function. It's functional! And, it provides a new way to develop machine learning models. PyTorch and Tensorflow 2.0 have converged on very similar APIs, so here's something different. XABY is built on top of JAX as well, the new cool kid in the machine learning world.

Everything in XABY is a function that takes an ArrayList as input. Many functions return other functions. There are functions that collect functions together into one higher-level function. And all of these functions take ArrayLists.

ArrayLists are the basic data structure in XABY, they're really just a slightly fancy Python list. ArrayLists collect JAX arrays and allow for passing data between functions in the most easy manner. There are also a lot of functions for manipulating ArrayLists to produce new ArrayLists.

**Note:** XABY is very much a prototype. Expect names of things to change.

In [1]:
import xaby as xb      # Base XABY things
import xaby.nn as xn   # For neural networks 🧠



To start out, let's pack up some arrays. `xb.pack` returns an ArrayList.

In [2]:
arr_list = xb.pack(xb.array([1., 2, 3]), xb.array([3,4,5,6]))
print(arr_list)

ArrayList:
DeviceArray([1., 2., 3.], dtype=float32)
DeviceArray([3, 4, 5, 6], dtype=int32)


This is just a list, you can append to it, index into it, iterate over it, etc. XABY provides a functions for working with ArrayLists. For example, `xb.select` allows you to select specific arrays from a list and return a new ArrayList with those arrays. Here's some fancy syntax:

In [3]:
arr_list >> xb.select(0)

ArrayList:
DeviceArray([1., 2., 3.], dtype=float32)

What's happening here is that `xb.select(0)` returns a function that selects the first item in `arr_list`, as another ArrayList. Every function that would return arrays will always return those arrays packed into an ArrayList. 

Note that you can also call use `xb.select` like this but it's a bit more boring in my opinion:

In [4]:
xb.select(0)(arr_list)

ArrayList:
DeviceArray([1., 2., 3.], dtype=float32)

`select` allows you to repeat arrays as well, it's super handy.

In [5]:
arr_list >> xb.select(0, 1, 0, 1)

ArrayList:
DeviceArray([1., 2., 3.], dtype=float32)
DeviceArray([3, 4, 5, 6], dtype=int32)
DeviceArray([1., 2., 3.], dtype=float32)
DeviceArray([3, 4, 5, 6], dtype=int32)

Why this notation with `>>`? Often when you're building deep learning models you get many nested functions. In normal Python, the first function executed is the most inner function, the last function executed is the outer-most. So when you read a common operation in a neural network like 

```python
x = sigmoid(fc2(relu(fc1(x))))
```

you're reading left to right, from the last function to the first. To make this code more readible, I built XABY so functions are executed in the order they are written:

```python
x = x >> fc1 >> relu >> fc2 >> sigmoid
```

Let's see this in action.

In [6]:
inputs = xb.pack(xb.array([1., 2., 3, 4]),
                 xb.array([3., 4, 5, 6]))


# This is the mean squared error!
inputs >> xb.sub >> xb.power(y=2) >> xb.mean(axis=None)

ArrayList:
DeviceArray(4., dtype=float32)

Let's step through this.
- `xb.sub` expects an input with two arrays, subtracts them element-wise, returning an ArrayList with one array
- `xb.power(y=2)` calculates $x^y$. It's actually a function that returns a function. So here, I called it with `y=2`. This returns another function that expects a single array (in an ArrayList) and returns each element raised to the power of 2.
- `xb.mean(axis=None)` returns a function that calculates the mean of an array. It also expects only one array as input.

In general, whenever you see a function that is called with some parameters, it will return another function that accepts an ArrayList. Functions that don't require extra parameters are called like `inputs >> xb.sub`.

So, that's fun, yeah? You might like this then! We can build a reuseable mean squared error function like so:

In [7]:
# Compose multiple functions into one function
mse = xb.sub >> xb.power(y=2) >> xb.mean(axis=None)

inputs >> mse

ArrayList:
DeviceArray(4., dtype=float32)

As you see, `mse` is itself a function that accepts ArrayLists. Under the hood, each of the functions are called in order and the whole thing is compiled using JAX's just-in-time compiler. This makes it super fast! Much of what you do with XABY is compose functions from other functions and pass in an ArrayList with the expected inputs.

## What about neural networks?

Deep learning (known to many as "AI") has produced amazing results over the last 6 or 7 years. I built XABY as a deep learning framework, so now I'll show you how to implement and train deep learning models. If you don't know much about deep learning, but want to understand what's going on here, check out [this free course](https://www.udacity.com/course/deep-learning-pytorch--ud188) I helped create at Udacity.

Here's a super simple feedforward network with one hidden layer. It takes 10 input features and returns a binary classification probability for each example.

In [14]:
model = xn.linear(10, 5) >> xn.relu >> xn.linear(5, 1) >> xn.sigmoid

That's it, that's the model. It's a function that accepts an ArrayList, it expects one array. Behind the scenes, `xn.linear` is initializing parameters that are passed to the forward functions. These parameters are collected by the composed function (a `sequential` function) in the `model.params` attribute. You call the model function like any other function.

In [9]:
# Make a random 7x10 array, pack it up
x = xb.pack(xb.random.uniform((7, 10)))

# Forward pass through the model
x >> model

ArrayList:
DeviceArray([[0.49970198],
             [0.50403184],
             [0.42720988],
             [0.5106314 ],
             [0.45148745],
             [0.46111318],
             [0.487898  ]], dtype=float32)

Cool, you built a model and can get predictions out. But for it to be useful, you need to train it. For this you need two things: a loss function to measure error and a way to update the parameters so the loss is minimized.

Common loss functions are provided in `xaby.nn.losses`. The model returns probabilities for binary classification, so we should use the binary cross-entropy loss, `xaby.nn.losses.binary_cross_entropy_loss`. This is just another function like the others. It expects two inputs: an array of probabilities between 0 and 1, and an array of binary labels, 0 or 1. We need to compose a function that takes input data and returns the loss. The function should look something like this:

![Binary cross entropy loss diagram](assets/loss_diagram.png)

XABY has multiple functions for composing other functions in various configurations. Here is how you would compose a loss function using the model defined above and the `binary_cross_entropy_loss` function.

In [10]:
# This produces another function, it's all functions!
loss = xb.split(model, xb.skip) >> xn.losses.binary_cross_entropy_loss()

Let's step through what's going on here:
- `xb.split` splits the input ArrayList into individual arrays and maps the arrays to the given functions
- `[features, targets] >> split(model, skip)` is equivalent to `[features >> model, targets >> skip]`
- `xb.skip` simply returns the input ArrayList, it's a no-op.
- `xb.split` packs the output of each function into an ArrayList

Putting all this together, `xb.split(model, xb.skip)` creates a function that takes an ArrayList `[features, labels]` and returns an ArrayList `[probabilities, labels]`. Again, much of the work with XABY is composing functions like this.

Now I'll make some fake data and calculate the loss.

In [11]:
# Make a random 7x10 array for features and another for targets
features = xb.random.uniform((7, 10))
targets = xb.random.bernoulli((7, 1))
inputs = xb.pack(features, targets)

# You can calculate the loss just like any other function
print(f"Loss: {inputs >> loss}")

Loss: 0.7117464542388916


With the loss, you need a way to update the parameters. In comes backpropagation. The idea here is you get the gradients of all the model parameters with respect to the loss. Then we use these gradients to update the model parameters using stochastic gradient descent (SGD). Again, if you don't know what I'm talking about, check out [this deep learning course](https://www.udacity.com/course/deep-learning-pytorch--ud188). It's pretty straightforward to get the gradients with XABY.

In [12]:
# Get the gradients and the loss for this batch
batch_loss, grads = loss << inputs

This returns the loss and the gradients for all the function parameters. We can update the `loss` function parameters with the gradients using `xaby.optim.sgd`.

In [13]:
# This returns a function
update = xb.optim.sgd(lr=0.003)

# Update the loss function parameters with the gradients
update(loss, grads)

# Calculate the loss again, it should be lower now
print(f"Loss: {inputs >> loss}")

Loss: 0.7115735411643982


Parameter updates are propagated to every sub-function so the model function is updated as well. You'd want to loop through a dataset in batches, updating the model for each batch. But this tutorial is long enough. I've created two more notebooks that show you how to train image classifiers on the MNIST dataset, [check them out!](https://github.com/mcleonard/xaby/tree/master/examples)