## Matrix multiplication 

We are allowed to use:

- Python
- matplotlib
- The Python standard library
- JAX standard library
- Jupyter notebooks and nbdev

We can also use `miniai` and`torch` for comparison

In [1]:
from pathlib import Path
import pickle, gzip, math, os, time, shutil, matplotlib as mpl, matplotlib.pyplot as plt, lovely_tensors as lt
import torch

In [2]:
import lovely_tensors as lt
import torch
lt.monkey_patch()

Idea 💡: someone should port `lovely_tensors` to `jax`

Let's create an utility function to show memory allocation in the GPU.

In [3]:
import pynvml
def get_memory_free_MiB(gpu_index):
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(int(gpu_index))
    mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return mem_info.free // 1024 ** 2

In [4]:
get_memory_free_MiB(0)

11176

## Get data

In [5]:
MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'

In [6]:
from urllib.request import urlretrieve
if not path_gz.exists(): urlretrieve(MNIST_URL, path_gz)

In [7]:
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

In [8]:
(x_train.shape, y_train.shape), (x_valid.shape, y_valid.shape), type(x_train)

(((50000, 784), (50000,)), ((10000, 784), (10000,)), numpy.ndarray)

In [9]:
get_memory_free_MiB(0)

11176

## Random array initialization 

### Pytorch

In [10]:
tw = torch.randn(784,10)
%timeit -n 10 tw = torch.randn(784,10)

84.3 µs ± 39.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
tw.dtype

torch.float32

In [12]:
get_memory_free_MiB(0)

11176

While we don't explicitly move to cuda, the tensor variable is in cpu

### JAX

It seems that JAX allocates everything it can at once. [`JAX will preallocate 90% of the total GPU memory when the first JAX operation is run.`](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html)
Let's change the default to 50%

In [13]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".50"

AFAIK this is the way to check if jax is running in GPU, TPU or CPU:

In [14]:
import jax
import jax.numpy as jnp 
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [15]:
cpus = jax.devices("cpu")
gpus = jax.devices("gpu")

In [16]:
get_memory_free_MiB(0)

11036

In [17]:
key = jax.random.PRNGKey(42)
jw = jax.random.normal(key, shape=(784,10), dtype=jax.numpy.float32)
get_memory_free_MiB(0)

5412

In [18]:
%timeit -n 10 jw = jax.random.normal(key, shape=(784,10), dtype=jax.numpy.float32)

534 µs ± 93.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
jw.shape, jw.device(), (jw.min(), jw.max()), (jw.mean(), jw.std())

((784, 10),
 StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 (Array(-3.7302895, dtype=float32), Array(3.9276645, dtype=float32)),
 (Array(0.01403106, dtype=float32), Array(0.9955147, dtype=float32)))

In [20]:
534/84.3

6.334519572953737

`PyTorch` random initialization is 6.33 times faster than `JAX`.

## Matrix multiplication

In [21]:
# works in pytorch, jax and numpy
def matmul(x,y):
    return x@y

### Pytorch + Cuda

In [22]:
tw = torch.randn(784,10)
tx_train = torch.tensor(x_train)
tx_train.shape, type(tx_train)

(torch.Size([50000, 784]), torch.Tensor)

In [23]:
%time txc, twc = tx_train.cuda(), tw.cuda() # measure JAX device transfer time

CPU times: user 1.54 s, sys: 768 ms, total: 2.31 s
Wall time: 3.37 s


In [24]:
txc.shape, twc.shape, txc.device, twc.device

(torch.Size([50000, 784]),
 torch.Size([784, 10]),
 device(type='cuda', index=0),
 device(type='cuda', index=0))

In [25]:
%timeit -n 1000 matmul(txc, twc)

565 µs ± 61 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### JAX

In [26]:
from jax import jit

In [28]:
%time jx_train= jax.numpy.array(x_train) # measure JAX device transfer time

CPU times: user 6.57 ms, sys: 106 ms, total: 113 ms
Wall time: 109 ms


In [29]:
jx_train.shape, type(jx_train), jx_train.device(), jw.shape, type(jw), jw.device()

((50000, 784),
 jaxlib.xla_extension.Array,
 StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 (784, 10),
 jaxlib.xla_extension.Array,
 StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0))

In [30]:
matmul_jit = jax.jit(matmul)
%time matmul_jit(jx_train, jw).block_until_ready() # measure JAX compilation time

CPU times: user 571 ms, sys: 502 ms, total: 1.07 s
Wall time: 2.37 s


Array([[ -0.97251654,  -6.734766  ,  -1.7148048 , ...,  -7.3052707 ,
          4.7699347 ,  -9.978209  ],
       [ -3.7587132 , -15.646385  ,   4.6581864 , ..., -12.779354  ,
          8.308025  ,   1.2251611 ],
       [ -8.072983  ,  -1.9900167 ,   4.9306164 , ...,  -1.5675635 ,
          9.269543  ,   5.2164707 ],
       ...,
       [  5.3260746 ,   2.7311437 ,  12.856297  , ...,   5.9363794 ,
         -5.454983  , -11.519852  ],
       [  9.113716  ,  -7.591759  ,   4.178947  , ...,  10.593879  ,
         -3.862001  ,  11.375757  ],
       [ 11.984879  ,  -3.51389   ,  15.2327    , ...,  16.66311   ,
         -2.9007533 ,  -5.151217  ]], dtype=float32)

In [31]:
%timeit -n 100 matmul_jit(jx_train, jw).block_until_ready()

905 µs ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [32]:
905/565

1.6017699115044248

In this test, PyTorch was 1.6 times faster than JAX for matrix multiplication.

Note for self: Still trying to figure out how I achieved 44 microseconds for PyTorch in previous run.  Probably an error