# Vanilla Python

To illustrate Cython, you can consider the following Python function that will compute the first `n` primes and return them as a list.

In [1]:
%cat primes_vanilla.py

def primes(kmax):
    p = [0]*1000
    result = []
    if kmax > 1000:
        kmax = 1000
    k = 0
    n = 2
    while k < kmax:
        i = 0
        while i < k and n % p[i] != 0:
            i = i + 1
        if i == k:
            p[k] = n
            k = k + 1
            result.append(n)
        n = n + 1
    return result


You can import the module and call the function using the `%timeit` magic to establish a baseline timning.

In [4]:
import primes_vanilla

In [5]:
primes_vanilla.primes(20)

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

In [6]:
%timeit primes_vanilla.primes(1000)

17.4 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Cython `.pyx` files

A first approach to speed up this computation is rewriting this function in Cython.  You can review the source code below.

In [4]:
%cat primes_cython.pyx

def primes(int kmax):
    cdef int n, k, i
    cdef int p[1000]
    result = []
    if kmax > 1000:
        kmax = 1000
    k = 0
    n = 2
    while k < kmax:
        i = 0
        while i < k and n % p[i] != 0:
            i = i + 1
        if i == k:
            p[k] = n
            k = k + 1
            result.append(n)
        n = n + 1
    return result


As you can see, the only changes to the original function are
* the declarations of the types for the function's argument,
* the declaration of the type of the variables `n`, `k`, `i`, and
* replacing the `p` Python array by a C array of `int`.

This code first needs to be compiled before it can be run.  Fortunately, this can easily be done from a Jupyter notebook by using the `pyximport` module.  The `install` function will ensure that for `.pyx` files, the `import` defined by `pyximport` will be used.  We also specify the `language_level` to Python 3.

In [7]:
import pyximport
pyximport.install(pyximport=True, pyimport=True, language_level='3str');

Now you can import the Cython module that implements the `primes` function and time it for comparison with the vanilla Python implementation.

In [8]:
import primes_cython

In [9]:
primes_cython.primes(20)

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

In [10]:
%timeit primes_cython.primes(1000)

920 µs ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


It is quite clear that the speedup is considerable for very little effort on your part.

# Pure Python & Cython

It is however also possible to use pure Python with type annotations to get a similar result.

In [8]:
%cat primes_pure_python.py

import cython

def primes(nb_primes: cython.int):
    i: cython.int
    p: cython.int[1000]

    if nb_primes > 1000:
        nb_primes = 1000

    if not cython.compiled:  # Only if regular Python is running
        p = [0] * 1000       # Make p work almost like a C array

    len_p: cython.int = 0  # The current number of elements in p.
    n: cython.int = 2
    while len_p < nb_primes:
        # Is n prime?
        for i in p[:len_p]:
            if n % i == 0:
                break

        # If no break occurred in the loop, we have a prime.
        else:
            p[len_p] = n
            len_p += 1
        n += 1

    # Let's copy the result into a Python list:
    result_as_list = [prime for prime in p[:len_p]]
    return result_as_list

Note that
* the `cython` module has to be imported,
* the Cython types such as `cython.int` have to be specified, rather than `int`,
* you can check whether the Python function has been compiled using `cython.compiled`.

In [11]:
import primes_pure_python

In [12]:
primes_pure_python.primes(20)

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

In [13]:
%timeit primes_pure_python.primes(1000)

907 µs ± 9.59 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


The performance is almost identical to that of the `.pyx` file, and the code is pure Python.

# Dynamic memory management

You can use `malloc` and `free` in Cython code, both in `.pyx` files and using the pure Python syntax.

## Cython `.pyx` files

In [1]:
%cat primes_malloc.pyx

from libc.stdlib cimport malloc, free


def primes(int kmax):
    cdef int n, k, i
    cdef int *p = <int *> malloc(kmax*sizeof(int))
    result = []
    if kmax > 1000:
        kmax = 1000
    k = 0
    n = 2
    while k < kmax:
        i = 0
        while i < k and n % p[i] != 0:
            i = i + 1
        if i == k:
            p[k] = n
            k = k + 1
            result.append(n)
        n = n + 1
    free(p)
    return result


In [2]:
import primes_malloc

In [3]:
primes_malloc.primes(20)

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

In [14]:
%timeit primes_malloc.primes(1000)

965 µs ± 71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Pure Python syntax

In [16]:
%cat primes_pure_malloc.py

import cython
from cython.cimports.libc.stdlib import malloc, free

def primes(nb_primes: cython.int):
    i: cython.int
    p: cython.p_int = cython.cast(cython.p_int, malloc(nb_primes*cython.sizeof(cython.int)))


    len_p: cython.int = 0  # The current number of elements in p.
    n: cython.int = 2
    while len_p < nb_primes:
        # Is n prime?
        for i in p[:len_p]:
            if n % i == 0:
                break

        # If no break occurred in the loop, we have a prime.
        else:
            p[len_p] = n
            len_p += 1
        n += 1

    # Let's copy the result into a Python list:
    result_as_list = [prime for prime in p[:len_p]]
    free(p)
    return result_as_list


In [17]:
import primes_pure_malloc

In function ‘__pyx_pf_18primes_pure_malloc_primes’,
    inlined from ‘__pyx_pw_18primes_pure_malloc_1primes’ at /home/gjb/.pyxbld/temp.linux-x86_64-cpython-311/home/gjb/Projects/Python-for-HPC/source-code/cython/Primes/primes_pure_malloc.c:2139:13:
 2173 |   __pyx_v_p = ((int *)malloc((__pyx_v_nb_primes * (sizeof(int)))));
      |                       ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /home/gjb/mambaforge/envs/python_for_hpc/include/python3.11/Python.h:23,
                 from /home/gjb/.pyxbld/temp.linux-x86_64-cpython-311/home/gjb/Projects/Python-for-HPC/source-code/cython/Primes/primes_pure_malloc.c:28:
/home/gjb/.pyxbld/temp.linux-x86_64-cpython-311/home/gjb/Projects/Python-for-HPC/source-code/cython/Primes/primes_pure_malloc.c: In function ‘__pyx_pw_18primes_pure_malloc_1primes’:
/usr/include/stdlib.h:540:14: note: in a call to allocation function ‘malloc’ declared here
  540 | extern void *malloc (size_t __size) __THROW __attribute_malloc__
     

Although the compiler warning seems a bit unsettling, it can be ignored in this case.

In [18]:
primes_pure_malloc.primes(20)

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

In [19]:
%timeit primes_pure_malloc.primes(1000)

911 µs ± 17.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
