<a href="https://colab.research.google.com/github/masaers/workshop-numpy/blob/main/mandelbrot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Rendering the Mandelbrot set

Below is an implementation of the naïve escape time rendering algorithm.

Each pixel in the resulting image correspond to a complex number with a real component on the x-axis and an imaginary component on the y-axis. Starting at the origin (`z[0] = 0`), each iteration moves `z` such that `z[i+1] = z[i]^2 + c` where `c` is the complex number corresponding to the current pixel. If `z` reamins bounded, `c` is a member of the set. The image is colored by how many iterations it takes to determine that each pixel is *not* a member of the set.

We use `px` and `py` to refer to the pixel locations on the rendered image, `x0` and `y0` (`c` above) to refer to the corresponding points in the complex plane. The `x` and `y` variables are used to represent the current value of `z` (above). The `shape` parameter corresponds to the size of the image, and the `bounds=((xmin, ymin), (xmax, ymax))` corresponds to patch of the complex plain being rendered.

In [None]:
def mandelbrot_python(iterations=256, shape=(800, 800), bounds=((-2.0, -1.5), (1.0, 1.5))):
    result = np.zeros(shape, np.uint)
    x0s = np.linspace(bounds[0][0], bounds[1][0], shape[0])
    y0s = np.linspace(bounds[0][1], bounds[1][1], shape[1])
    for py, y0 in enumerate(y0s):
        for px, x0 in enumerate(x0s):
            x = 0
            y = 0
            for n in range(iterations):
                x, y = x0 + x**2 - y**2, y0 + 2*x*y
                if not x**2 + y**2 < 4:
                    break
                result[py][px] = n
    return result

In [None]:
%timeit mandelbrot_python(10, shape=(100, 100))

Generate a picture with relatively few iterations.

In [None]:
iter = 128
img = mandelbrot_python(iter)

Show the picture using matplotlib. Notice that we transform it from the `0 - iter` range so that it is instead in the `1 - 0` range. This is in order to have the members of the set (point that did not escape after `iter` iterations) be the darkest colors (0).

*Enjoy!*

Try different colormaps (see [documentation](https://matplotlib.org/stable/users/explain/colors/colormaps.html) for valid strings, or try something an read the error message).

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow((iter - img) / iter, cmap="Spectral", origin="lower", extent=[-2, 1, -1.5, 1.5])
plt.show()

## Workshop

Complete the implementation of a Numpy implementation of the naïve escape time algorithm! The goal is to reduce the code to a single loop over `iterations`.

You may see some warnings as some points diverge to infinity. These can be safely ignored.

### Bonus task

Can you speed it up even more and get rid of the warnings by only updating points that have not yet escaped?

In [None]:
def mandelbrot_numpy(iterations=256, shape=(800, 800), bounds=((-2.0, -1.5), (1.0, 1.5))):
    result = np.zeros(shape, np.uint)
    x0 = np.repeat(np.expand_dims(np.linspace(bounds[0][0], bounds[1][0], shape[0]), axis=0), shape[1], axis=0)
    y0 = np.repeat(np.expand_dims(np.linspace(bounds[0][1], bounds[1][1], shape[1]), axis=1), shape[0], axis=1)
    x = np.zeros(shape)
    y = np.zeros(shape)
    for n in range(iterations-1):
        # Your code goes here
    return result

Nice speedup!

In [None]:
%timeit mandelbrot_numpy(10, shape=(100, 100))

Now we can render a more precise image with more iterations!

In [None]:
iter = 1024
img = mandelbrot_numpy(iter)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow((iter - img) / iter, cmap="Spectral", origin="lower", extent=[-2, 1, -1.5, 1.5])
plt.show()

## Play around

Try zooming in on boundary areas. You may have to increase the number of iterations as you zoom in.

In [None]:
iter = 1024
xmin, xmax = -0.750, -0.720
ymin, ymax =  0.165,  0.195
img = mandelbrot_numpy(iter, bounds=((xmin, ymin), (xmax, ymax)))

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow((iter - img) / iter, cmap="inferno", origin="lower", extent=[xmin, xmax, ymin, ymax])
plt.show()