In [1]:
########################################
## Evaluate this cell before starting ##
########################################

import timeit
from collections import OrderedDict
from itertools import islice
from math import log
from bokeh.models.formatters import NumeralTickFormatter
from bokeh.models.ranges import DataRange1d
from bokeh.models.sources import ColumnDataSource
from bokeh.palettes import Category10
from bokeh.plotting import figure
from bokeh.io import output_notebook, show, push_notebook

output_notebook()

timing_palette = Category10[10]
timing_lines = OrderedDict()

def iterations():
    m = 1
    while True:
        for i in (1, 2, 5):
            yield m * i
        m *= 10

def approx_nth(n):
    if n < 6:
        return int(2.2 * n + 1)
    else:
        return int(n * (log(n) + log(log(n))))

def time_gen(genfn, num):
    def timed():
        return list(islice(genfn(), num - 1, num))[0]
    return min(timeit.repeat(timed, number=1, repeat=3, globals=globals()))

def timing_plot(genfn, need_nth=False):
    def plot(fig, name, vals, num, dash='solid'):
        col = timing_palette[num % len(timing_palette)]
        fig.line('x', 'y', legend=name, source=vals, line_dash=dash, color=col)
        fig.scatter('x', 'y', legend=name, source=vals, marker='o', color=col)
    name = genfn.__name__
    exist = None
    if True: # log-log
        extra_args = dict(y_range=[1e-6, 1], x_range=DataRange1d(start=1),
                          x_axis_type='log', y_axis_type='log')
    else: # lin-lin
        extra_args = dict(y_range=[0,1], x_range=DataRange1d(start=0))
    fig = figure(plot_width=800, plot_height=400, toolbar_location='above', title="Timing", **extra_args)
    num = 0
    for k, v in timing_lines.items():
        plot(fig, k, v, num, 'dashed')
        if k == name:
            exist = num
        num += 1
    source = ColumnDataSource(data=dict(x=[], y=[]))
    plot(fig, name, source, exist or num)
    fig.xaxis.axis_label = "Primes"
    fig.xaxis.formatter = NumeralTickFormatter(format='0[.]0 a')
    fig.yaxis.axis_label = "Seconds"
    fig.legend.location = 'top_left'
    fig.legend.click_policy='hide'
    fig.legend.background_fill_alpha = 0.5
    handle = show(fig, notebook_handle=True)
    for i in iterations():
        if need_nth:
            def gen():
                return genfn(approx_nth(i))
            gen.__name__ = genfn.__name__
        else:
            gen = genfn
        t = time_gen(gen, i)
        source.stream(dict(x=[i], y=[t]))
        push_notebook(handle=handle)
        if t >= 1: break
    timing_lines[gen.__name__] = source.data

from IPython.display import Javascript

Javascript('''
require(['base/js/namespace', 'base/js/events'],
function (Jupyter, events) {
    function swap_src(el, src, t) {
        console.log("swap", el, src, t);
        var old = el.src;
        el.src = src;
        setTimeout(function() {el.src = old;}, t);
    }

    // save a reference to the cell we're currently executing inside of,
    // to avoid clearing it later (which would remove this js)
    var this_cell = $(element).closest('.cell').data('cell');
    function init_presentation() {
        // Clear (other) cell outputs
        Jupyter.notebook.get_cells().forEach(function (cell) {
            if (cell.cell_type === 'code' && cell !== this_cell) {
                cell.clear_output();
            }
            Jupyter.notebook.set_dirty(true);
        });
        // Make sieve clickable to start gif
        sieve.src = 'sieve1.png';
        sieve.onclick = function() {
            swap(document.getElementById("sieve"), 'sieve.gif', 37000);
        };
    }

    if (Jupyter.notebook._fully_loaded) {
        // notebook has already been fully loaded, so init now
        init_presentation();
    }
    // Also clear on any future load
    // (e.g. when notebook finishes loading, or when a checkpoint is reloaded)
    events.on('notebook_loaded.Notebook', init_presentation);
});


''')

<IPython.core.display.Javascript object>

# Optimised Primes

Emlyn Corrin

<img data-gifffer="sieve.gif" />

![](prime.png)
<!--- Image (public domain) from:
https://www.flickr.com/photos/114305749@N08/24438440681
-->

## Why?

- Online programming contests (Project Euler etc.)

- Mathematical or programming exercise

- Because it's fun!

## What is a prime?

A prime number (or a prime) is a natural number greater than 1 that has no positive divisors other than 1 and itself.
<div style="text-align: right">&mdash; Wikipedia</div>

In [2]:
def is_prime(n):
    # Greater than 1
    if n <= 1:
        return False
    # Any positive divisors (> 1 and < n)?
    for i in range(2, n):
        if n % i == 0:
            return False
    # If not, it must be a prime
    return True

## Let's generate a few

In [3]:
[i for i in range(20) if is_prime(i)]

[2, 3, 5, 7, 11, 13, 17, 19]

In [4]:
print(', '.join(str(i) for i in range(2000) if is_prime(i)))

2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 122

## What about generating them on demand

In [5]:
from itertools import count, islice

def primes1():
    for i in count():
        if is_prime(i):
            yield i
    
list(islice(primes1(), 20))

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

## But how fast is it?

In [6]:
timing_plot(primes1)

## Can we make it faster?

### What about skipping all even numbers (apart from 2)?

In [7]:
from itertools import count

def primes2():
    def is_prime(n):
        for i in range(3, n, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

### How much faster is it?

In [8]:
timing_plot(primes2)

## Can we reduce the number of checks further?

### Yes!
Factors always come in pairs: $n = f*g$ (except square numbers where you can have $f = g = \sqrt n$).  
If $f \leq \sqrt n$ then $g \geq \sqrt n$, and vice versa.  
So if $n$ has any prime factors, at least one of them must always be $\leq \sqrt n$.  
So we only have to check up to $\sqrt n$, not to $n$.

In [9]:
from itertools import count
from math import sqrt

def primes3():
    def is_prime(n):
        for i in range(3, int(sqrt(n)) + 1, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

## How much faster is this?

In [10]:
timing_plot(primes3)

## Is this the best we can do?

We are still checking more numbers than necessary:  
e.g. once we've tested for divisibility by 3 and 5,  
we shouldn't need to test 9, 15, 25, 30, 45... etc.

i.e. we only need to check for divisibility by primes.

## The sieve of Eratosthenes

1. start with a grid of numbers, from 2 to max_prime
2. find first (next) unmarked number, return that as a prime
3. mark all multiples of it (actually just from n² onwards)
4. go back to 2.
<img src="sieve.png" id="sieve" />

In [11]:
def primes4(max_prime):
    sieve = [True] * max_prime
    for i in range(2, max_prime):
        if sieve[i]:
            yield i
            for j in range(i, max_prime, i):
                sieve[j] = False

In [12]:
def primes5(max_prime):
    sieve = [True] * (max_prime // 2)
    yield 2
    for i in range(3, max_prime, 2):
        if sieve[i//2]:
            yield i
            for j in range(i*i, max_prime, i*2):
                sieve[j//2] = False

In [13]:
timing_plot(primes4, True)

In [14]:
timing_plot(primes5, True)

## Problems?

### Memory use
- Use packed data structure (e.g. struct module), encode 8 cells/byte
- Also skip multiples of 3 (only check numbers of form $6n \pm 1$)

### Need to allocate storage upfront
Often don't know in advance how much to allocate
(e.g. first 100k primes)

## What about storing a list of primes so far, and only test dividing by those?

In [15]:
from itertools import count

def primes6():
    yield 2
    primelist = []
    for candidate in count(3, 2):
        for p in primelist:
            if candidate % p == 0:
                break
        else:
            yield candidate
            primelist.append(candidate)

In [16]:
timing_plot(primes6)

## Better?
Now don’t have to decide upper limit in advance, but slower

What about switching things around… for each prime store the next multiple higher than candidate, then we just just have to check if candidate is in the list, not multiple test divisions per candidate.
For each multiple in the list, we store the original prime, so that when we reach it, we we can add it to generate the next multiple. But it could be a multiple of more than one prime, so we have to store a list of source primes:

In [17]:
from itertools import count

def primes7():
    state = {}
    for candidate in count(2):
        if candidate in state:
            for factor in state[candidate]:
                if candidate + factor in state:
                    state[candidate + factor].append(factor)
                else:
                    state[candidate + factor] = [factor]
            del state[candidate]
        else:
            yield candidate
            state[2 * candidate] = [candidate]

In [18]:
timing_plot(primes7)

We can make a few optimisations:
Defaultdict so we don’t have to check if a number is present
We skip even numbers, and therefore even multiples of primes
When we find a prime, p, the first multiple we have to add to the state is p^2, because smaller multiples will have another factor less than p  (p*q, where q < p).


In [19]:
from collections import defaultdict
from itertools import count

def primes8():
    yield 2
    state = defaultdict(list)
    for candidate in count(3, 2):
        if candidate in state:
            for inc in state[candidate]:
                state[candidate + inc].append(inc)
            del state[candidate]
        else:
            yield candidate
            state[candidate * candidate] = [2 * candidate]

In [20]:
timing_plot(primes8)

In [21]:
import heapq
from itertools import count

def primes9():
    yield 2
    yield 3
    state = [(9, 2 * 3)]
    for candidate in count(5, 2):
        #print('dbg', candidate, state[0])
        if candidate == state[0][0]:
            while candidate == state[0][0]:
                mult, inc = state[0]
                heapq.heapreplace(state, (mult + inc, inc))
        else:
            yield candidate
            heapq.heappush(state, (candidate * candidate, 2 * candidate))

In [22]:
timing_plot(primes9)

In [23]:
import pyprimesieve

timing_plot(pyprimesieve.primes, True)

In [24]:
from IPython.display import Javascript

Javascript('''
sieve.src = 'sieve.png';
''')

<IPython.core.display.Javascript object>