In [None]:
# !poetry add -D matplotlib

In [None]:
from timeit import default_timer as timer
from matplotlib.pylab import imshow, show
import numpy as np


def mandelbrot(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x, y)
    z = 0.0j
    for i in range(max_iters):
        z = z * z + c
        if (z.real * z.real + z.imag * z.imag) >= 4:
            return i

    return 255


def create_fractal(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandelbrot(real, imag, iters)
            image[y, x] = color

In [None]:
image = np.zeros((500 * 10, 750 * 10), dtype=np.uint8)

In [None]:
%%time

create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20)

In [None]:
imshow(image)  # show()

## Compiling with numba on CPU

In [None]:
import numba as na

@na.jit
def mandelbrot_cpu(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x, y)
    z = 0.0j
    for i in range(max_iters):
        z = z * z + c
        if (z.real * z.real + z.imag * z.imag) >= 4:
            return i

    return 255

@na.jit
def create_fractal_cpu(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandelbrot_cpu(real, imag, iters)
            image[y, x] = color

In [None]:
image_cpu = np.zeros((500 * 10, 750 * 10), dtype=np.uint8)

In [None]:
%%time

create_fractal_cpu(-2.0, 1.0, -1.0, 1.0, image_cpu, 20)

In [None]:
imshow(image_cpu)

## Compiling with numba on GPU

In [None]:
from numba import cuda

@cuda.jit(device=True)
def mandelbrot_gpu(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x, y)
    z = 0.0j
    for i in range(max_iters):
        z = z * z + c
        if (z.real * z.real + z.imag * z.imag) >= 4:
            return i

    return 255

@cuda.jit
def create_fractal_gpu(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height    

    x, y = cuda.grid(2)

    if x < width and y < height:
        real = min_x + x * pixel_size_x
        imag = min_y + y * pixel_size_y
        color = mandelbrot_gpu(real, imag, iters)
        image[y, x] = color

In [None]:
image_gpu = np.zeros((500 * 10, 750 * 10), dtype=np.uint8)

nthread = 32
blockspergrid = (image_gpu.shape[1] // nthread + 1, image_gpu.shape[0] // nthread + 1)  # NOTE output `image` shaped as `(y, x)`

In [None]:
%%time

create_fractal_gpu[blockspergrid, (nthread, nthread)](
    -2.0, 1.0, -1.0, 1.0, image_gpu, 20
)

In [None]:
imshow(image_gpu)

## Compiling as uFunc

In [None]:
from numba import vectorize

sig = "uint8(uint32, f4, f4, f4, f4, uint32, uint32, uint32)"


@vectorize([sig], target="cuda")
def mandel(tid, min_x, max_x, min_y, max_y, width, height, iters):
    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height

    x = tid % width  # 需要这里算坐标
    y = tid / width

    real = min_x + x * pixel_size_x
    imag = min_y + y * pixel_size_y

    c = complex(real, imag)
    z = 0.0j

    for i in range(iters):
        z = z * z + c
        if (z.real * z.real + z.imag * z.imag) >= 4:
            return i
    return 255


def create_fractal_ufunc(min_x, max_x, min_y, max_y, width, height, iters):
    tids = np.arange(width * height, dtype=np.uint32)
    return mandel(
        tids,
        np.float32(min_x),
        np.float32(max_x),
        np.float32(min_y),
        np.float32(max_y),
        np.uint32(height),
        np.uint32(width),
        np.uint32(iters),
    )

In [None]:
width, height = 500 * 10, 750 * 10
image_ufunc = np.zeros((width, height), dtype=np.uint8)

In [None]:
pixels = create_fractal_ufunc(-2.0, 1.0, -1.0, 1.0, width, height, 20)

In [None]:
imshow(pixels.reshape((width, height)))
# mandel.functions