# Meeting 2022-02-14

In [1]:
import jax
import jax.numpy as jnp

from jax import (
    # Transforms
    jit,
    grad,
    pmap,
    vmap,
    make_jaxpr,
    # Random numbers
    random,
    # Internals
    xla_computation,
)

def hr():
    """Display a horizonal line."""
    print("-" * 80)

In [2]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

## Intermediate Representations

Some function:

$$f(x) = x^2 + e^{-x}$$

In [12]:
@jit
def f(x):
    """Simple square function."""
    print(f"x = {x}")
    return x ** 2. + jnp.exp(-x)

print(f(1.0))
hr()
print(f(1.0))

x = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
1.3678795
--------------------------------------------------------------------------------
1.3678795


### JAXPR

In [4]:
make_jaxpr(f)(1.0)

{ lambda ; a:f32[]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[]. let
          d:f32[] = pow c 2.0
          e:f32[] = neg c
          f:f32[] = exp e
          g:f32[] = add d f
        in (g,) }
      name=f
    ] a
  in (b,) }

### HLO IR

In [5]:
# Get HLO intermediate representation
ir = f.lower(1.0).compiler_ir('hlo')

print(repr(ir))
hr()
print(ir.as_hlo_text())

<jaxlib.xla_extension.XlaComputation object at 0x7f5a2d5dd1b0>
--------------------------------------------------------------------------------
HloModule jit_f.0

ENTRY main.7 {
  Arg_0.1 = f32[] parameter(0)
  constant.2 = f32[] constant(2)
  power.3 = f32[] power(Arg_0.1, constant.2)
  negate.4 = f32[] negate(Arg_0.1)
  exponential.5 = f32[] exponential(negate.4)
  ROOT add.6 = f32[] add(power.3, exponential.5)
}




In [6]:
import jax.lib.xla_bridge

f_xla = xla_computation(f)(1.0)
# Retrieve the XLA backend
backend = jax.lib.xla_bridge.get_backend()
# Compile and optimize function
executable = backend.compile(f_xla)

print(repr(executable))
hr()
print(executable.hlo_modules()[0].to_string())

<jaxlib.xla_extension.Executable object at 0x7f5a925a62b0>
--------------------------------------------------------------------------------
HloModule xla_computation_f.13

%fused_computation (param_0.1: f32[]) -> f32[] {
  %param_0.1 = f32[] parameter(0)
  %multiply.1 = f32[] multiply(f32[] %param_0.1, f32[] %param_0.1), metadata={op_type="pow" op_name="xla_computation(f)/jit(f)/pow" source_file="/tmp/ipykernel_5746/3066188181.py" source_line=4}
  %negate.1 = f32[] negate(f32[] %param_0.1), metadata={op_type="neg" op_name="xla_computation(f)/jit(f)/neg" source_file="/tmp/ipykernel_5746/3066188181.py" source_line=4}
  %exponential.1 = f32[] exponential(f32[] %negate.1), metadata={op_type="exp" op_name="xla_computation(f)/jit(f)/exp" source_file="/tmp/ipykernel_5746/3066188181.py" source_line=4}
  ROOT %add.1 = f32[] add(f32[] %multiply.1, f32[] %exponential.1), metadata={op_type="add" op_name="xla_computation(f)/jit(f)/add" source_file="/tmp/ipykernel_5746/3066188181.py" source_line=4}


1. What is the difference between the two cells above?

Relevant discussion [here](https://github.com/google/jax/discussions/7068).

### MHLO - MLIR Dialect

In [7]:
# Get MHLO intermediate representation
ir = f.lower(1.0).compiler_ir('mlho')

print(repr(ir))
hr()
print(ir)

<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f5a2adbd8f0>
--------------------------------------------------------------------------------
module @jit_f.1 {
  func public @main(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
    %1 = mhlo.power %arg0, %0 : tensor<f32>
    %2 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
    %3 = "mhlo.exponential"(%2) : (tensor<f32>) -> tensor<f32>
    %4 = mhlo.add %1, %3 : tensor<f32>
    return %4 : tensor<f32>
  }
}



In [8]:
# Create a vector of 1s
v = jnp.array([1.0, 2.0, 3.0])
# Get MHLO intermediate representation
ir = f.lower(v).compiler_ir()

print(repr(ir))
hr()
print(ir)

<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f5a2adbc870>
--------------------------------------------------------------------------------
module @jit_f.2 {
  func public @main(%arg0: tensor<3xf32>) -> tensor<3xf32> {
    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<3xf32>
    %2 = mhlo.power %arg0, %1 : tensor<3xf32>
    %3 = "mhlo.negate"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
    %4 = "mhlo.exponential"(%3) : (tensor<3xf32>) -> tensor<3xf32>
    %5 = mhlo.add %2, %4 : tensor<3xf32>
    return %5 : tensor<3xf32>
  }
}



In [9]:
# Generate a 2x2 identity matrix
I = jnp.eye(2)
# Get MHLO intermediate representation
ir = f.lower(I).compiler_ir()

print(repr(ir))
hr()
print(ir)

<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f5a2adf96f0>
--------------------------------------------------------------------------------
module @jit_f.8 {
  func public @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<2x2xf32>
    %2 = mhlo.power %arg0, %1 : tensor<2x2xf32>
    %3 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
    %4 = "mhlo.exponential"(%3) : (tensor<2x2xf32>) -> tensor<2x2xf32>
    %5 = mhlo.add %2, %4 : tensor<2x2xf32>
    return %5 : tensor<2x2xf32>
  }
}



## PMAP attempt

In [13]:
from jax import random

# Generate input arrays
key = random.PRNGKey(0)
A = jnp.ones(10)
B = jnp.ones(10)

try:
    # PMAP dot product operation
    dot_pmap = pmap(jnp.dot)(A, B)
except ValueError as error:
    print(error)

compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1)


## MPI4JAX & JAX Primitives

In [14]:
import mpi4jax
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

@jax.jit
def foo(arr):
   arr = arr + rank
   arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
   return arr_sum

make_jaxpr(foo)(1.0)

{ lambda ; a:f32[]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[]. let
          d:f32[] = add c 0.0
          e:Tok = create_token 
          f:f32[] _:Tok = allreduce_mpi[
            comm=<mpi4jax._src.utils.HashableMPIType object at 0x7f5a1876c280>
            op=<mpi4jax._src.utils.HashableMPIType object at 0x7f5a1876c130>
            transpose=False
          ] d e
        in (f,) }
      name=foo
    ] a
  in (b,) }

## STAX, FLAX, TRAX

- [STAX][1]: Experimental NN library inside JAX repo.
- [FLAX][2]: NN library with functional style based on JAX by Google Brain team. (2.6k stars)
- [TRAX][3]: Fully-featured NN library based on JAX by Google Brain team. (6.8k stars)

[1]: https://github.com/google/jax/tree/main/jax/example_libraries
[2]: https://github.com/google/flax
[3]: https://github.com/google/trax

## 1D Convolution

In [None]:
key = random.PRNGKey(0)
array = jnp.ones((10,))
mask = jnp.array([1.0, 2.0, 1.0])

jnp.convolve(array, mask)

## Plans

- [ ] Continue studying JAX implementation
- [ ] Study basics of XLA/HLO
- [ ] Write minimal example using PMAP/PJIT @ local/cluster
- [ ] Write minimal example using PMAP/PJIT @ Google Collab