<a href="https://colab.research.google.com/github/christianmerkwirth/colabs/blob/master/Python_HPC_Cupy%2C_JAX%2C_Numba%2C_Cython.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HPC Python code with JAX, Numba & Cupy

This Colab notebook shows example of high performance numerical code in Python. It is based strongly on a notebook published by Wolfgang Resch (https://github.com/NIH-HPC/python-in-hpc), which itself follows a [gist](https://gist.github.com/jfpuget/60e07a82dece69b011bb) published by Jean-François Puget fairly closely which in turn was inspired by an IBM developerworks [article](https://www.ibm.com/developerworks/community/blogs/jfp/entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en) from 2015. The idea is to start with a plain python implementation of the [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set), profile the code, and evaluate different approaches to optimizing the performance.

If you want to run all the implementations the Colab needs to be connected to a GPU kernel. Nothing more should be required.

Author: Christian Merkwirth

## Setup

We'll need numpy and numba. Matplotlib is used for plotting the sets.

In [1]:
import sys
import os
import copy
import inspect
import logging
import math
from types import ModuleType
from typing import Optional, Any, List, Union
import collections
from dataclasses import dataclass, asdict

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = ".25"

%load_ext cython
import pyximport; pyximport.install(reload_support=True)

%tensorflow_version 2.x
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers
import tensorflow_addons as tfa

print(tf.__version__)

import jax
import jax.numpy as jnp
from jax import pmap, vmap

import numpy as np
import pandas as pd
import random

np.set_printoptions(linewidth=120, precision=2, suppress=True)

import numba
from numba import jit, njit, vectorize, cuda, prange, guvectorize, int32, int64, float32, float64

import tensorflow.experimental.numpy as tnp

from pandas.plotting import register_matplotlib_converters
pd.options.display.max_columns = 999
pd.options.display.max_rows = 100

%matplotlib inline
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import colors
import matplotlib.dates as mdates
import seaborn as sns

register_matplotlib_converters()

sns.set_context("notebook", font_scale=1.)
sns.set_style("whitegrid")
%config InlineBackend.figure_format = 'retina'

from IPython.core.pylabtools import figsize
figsize(12, 11)

!/usr/local/cuda/bin/nvcc --version

2.4.0
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243


# Challenge 1: Mandelbrot Set



The plotting function draws a pretty picture which is a quick way to verify that the results returned by the various `mandel_set` function implementations are correct. The first argument is a `mandel_set` function that returns a 2D array(-like) object that can be plotted.

In [2]:
def mandelbrot_image(fun, **kwargs):
    cmap='magma'
    z = fun(**kwargs)
    
    fig, ax = plt.subplots(figsize=(7,7), dpi=150)
    plt.xticks([])
    plt.yticks([])

    if len(kwargs):
      plt.title("[{xmin}, {ymin}] to [{xmax}, {ymax}]".format(**kwargs))
    
    norm = colors.PowerNorm(0.3)
    ax.imshow(z,cmap=cmap,origin='lower',norm=norm)


# Argument sets for the mandelbrot function
fast_args = {
    'xmin': -2.0,
    'xmax': 0.5,
    'ymin': -1.25,
    'ymax': 1.25,
    'width': 1024,
    'height': 1024,
    'maxiter': 80
}

slow_args = {
    'xmin': -0.74877,
    'xmax': -0.74872,
    'ymin': 0.06505,
    'ymax': 0.06510,
    'width': 1024,
    'height': 1024,
    'maxiter': 2048
}

# Let's store benchmarking results for later use.
@dataclass
class TimingResult:
  name: str
  uses_gpu: bool
  timing_result1: Union[List, float]
  timing_result2: Union[List, float]


def get_timing(timeit_res):
  return timeit_res.best

all_results = []

## CPU based implementations

### Pure python

This is the baseline, pure python mandelbrot set. It uses a nested list as a 2D array-like object.

In [3]:
def linspace(start, stop, n):
    step = float(stop - start) / (n - 1)
    return [start + i * step for i in range(n)]

def mandel1(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z*z + c
    return n

def mandel_set1(xmin, xmax, ymin, ymax, width, height, maxiter):
    r = linspace(xmin, xmax, width)
    i = linspace(ymin, ymax, height)
    n = [[0]*width for _ in range(height)]
    for x in range(width):
        for y in range(height):
            n[y][x] = mandel1(complex(r[x], i[y]), maxiter)
    return n

We'll use two areas of the Mandelbrot set as benchmarks for the code in this notebook:

 - the first, `(xmin, xmax, ymin, ymax, width, height, maxiter)` runs quicker and is used in each section of the notebook
 - the second `(xmin=-0.74877, xmax=-0.74872, ymin=0.06505, ymax=0.06510, width=1000, height=1000, maxiter=2048)` requires considerable more compute time.
 - Other sets covering the same area as set2 but at a higher resolution may be used in some circumstances

In [4]:
t1 = %timeit -o mandel_set1(**fast_args)

1 loop, best of 3: 3.88 s per loop


In [None]:
t2 = %timeit -o mandel_set1(**slow_args)
#t2 = t1

In [None]:
all_results.append(
    TimingResult(
        name='Pure Python',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

Let's plot the two areas - the second plot will take a while.

In [None]:
mandelbrot_image(mandel_set1, **fast_args)

In [None]:
#mandelbrot_image(mandel_set1, **slow_args)

So what is taking so long? Let's profile the code:

In [None]:
p = %prun -r -q mandel_set1(**fast_args)
p.stream = sys.stdout
p.sort_stats('cumulative').print_stats(5)

Simpler, but with less control over the output

In [None]:
#%prun -s cumulative mandel_set1(**fast_args)

Again, the corresponding command line would be

```shell
$ python -m cProfile -s cumulative mandel01.py
```

However, for this to work, the script needs to be executable.

The profile above shows that most of the runtime is spent in the `mandel1` function. Let's get a line-by-line profile of that function

In [None]:
#%lprun -f mandel1 mandel_set1(**fast_args)

There is an algorithmic improvement that could be made here: Most of the time in this function is spent on calculating the absolute value and the value of the next iteration. Both of those involve redundant steps that can be factored out. Later on we will implement this, but for now let's try the simplest approach.

### Implementation: JIT compiling with numba

[Numba](https://numba.pydata.org/) is a `numpy` aware dynamic Python compiler using LLVM. It can speed up math-heavy, array oriented code with just some minor annotations. At the top of the notebook we imported the `jit` decorator from the numba package. The only modification to the code is decorating the `mandel` function. Note that numba cannot jit compile the current implementation of the `mandel_set` function due to the use of nested lists.

In [None]:
@jit(nopython=True)
def mandel2(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z*z + c
    return n

def mandel_set2(xmin, xmax, ymin, ymax, width, height, maxiter):
    r = linspace(xmin, xmax, width)
    i = linspace(ymin, ymax, height)
    n = [[0]*width for _ in range(height)]
    for x in range(width):
        for y in range(height):
            n[y][x] = mandel2(complex(r[x], i[y]), maxiter)
    return n

# warm up jit
_ = mandel_set2(**fast_args)

In [None]:
t1 = %timeit -o mandel_set2(**fast_args)

In [None]:
t2 = %timeit -o mandel_set2(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Numba Simple',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

Now, reprofiling shows that the majority of the remaining runtime is spent in the `mandel_set` function

In [None]:
%prun -s cumulative mandel_set2(**fast_args)

Let's do a line profile on that

In [None]:
#%lprun -f mandel_set2 mandel_set2(**fast_args)

### Implementation 3: Numpy arrays

[Numpy](http://www.numpy.org/) is the fundamental Python array computation package. Numpy arrays are much more efficient that python lists or arrays. So does changing from lists to numpy arrays in the `mandel_set` function improve performance?

In [None]:
@njit(fastmath=True)
def mandel3(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z*z + c
    return 0

def mandel_set3(xmin, xmax, ymin, ymax, width, height, maxiter):
    r = np.linspace(xmin, xmax, width)
    i = np.linspace(ymin, ymax, height)
    n = np.empty((height, width), dtype=int)
    for x in range(width):
        for y in range(height):
            n[y, x] = mandel3(complex(r[x], i[y]), maxiter)
    return n

# warm up jit
_ = mandel_set3(**fast_args)

In [None]:
t1 = %timeit -o mandel_set3(**fast_args)

In [None]:
t2 = %timeit -o mandel_set3(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Numba & Numpy',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

No it did not - it actually is a bit slower. This might change with larger arrays. But either way, now we can jit compile it.

### Implementation: Numba jit compile the second function

Again, the only thing that has to change in the code is the `@jit` decorator. Note that for `np.empty`, `dtype=int` fails

In [None]:
@njit(fastmath=True)
def mandel4(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z*z + c
    return n

@njit
def mandel_set4(xmin, xmax, ymin, ymax, width, height, maxiter):
    r = np.linspace(xmin, xmax, width)
    i = np.linspace(ymin, ymax, height)
    n = np.empty((height, width), dtype=np.int32)
    for x in range(width):
        for y in range(height):
            n[y, x] = mandel4(complex(r[x], i[y]), maxiter)
    return n

# warm up jit
_ = mandel_set4(**fast_args)

In [None]:
t1 = %timeit -o mandel_set4(**fast_args)

In [None]:
t2 = %timeit -o mandel_set4(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Numba all functions',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

Note: We can't line_profile the jit functions

### Implementation: Improve the math and run the code in parallel

The following are the definition of the square of complex numbers, the absolute value, and the sum of two complex numbers:

$$
\begin{aligned}
(a + bi)^2 &= (a + bi)(a + bi) = (a^2 - b^2) + 2abi \\
|a + bi| &= \sqrt{a^2 + b^2} \\
(a + bi) + (c + di) &= (a + c) + (b + d)i
\end{aligned}
$$

Based on this, we can factor out some redundant calculations by passing the real and imaginary parts of the complex numbers directly to the `mandel` function and calculating their squares separately only once and avoiding the square root computation.

In [None]:
@njit(fastmath=True)
def mandel5(creal, cimag, maxiter):
    real = creal
    imag = cimag
    for n in range(maxiter):
        real2 = real*real
        imag2 = imag*imag
        if real2 + imag2 > 4.0:
            return n
        imag = 2 * real*imag + cimag
        real = real2 - imag2 + creal       
    return n

@njit(parallel=True)
def mandel_set5(xmin, xmax, ymin, ymax, width, height, maxiter):
    r = np.linspace(xmin, xmax, width)
    i = np.linspace(ymin, ymax, height)
    n = np.empty((height, width), dtype=np.int32)
    for x in prange(width):
        for y in range(height):
            n[y, x] = mandel5(r[x], i[y], maxiter)
    return n

# warm up jit
_ = mandel_set5(**fast_args)

In [None]:
t1 = %timeit -o mandel_set5(**fast_args)

In [None]:
t2 = %timeit -o mandel_set5(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Numba parallel',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

Checking that the results are still correct:

In [None]:
mandelbrot_image(mandel_set5, **fast_args)

And drawing the slower slice is much more pleasant now:

In [None]:
mandelbrot_image(mandel_set5, **slow_args)

### Implementation: Cython

[Cython](http://cython.org/) is a static compiler for Python plus the Cython extensions such as static type definitions. It generates C code that uses the Python C API to create C extensions.

How does a cythonized version of implementation 5 perform? To conveniently use cython code in a jupyter notebook, we'll load the cython extension first. Under the hood, this will generate C code, compile the extension module, and load it.

In [None]:
%%cython
import cython
import numpy as np

cdef int mandel6(const double creal, const double cimag, const int maxiter):
    cdef:
        double real2, imag2
        double real = creal, imag = cimag
        int n

    for n in range(maxiter):
        real2 = real*real
        imag2 = imag*imag
        if real2 + imag2 > 4.0:
            return n
        imag = 2* real*imag + cimag
        real = real2 - imag2 + creal;
    return n

@cython.boundscheck(False) 
@cython.wraparound(False)
cpdef mandel_set6(double xmin, double xmax, double ymin, double ymax, int width, int height, int maxiter):
    cdef:
        double[:] r1 = np.linspace(xmin, xmax, width)
        double[:] r2 = np.linspace(ymin, ymax, height)
        int[:,:] n = np.empty((height, width), np.int32)
        int i,j
    
    for i in range(width):
        for j in range(height):
            n[j,i] = mandel6(r1[i], r2[j], maxiter)
    return n

In [None]:
t1 = %timeit -o mandel_set6(**fast_args)

In [None]:
t2 = %timeit -o mandel_set6(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Cython',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set6, **slow_args)

### Implementation: Fortran

The `f2py` tool included with numpy can be used to bind fortran code to python. I don't have any fortran magic installed, so here is just the fortran code that implements the same algorithm:

In [None]:
%%writefile mandel07.f90
subroutine mandel_set7(xmin, xmax, ymin, ymax, width, height, maxiter, n)
    real(8), intent(in)   :: xmin, xmax, ymin, ymax
    integer, intent(in)   :: width, height, maxiter
    integer               :: niter
    integer, dimension(height, width), intent(out) :: n
    integer               :: x, y
    real(8)               :: xstep, ystep
    
    xstep = (xmax - xmin) / (width - 1)
    ystep = (ymax - ymin) / (width - 1)
    do x = 1, width
        do y = 1, height
            call mandel7(xmin + (x - 1) * xstep, ymin + (y - 1) * ystep, maxiter, niter)
            n(y, x) = niter
        end do
    end do
end subroutine mandel_set7

subroutine mandel7(cre, cim, itermax, n)
    real(8), intent(in)      :: cre, cim
    integer, intent(in)      :: itermax
    integer, intent(out)     :: n
    real(8)                  :: re2, im2, re, im

    re = cre
    im = cim 
    do n = 0, itermax - 1
        re2 = re ** 2
        im2 = im ** 2
        if (re2 + im2 > 4.0) then
            exit
        end if
        im = 2 * re * im + cim
        re = re2 - im2 + cre
    end do
end subroutine mandel7


Which was compiled into the mb_fort module with the following line.


In [None]:
!f2py -m mb_fort -c mandel07.f90 --fcompiler=gnu95 --opt=-O3

In [None]:
from mb_fort import mandel7, mandel_set7

In [None]:
t1 = %timeit -o mandel_set7(**fast_args)

In [None]:
t2 = %timeit -o mandel_set7(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Fortran',
        uses_gpu = False,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set7, **slow_args)

## GPU based implementations


### Implementation: Cupy

TODO: Text here


In [None]:
import cupy as cp

# Limit the total amount of GPU memory allocated by cupy.
cp.get_default_memory_pool().set_limit(4 * 1024**3)

# Note that this function does not specify and types nor sepcific packages.
def inner_loop(real, imag, grid_r, grid_i):
    real2 = real*real
    imag2 = imag*imag
    not_diverged = (real + imag) < 4.0
    imag = 2 * real*imag + grid_i
    real = real2 - imag2 + grid_r 
    return real, imag, not_diverged

mandelbrot_kernel = cp.fuse(inner_loop)

def mandel_set8(xmin=-2.0, xmax=0.5, ymin=-1.25, ymax=1.25, width=1024, height=1024, maxiter=80):
  r = cp.linspace(xmin, xmax, width)
  i = cp.linspace(ymin, ymax, height)
  grid_r, grid_i = cp.meshgrid(r, i)
  res = cp.zeros_like(grid_r)

  real = grid_r
  imag = grid_i
  for n in range(maxiter):
      real, imag, not_diverged = mandelbrot_kernel(real, imag, grid_r, grid_i)
      res += not_diverged

  return cp.asnumpy(res)

# Kick initial compilation off.
_ = mandel_set8(**fast_args)

In [None]:
t1 = %timeit -o mandel_set8(**fast_args)

In [None]:
t2 = %timeit -o mandel_set8(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Cupy',
        uses_gpu = True,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set8, **slow_args)

### Implementation: JAX XLA and JIT



In [None]:
from functools import partial

# Note that this function does not specify and types nor sepcific packages.
def inner_loop(real, imag, grid_r, grid_i):
    real2 = real*real
    imag2 = imag*imag
    not_diverged = (real + imag) < 4.0
    imag = 2 * real*imag + grid_i
    real = real2 - imag2 + grid_r 
    return real, imag, not_diverged


def outer_loop(state, grid_r, grid_i):
    [res, real, imag] = state
    real, imag, not_diverged = inner_loop(real, imag, grid_r, grid_i)
    res += not_diverged
    return [res, real, imag]


def mandel_set9(xmin=-2.0, xmax=0.5, ymin=-1.25, ymax=1.25, width=1024, height=1024, maxiter=80):
    r = jnp.linspace(xmin, xmax, width)
    i = jnp.linspace(ymin, ymax, height)
    grid_r, grid_i = jnp.meshgrid(r, i)
    res = jnp.zeros_like(grid_r)
    state = [res, grid_r, grid_i]
    res, _, _ = jax.lax.fori_loop(0, maxiter, lambda i, s: outer_loop(s, grid_r, grid_i), state)
    return res


# Unfortunately we have to hardcode command line params.
mandel_set9_fast = jax.jit(partial(mandel_set9, **fast_args))
mandel_set9_slow = jax.jit(partial(mandel_set9, **slow_args))

_ = mandel_set9_fast().block_until_ready()
_ = mandel_set9_slow().block_until_ready()


In [None]:
t1 = %timeit -o mandel_set9_fast().block_until_ready()

In [None]:
t2 = %timeit -o mandel_set9_slow().block_until_ready()

In [None]:
all_results.append(
    TimingResult(
        name='JAX',
        uses_gpu = True,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set9_slow)

### Implementation: Tensorflow

In [None]:
import collections
Loopvars = collections.namedtuple('Loopvars', 'real, imag, res')


@tf.function(experimental_compile=True, experimental_follow_type_hints=True)
def inner_loop(real: tf.Tensor, imag: tf.Tensor, grid_r: tf.Tensor, grid_i: tf.Tensor):
    real2 = real*real
    imag2 = imag*imag
    not_diverged = tf.cast(real + imag < 4.0, tf.int32)
    imag = 2 * real*imag + grid_i
    real = real2 - imag2 + grid_r
    return real, imag, not_diverged

@tf.function(experimental_compile=True, experimental_follow_type_hints=True)
def outer_loop(grid_r: tf.Tensor, grid_i: tf.Tensor, maxiter: int):
    res = tf.zeros_like(grid_r, dtype=tf.int32)
    real = tf.identity(grid_r)
    imag = tf.identity(grid_i)

    initial_vars = (tf.constant(0), Loopvars(real, imag, res))
    cond = lambda i, v: i < maxiter
    def body(i, v):
        real, imag, not_diverged = inner_loop(v.real, v.imag, grid_r, grid_i)
        return (i + 1, Loopvars(real, imag, v.res + not_diverged))

    final_vars = tf.while_loop(cond, body, initial_vars)
    return final_vars[1].res

def mandel_set10(xmin: float, xmax: float, ymin: float, ymax: float, width: int, height: int, maxiter: int):
    r = tf.linspace(xmin, xmax, width)
    i = tf.linspace(ymin, ymax, height)
    grid_r, grid_i = tf.meshgrid(r, i)
    res = outer_loop(grid_r, grid_i, maxiter)
    return res.numpy()


_ = mandel_set10(**fast_args)

In [None]:
t1 = %timeit -o mandel_set10(**fast_args)

In [None]:
t2 = %timeit -o mandel_set10(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Tensorflow',
        uses_gpu = True,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set10, **slow_args)

### Implementation: Tensorflow Numpy

In [None]:
# Note that this function does not specify and types nor sepcific packages.
@tf.function
def inner_loop(real, imag, grid_r, grid_i):
    real2 = real*real
    imag2 = imag*imag
    not_diverged = (real + imag) < 4.0
    imag = 2 * real*imag + grid_i
    real = real2 - imag2 + grid_r 
    return real, imag, not_diverged


def mandel_set11(xmin=-2.0, xmax=0.5, ymin=-1.25, ymax=1.25, width=1024, height=1024, maxiter=80):
    r = tnp.linspace(xmin, xmax, width)
    i = tnp.linspace(ymin, ymax, height)
    grid_r, grid_i = tnp.meshgrid(r, i)
    res = tnp.zeros_like(grid_r, dtype=tf.int32)

    real = grid_r
    imag = grid_i
    for n in range(maxiter):
        real, imag, not_diverged = inner_loop(real, imag, grid_r, grid_i)
        res += tf.cast(not_diverged, tf.int32)

    return res

_ = mandel_set11(**fast_args)


In [None]:
t1 = %timeit -o mandel_set11(**fast_args)

In [None]:
t2 = %timeit -o mandel_set11(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='TF Numpy',
        uses_gpu = True,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set11, **slow_args)

### Implementation: Numba & Cuda

In [None]:
import numpy
import cupy as cp
from numba import cuda
from numba import vectorize


def inner_loop(creal, cimag, maxiter):
    real = creal
    imag = cimag
    for n in range(maxiter):
        real2 = real*real
        imag2 = imag*imag
        if real2 + imag2 > 4.0:
            return n
        imag = 2 * real*imag + cimag
        real = real2 - imag2 + creal       
    return n


inner_loop_gpu = cuda.jit(device=True)(inner_loop)

@cuda.jit
def mandel_set12_gpu(M, xmin, xmax, ymin, ymax,  maxiter):
    """Calculate the Mandelbrot set on the GPU.
    Parameters
    ----------
    M : numpy.ndarray
        a two-dimensional integer array that will contain the 
        escape times for each point.
    xmin: float
        minimum value on the real axis
    xmax: float
        maximum value on the real axis
    ymin: float
        minimum value on the imaginary axis
    ymax: float
        maximum value on the imaginary axis
    """
    ny, nx = M.shape
    i, j = cuda.grid(2)
    
    if i < ny and j < nx:
        dx = (xmax - xmin) / nx
        dy = (ymax - ymin) / ny
        M[j, i] = inner_loop_gpu(xmin + dx * i, ymin + dy * j, maxiter)


def mandel_set12(xmin=-2.0, xmax=0.5, ymin=-1.25, ymax=1.25, width=1024, height=1024, maxiter=80):
  # We define M as a Cupy device array on GPU memory.
  M = cp.zeros((height, width), dtype=numpy.int32)
  block = (32, 32)
  grid = (M.shape[0] // block[0] if M.shape[0] % block[0] == 0 
              else M.shape[0] // block[0] + 1,
          int(M.shape[0] // block[1] if M.shape[1] % block[1] == 0 
              else M.shape[1] // block[1] + 1))
  mandel_set12_gpu[grid, block](M, xmin, xmax, ymin, ymax, maxiter)
  # Only now we have to copy the results back to host memory.
  return cp.asnumpy(M)

In [None]:
t1 = %timeit -o mandel_set12(**fast_args)

In [None]:
t2 = %timeit -o mandel_set12(**slow_args)

In [None]:
all_results.append(
    TimingResult(
        name='Cuda Numba',
        uses_gpu = True,
        timing_result1 = get_timing(t1),
        timing_result2 = get_timing(t2)))

In [None]:
mandelbrot_image(mandel_set12, **slow_args)

In [None]:
df_timings = pd.DataFrame.from_records([asdict(x) for x in all_results])
df_timings['fast_set'] = df_timings.timing_result1.apply(np.median)
df_timings['slow_set'] = df_timings.timing_result2.apply(np.median)
df_timings['speedup_fast_set'] = (df_timings.fast_set / df_timings.fast_set[0]).round(2)
df_timings['speedup_slow_set'] = (df_timings.slow_set / df_timings.slow_set[0]).round(2)

In [None]:
df_timings[['name', 'uses_gpu', 'fast_set', 'slow_set', 'speedup_fast_set', 'speedup_slow_set']]