Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Optimizing Cython for Mandelbrot fractal calculations

examples-mandelbrot/mandelbrot.png

This is a continuation of of a series of blog postings found here:

http://aroberge.blogspot.com/2010/01/profiling-adventures-and-cython-setting.html

At the end of that series of posts, the OP had Cython code that calculated and drew the fractal image. I noted that the drawing might be faster if you modified the pixels one by one as an array in memory, then blitted that to the image.

Not knowing quite how to do that with TK, I made a wx version, then set about making the Cython modifications, using a numpy array for the image.

First version I got working

# mandelcy1.pyx
# cython: profile=True

import cython

@cython.profile(False)
cdef inline int mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                           2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return i
    return -1

def create_fractal( double min_x,
                    double min_y,
                    double pixel_size,
                    int nb_iterations,
                    colours,
                    image):

    cdef int width, height
    cdef int x, y, start_y, end_y
    cdef int nb_colours, current_colour, new_colour
    cdef double real, imag

    nb_colours = len(colours)
    # image is an ndarray of size: w,h,3
    width = image.shape[0]
    height = image.shape[1]

    for x in range(width):
        real = min_x + x*pixel_size
        for y in range(height):
            imag = min_y + y*pixel_size
            colour = mandel(real, imag, nb_iterations)
            image[x, y, :] = colours[ colour%nb_colours ]

What I did differently:

  • moved the makePallet code out of cython -- it's only done once, no need for Cython
  • changed the colors to be numbers (RGB) rather than hex strings.
  • had create_fractal take a numpy array of pixels as input. The code then loops through the pixels and sets their color. Since I'm setting each pixel, there is none of that line-drawing stuff, so it's pretty simple.

I put very simple timing code in the wx wrapper (see below for that), and this takes about 5.9 seconds to run, on a 500x500 image -- increasing the iterations doesn't change it much -- I think it's stopping far short of the max iterations much of the time anyway. I don't know how this compares to the OP's version -- I never did get that running on my machine.

What I haven't done:

Adding numpy arrays

I haven't told Cython that this is a numpy array -- Cython understands numpy arrays; I suspect that will make a big difference:

  1. make the colours sequence a numpy array:

    return np.array(colours, dtype=np.uint8)

WOW! that sped it up by about a factor or two (2.8 seconds), even without telling Cython that it was a numpy array!

  1. let cython know it's a numpy array:
cimport numpy as np # for the special numpy stuff

def create_fractal( double min_x,
                    double min_y,
                    double pixel_size,
                    int nb_iterations,
                    np.ndarray colours not None,
                    np.ndarray image not None):

The not None means that you'll get an exception if you pass in None -- That's left over from Pyrex, where None can be used to mean "not a valid value" for a dtype that doesn't have a NaN or anything.

note that you need to tell the compiler where to find the numpy headers. I do this in the setup.py I'm using to build it:

setup(
    cmdclass = {'build_ext': build_ext},
    ext_modules = [Extension("mandelcy", ["mandelcy.pyx"], )],
    include_dirs = [numpy.get_include(),],
)

OK -- this helped a little, but not much -- about 2 seconds now.

But we can tell Cython what data types to expect:

def create_fractal( double min_x,
                    double min_y,
                    double pixel_size,
                    int nb_iterations,
                    np.ndarray[np.uint8_t, ndim=2, mode="c"] colours not None,
                    np.ndarray[np.uint8_t, ndim=3, mode="c"] image not None):

So it now knows that colours and image are arrays of unsigned 8 bit integers, and what their dimensions are. The "mode = 'c'" means that the array is c-contiguous (rather than fortran order). Interestingly, this didn't help speed any ... yet.

Turning off bounds checking:

@cython.boundscheck(False)
def create_fractal( double min_x, ...):

not much change there, either.

oops, I forgot to tell cython that colour is an int, so it had to convert it to a generic python object, and then back again, inside the loop:

colour = mandel(real, imag, nb_iterations)
image[x, y, :] = colours[ colour%nb_colours]

That helped a bit: from 2.04 to 1.96 seconds.

Cython can index numpy arrays with simple pointer math, If you use only simple indexing to do it, so I change:

image[x, y, :] = colours[ colour%nb_colours]

to:

image[x, y, 0] = colours[ colour%nb_colours, 0 ]
image[x, y, 1] = colours[ colour%nb_colours, 1 ]
image[x, y, 2] = colours[ colour%nb_colours, 2 ]

BINGO! 0.03 seconds, about a 66 times speed up!

Someone on the Cython list suggested that using a single integer to store the color, rather than the three separate 8 byte ones would speed things up -- I tried it, and no noticeable difference. I think that the calculation is swamping the color-setting at this point, and perhaps the pushing around of extra data (I had to use RGBA for a 32 bit integer) slows things down a bit.

Here is that version ( I also tweaked the inputs a little to make it a bit easier to change what part of the fractal to view) :

# mandel3cy.pyx
# cython: profile=True

import cython
#import numpy as np
cimport numpy as np # for the special numpy stuff

@cython.profile(False)
cdef inline int mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                           2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return i
    return -1

@cython.boundscheck(False)
def create_fractal( double min_x,
                    double max_x,
                    double min_y,
                    int nb_iterations,
                    np.ndarray[np.uint8_t, ndim=2, mode="c"] colours not None,
                    np.ndarray[np.uint8_t, ndim=3, mode="c"] image not None):

    cdef int width, height
    cdef int x, y, start_y, end_y
    cdef int nb_colours, colour
    cdef double real, imag, pixel_size



    nb_colours = len(colours)

    width = image.shape[0]
    height = image.shape[1]

    pixel_size = (max_x - min_x) / width

    for x in range(width):
        real = min_x + x*pixel_size
        for y in range(height):
            imag = min_y + y*pixel_size
            colour = mandel(real, imag, nb_iterations)
            image[x, y, 0] = colours[ colour%nb_colours, 0 ]
            image[x, y, 1] = colours[ colour%nb_colours, 1 ]
            image[x, y, 2] = colours[ colour%nb_colours, 2 ]

and the setup.py:

#!/usr/bin/env python

"""
setup.py  to build mandelbot code with cython
"""
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy # to get includes


setup(
    cmdclass = {'build_ext': build_ext},
    ext_modules = [Extension("mandelcy", ["mandelcy.pyx"], )],
    include_dirs = [numpy.get_include(),],
)

And here is the wx code to view it:

#!/usr/bin/env python

"""
A simple app to show mandelbrot fractals.
"""

import time
import numpy as np
import wx

import mandelcy # this is the cython module that does the real work.

class BitmapWindow(wx.Window):
    """
    A simple window to display a bitmap from a numpy array
    """
    def __init__(self, parent, bytearray, *args, **kwargs):
        wx.Window.__init__(self, parent, *args, **kwargs)

        self.bytearray = bytearray
        self.Bind(wx.EVT_PAINT, self.OnPaint)

    def OnPaint(self, evt):
        dc = wx.PaintDC(self)
        w, h = self.bytearray.shape[:2]
        bmp = wx.BitmapFromBuffer(w, h, self.bytearray)
        dc.DrawBitmap(bmp, 50, 0 )

class DemoFrame(wx.Frame):
    def __init__(self, title = "Mandelbrot Demo"):
        wx.Frame.__init__(self, None , -1, title)#, size = (800,600), style=wx.DEFAULT_FRAME_STYLE|wx.NO_FULL_REPAINT_ON_RESIZE)

        # create the array and bitmap:
        self.bytearray = np.zeros((500, 500, 3), dtype=np.uint8) + 125

        self.BitmapWindow = BitmapWindow(self, self.bytearray,
                                         size=self.bytearray.shape[:2])

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.BitmapWindow, 0, wx.ALIGN_CENTER|wx.ALL, 10)
        # set up the buttons
        sizer.Add(self.SetUpTheButtons(), 0, wx.EXPAND)
        self.SetSizerAndFit(sizer)

        self.colours = make_palette()

    def SetUpTheButtons(self):
        RunButton = wx.Button(self, wx.NewId(), "Run")
        RunButton.Bind(wx.EVT_BUTTON, self.OnRun)

        self.IterSlider =   wx.Slider( self, wx.ID_ANY,
                                  value=20,
                                  minValue=20,
                                  maxValue=10000,
                                  size=(250, -1),
                                  style = wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS
                                  )

        QuitButton = wx.Button(self, wx.NewId(), "Quit")
        QuitButton.Bind(wx.EVT_BUTTON, self.OnQuit)

        self.Bind(wx.EVT_CLOSE, self.OnQuit)

        sizer = wx.BoxSizer(wx.HORIZONTAL)
        sizer.Add((1,1), 1)
        sizer.Add(RunButton, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
        sizer.Add((1,1), 1)
        sizer.Add(self.IterSlider, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
        sizer.Add((1,1), 1)
        sizer.Add(QuitButton, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
        sizer.Add((1,1), 1)
        return sizer

    def OnRun(self,Event):
        width, height = self.bytearray.shape[:2]

        min_x = -1.5
        max_x =  0
        min_y = -1.5
        # max_y is calulated from X, to keep it symetric

        nb_iterations = self.IterSlider.Value
        print "Calculating with %i interations:"%nb_iterations
        start = time.clock()
        mandelcy.create_fractal(min_x, max_x, min_y, nb_iterations, self.colours, self.bytearray)
        print "it took %f seconds to run"%( time.clock() - start)
        self.Refresh()

    def OnStop(self, Event=None):
        self.Timer.Stop()

    def OnQuit(self,Event):
        self.Destroy()

def make_palette():
    '''sample coloring scheme for the fractal - feel free to experiment

        No need for this to be in Cython
    '''
    colours = []

    for i in range(0, 25):
        #colours.append('#%02x%02x%02x' % (i*10, i*8, 50 + i*8))
        colours.append( (i*10, i*8, 50 + i*8), )
    for i in range(25, 5, -1):
        #colours.append('#%02x%02x%02x' % (50 + i*8, 150+i*2,  i*10))
        colours.append( (50 + i*8, 150+i*2,  i*10), )
    for i in range(10, 2, -1):
        #colours.append('#00%02x30' % (i*15))
        colours.append( (0, i*15, 48), )
    return np.array(colours, dtype=np.uint8)


app = wx.PySimpleApp(0)
frame = DemoFrame()
frame.Show()
app.MainLoop()

If you don't have wx installed, here's a script that uses SciPy and PIL to save images:

import time
import numpy as np

import mandelcy # this is the cython module that does the real work.

def make_palette():
    '''sample coloring scheme for the fractal - feel free to experiment

        No need for this to be in Cython
    '''
    colours = []

    for i in range(0, 25):
        colours.append( (i*10, i*8, 50 + i*8), )
    for i in range(25, 5, -1):
        colours.append( (50 + i*8, 150+i*2,  i*10), )
    for i in range(10, 2, -1):
        colours.append( (0, i*15, 48), )
    return np.array(colours, dtype=np.uint8)

min_x = -1.5
max_x =  0
min_y = -1.5
nb_iterations = 500

bytearray = np.zeros((500, 500, 3), dtype=np.uint8) + 125
colours = make_palette()
start = time.clock()
mandelcy.create_fractal(min_x, max_x, min_y, nb_iterations, colours, bytearray)
print "it took %f seconds to run"%( time.clock() - start)

from scipy.misc import toimage
toimage(bytearray).save("mandelbrot.png")
Something went wrong with that request. Please try again.