# 2022-02-07

In [1]:
import jax.numpy as jnp
from jax import jit, grad, vmap, pmap, devices

## Large Buffer Allocation

In [2]:
devices()

[GpuDevice(id=0, process_index=0)]

In [3]:
try:
    # Allocate 8 GiB buffer
    buffer = jnp.zeros(2 * 1024 ** 3)
except RuntimeError as error:
    print(f"Error: {error}")

Error: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8589934592 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         4B
              constant allocation:         0B
        maybe_live_out allocation:    8.00GiB
     preallocated temp allocation:         0B
                 total allocation:    8.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 8.00GiB
		Operator: op_name="jit(broadcast_in_dim)/jit(main)/broadcast_in_dim[\n  broadcast_dimensions=()\n  shape=(2147483648,)\n]" source_file="/tmp/ipykernel_19209/4271869147.py" source_line=3
		XLA Label: broadcast
		Shape: f32[2147483648]

	Buffer 2:
		Size: 4B
		Entry Parameter Subshape: f32[]




2022-02-07 11:34:27.637587: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:462] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.00GiB (rounded to 8589934592)requested by op 
2022-02-07 11:34:27.637646: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:474] *___________________________________________________________________________________________________
2022-02-07 11:34:27.637695: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2089] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8589934592 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         4B
              constant allocation:         0B
        maybe_live_out allocation:    8.00GiB
     preallocated temp allocation:         0B
                 total allocation:    8.00GiB
              total fragmentation:         0B (0.00%)
Peak buffe

## Just-in-Time Tracing

In [6]:
@jit
def square(x):
    print(f"x = {x}")
    return x ** 2

print(square(1.0))
print("-------------------------")
print(square(1.0))
print("-------------------------")
print(square(jnp.arange(10)))

x = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
1.0
-------------------------
1.0
-------------------------
x = Traced<ShapedArray(int32[10])>with<DynamicJaxprTrace(level=0/1)>
[ 0  1  4  9 16 25 36 49 64 81]


## JAX Misconceptions

Automatic differentiation is actually done analitically. JAX knows all the operations in the `jax.numpy` namespace and is able to apply the chain rule to differentiate user code, including those with control flow. ([Source](https://youtu.be/z-WSrQDXkuM))

## MPI4JAX

Supported operations:

- `send`
- `recv`
- `sendrecv`
- `bcast`
- `gather` (!!!)
- `scatter`  (!!!)
- `reduce`
- `allgather`
- `allreduce`
- `alltoall`
- `scan`

### Implementation

1. Python module, registering a new primitive with JAX.
   - **Abstract evaluation rules** are used by the compiler to infer the output shapes and data types without running the actual computation.
   - **Translation rules** determine the specific computational kernel and prepare the input buffers.
2. A Cython function that casts raw input arguments passed by XLA to their true C type, so they can be passed on to MPI.

## FEDJAX

> Federated learning is a machine learning setting where many clients collaboratively train a model
under the orchestration of a central server, while keeping the training data decentralized. Clients
can be either mobile devices or whole organizations depending on the task at hand.

Client/server architecture with naive data partitioning:

```python
for_each_client = fedjax.for_each_client(
    client_init=lambda server_params, _: server_params,
    client_step=(
        lambda params, batch: params - grad_fn(params, batch) * 0.1),
    client_final=lambda server_params, params: server_params - params)
```

## JAX PJIT

> It takes in an XLA program that represents the complete neural net, as if there is only one giant virtual device. In addition to the program, it also takes in partitioning specifications for both function inputs and outputs. The output of the XLA SPMD partitioner is an identical program for N devices that performs communications between devices through collective operations. The program only compiles once per host. **Pjit is the API exposed for the XLA SPMD partitioner in JAX**. ([Source](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html))

![XLA SPMD](https://jax.readthedocs.io/en/latest/_images/xla_spmd.jpg)

## JAX PMAP

TODO!