# Lesson 4 project 1: Drawing the Mandelbrot set

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In this project, you'll draw the Mandelbrot set—a fractal—in imperative, array-oriented, and JIT-compiled ways.

The [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set) is a set of [complex numbers](https://en.wikipedia.org/wiki/Complex_number) $c$ for which

$$z_{i + 1} = |z_i|^2 + c \mbox{\hspace{0.5 cm} with } z_0 = 0$$

does not diverge to infinity ($z_i \to \infty$ as $i \to \infty$).

To draw this figure, we start with a two-dimensional array representing the plane of complex numbers (real coordinates from $-2$ to $1/2$ and imaginary coordinates from $-1.2$ to $1.2$), which can be produced in NumPy with [np.ogrid](https://numpy.org/doc/stable/reference/generated/numpy.ogrid.html):

In [None]:
height   = 35
width    = 70
real_min = -2
real_max = 0.5
imag_min = -1.2
imag_max = 1.2

# "j" multiplies by the imaginary number sqrt(-1) in Python
y, x = np.ogrid[imag_min:imag_max:height*1j, real_min:real_max:width*1j]
c = x + y*1j

# show the four corners
c[0, 0], c[0, -1], c[-1, 0], c[-1, -1]

Each element of this array is a different value of $c$. Now we want another variable, $z$, to start at $0$ and iterate

```python
z = z**2 + c
```

infinitely many times to see if $z \to \infty$. We can't iterate infinitely many times or actually compute infinity, so

* if `abs(z) > 2`, then continued iteration would take `z` to infinity (not an approximation, but based on a theorem not shown  here)
* if `abs(z)` hasn't exceeded `2` after 20 iterations, we assume that it won't (an approximation that isn't good enough for [lesson-5b-gpu/project-area.ipynb](../lesson-5b-gpu/project-area.ipynb), but good enough for drawing pictures).

In [None]:
h = w = 0

z = 0
for i in range(20):
    z = z**2 + c[h, w]
    if abs(z) > 2:
        print(f"{c[h, w]} is NOT in the Mandelbrot set")
        break
else:
    # Python language feature: `for ... else` enters the `else` clause if it does not `break`
    print(f"{c[h, w]} is in the Mandelbrot set")

Draw a picture by iterating over all the sampled points in the complex plane.

In [None]:
picture = np.empty(c.shape, dtype=np.bool_)

for h in range(height):
    for w in range(width):
        z = 0
        for i in range(20):
            z = z**2 + c[h, w]
            if abs(z) > 2:
                picture[h, w] = False
                break
        else:
            picture[h, w] = True

In [None]:
for h in range(height):
    for w in range(width):
        print("*" if picture[h, w] else " ", end="")
    print()

It's not much different from the first-ever picture of a Mandelbrot set by Robert Brooks and Peter Matelski (45 years ago!),

<center>
<img src="../img/Mandel.png" width="500px">
</center>

but not as glamorous as pictures in calendars, which add color by labeling how many iterations were needed to get to `abs(z) > 2`.

In [None]:
picture = np.empty(c.shape, dtype=np.int64)

for h in range(height):
    for w in range(width):
        z = 0
        for i in range(20):
            z = z**2 + c[h, w]
            if abs(z) > 2:
                picture[h, w] = i
                break
        else:
            picture[h, w] = 20  # at least

In [None]:
fig, ax = plt.subplots(1, 1)

ax.imshow(picture);
ax.set_xlabel("index along real axis");
ax.set_ylabel("index along imaginary axis");

They're also usually in higher resolution, but that will be a problem for us if we keep iterating in Python.

Before we get to alternatives, let's wrap the above up as a function, to more easily change the scale.

In [None]:
def run_python(height, width, real_min=-2, real_max=0.5, imag_min=-1.2, imag_max=1.2):
    y, x = np.ogrid[imag_min:imag_max:height*1j, real_min:real_max:width*1j]
    c = x + y*1j

    picture = np.empty(c.shape, dtype=np.int64)

    for h in range(height):
        for w in range(width):
            z = 0
            for i in range(20):
                z = z**2 + c[h, w]
                if abs(z) > 2:
                    picture[h, w] = i
                    break
            else:
                picture[h, w] = 20

    return picture

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))

ax.imshow(run_python(400, 600));

Now with twice the resolution...

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))

ax.imshow(run_python(800, 1200));

You probably noticed the lag.

<br><br><br><br><br>

## Exercise 1

Accelerate the calculation with NumPy.

Notice that the Mandelbrot calculation has an "iterate until converged" step, like the special functions we tried to accelerate in the lecture.

Note that you can use a slice to determine which array values to assign:

In [None]:
array = np.arange(10)
array

In [None]:
array[array % 2 == 0] = 999

In [None]:
array

You want to know at which iteration each array element _starts_ to diverge, so be sure to only overwrite array elements if they have not yet been overwritten. How you do that is the key to this exercise.

In [None]:
def run_numpy(height, width, real_min=-2, real_max=0.5, imag_min=-1.2, imag_max=1.2):
    y, x = np.ogrid[imag_min:imag_max:height*1j, real_min:real_max:width*1j]
    c = x + y*1j

    z = np.zeros(c.shape, dtype=np.complex128)       # initial values of all elements of z are 0
    picture = np.full(c.shape, 20, dtype=np.int64)   # initial values of all pixels are 20

    for i in range(20):
        ...

    return picture

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))

with np.errstate(over="ignore", invalid="ignore"):   # ignore warnings about inf and nan values in z
    ax.imshow(run_numpy(800, 1200));

How fast is the NumPy solution, compared with pure Python?

Keep in mind that the algorithms are different. The Python algorithm _stops processing_ when `abs(z)` reaches `2`, but the NumPy solution keeps going because it's array-oriented.

In [None]:
%%timeit -r1 -n1

run_python(800, 1200)

In [None]:
%%timeit -r1 -n1

with np.errstate(over="ignore", invalid="ignore"):
    run_numpy(800, 1200)

<br><br><br><br><br>

## Exercise 2

Now compile the function with Numba. Start with `run_python` and modify it so that Numba can JIT-compile it.

At the time that I'm writing this, `np.ogrid` is not in [Numba's list of supported NumPy functions](https://numba.readthedocs.io/en/stable/reference/numpysupported.html), so if you try to include it in the JIT-compiled part, it will raise a compilation error. How can you work around that?

See Numba's [JIT-compilation docs](https://numba.readthedocs.io/en/stable/reference/jit-compilation.html) for documentation on `nb.jit` and `nb.vectorize`, and use `nopython=True` to ensure that it doesn't try to use "object mode" (which usually only results in more confusing error messages).

In [None]:
import numba as nb

In [None]:
def run_numba(height, width, real_min=-2, real_max=0.5, imag_min=-1.2, imag_max=1.2):
    ...

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))

ax.imshow(run_numba(800, 1200));

How fast is it?

In [None]:
%%timeit -r1 -n1

run_python(800, 1200)

In [None]:
%%timeit -r1 -n1

with np.errstate(over="ignore", invalid="ignore"):
    run_numpy(800, 1200)

In [None]:
%%timeit -r1 -n1

run_numba(800, 1200)

<br><br><br><br><br>

## Exercise 3

Now compile it with JAX. Unlike Numba, which is similar to `run_python`, compilation with JAX is similar to `run_numpy`.

Unlike Numba, you can't use functions from the NumPy namespace (`np.*`); you have to use JAX's equivalents in `jax.numpy.*`.

JAX has a [jax.numpy.ogrid](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ogrid.html) function, but it can't be used in a JIT-compiled function (because the [shape of the array depends on its arguments](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit)), so it has to be outside the compiled part. You can solve this in a similar way as in exercise 2.

Another issue is that in-place assignment,

```python
array[array % 2 == 0] = 999
```

won't work. ([JAX arrays are immutable](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates).) However, JAX provides a [jax.numpy.where](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) function that lets you do the above as

```python
#                          condition    if_true  if_false
array = jax.numpy.where(array % 2 == 0,   999,    array  )
```

In [None]:
import jax

# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
from jax import config
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

In [None]:
def run_jax(height, width, real_min=-2, real_max=0.5, imag_min=-1.2, imag_max=1.2):
    ...

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))

ax.imshow(run_jax(800, 1200));

How fast is it?

In [None]:
%%timeit -r1 -n1

run_python(800, 1200)

In [None]:
%%timeit -r1 -n1

with np.errstate(over="ignore", invalid="ignore"):
    run_numpy(800, 1200)

In [None]:
%%timeit -r1 -n1

run_numba(800, 1200)

In [None]:
%%timeit -r1 -n1

run_jax(800, 1200)

Note: it's hard to get JAX to use only one thread or use the CPU when a GPU is available. This is good if you want speed at all costs, but it's bad if you're trying to discover _why_ JAX is faster. If JAX is using all CPU cores while Numba is only using one, it would be more fair to compare against [Numba with parallel processing](https://numba.readthedocs.io/en/stable/user/parallel.html). Similarly with 32-bit types and mathematical approximations (`fastmath=True`).

<br><br><br><br><br>

When you're done with this exercise, see [Mandelbrot on all accelerators](https://colab.research.google.com/drive/1J0l5e0NZm5kEm5BEUDG4neN5EN0VVCnt#scrollTo=JMJx2GOjtdyz) (which has spoilers/solutions to the above) and a [discussion about it with the JAX developers](https://colab.research.google.com/drive/google/jax#11078) for a deep-dive into accelerated Python.

<br><br><br><br><br>

Also, try using your fastest implementation to zoom in on the fine structure. Wheee!

<center>
<img src="../img/Mandelbrot_sequence_new.gif">
</center>