# Black and white

Now that we've had a look at multi-dimensional indexing why don't you try and use two-dimensional indexing to make our image black and white.

Instead of operating over all pixels channel by channel we want to just operate over all pixels and average the channels out.

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from numba import cuda
import numpy as np
import math

plt.rcParams["figure.figsize"] = (30,4)


**1. Load our image with matplotlib.**

In [None]:
!wget https://raw.githubusercontent.com/jacobtomlinson/gpu-python-tutorial/main/images/numba.png
im = plt.imread("numba.png")
plt.imshow(im)


**2. Move our image to the GPU and create an output array of the same size.**

In [None]:
gpu_im = cuda.to_device(im)
gpu_output = cuda.to_device(np.zeros_like(gpu_im))


**3. Set our two-dimensional thead size and block size.** _Hint: Our `threadsperblock` should still multiply to `128`._

In [None]:
threadsperblock = (16, 16)
blockspergrid_x = math.ceil(gpu_im.shape[0] / threadsperblock[0])
blockspergrid_y = math.ceil(gpu_im.shape[1] / threadsperblock[1])
blockspergrid = (blockspergrid_x, blockspergrid_y)


**4. Write our kernel.**

In [None]:
@cuda.jit
def black_white(im, output):
    # With our two-dimensional grid we can get our index position in two dimensions
    x, y = cuda.grid(2)

    # Because our grid is slightly larger than our image anything outside the image should be ignored
    if x < im.shape[0] and y < im.shape[1]:

        # Calculate the average across the RGB channels
        average = (im[x, y, 0] + im[x, y, 1] + im[x, y, 2]) / 3

        # Set all output RGB channels to the average
        output[x, y, 0] = average
        output[x, y, 1] = average
        output[x, y, 2] = average

        # Pass the alpha channel through
        output[x, y, 3] = im[x, y, 3]


**5. Run the kernel.**

In [None]:
black_white[blockspergrid, threadsperblock](gpu_im, gpu_output)


**6. Move the data back from the GPU and plot it.**

In [None]:
plt.imshow(gpu_output.copy_to_host())
