In [1]:
########################################
## Evaluate this cell before starting ##
########################################

from bokeh.io import output_notebook, push_notebook, show
from bokeh.models.formatters import BasicTickFormatter, NumeralTickFormatter
from bokeh.models.ranges import DataRange1d
from bokeh.models.sources import ColumnDataSource
from bokeh.models.widgets import Panel, Tabs
from bokeh.palettes import Category10
from bokeh.plotting import figure
from collections import OrderedDict
from itertools import count, islice
from math import log
from timeit import timeit

output_notebook()

REPEATS = 3

timing_palette = Category10[10]
timing_lines = OrderedDict()

def iterations():
    for x in count():
        for i in (1, 2, 5):
            yield i * 10 ** x

def approx_nth(n):
    return int(2.2 * n + 1) if n < 6 else int(n * (log(n) + log(log(n))))

def null_func():
    def null_gen():
        yield 0
    return list(islice(null_gen(), 0, 1))[0]

NULL_TIME = timeit(null_func, number=1000, globals=globals()) / 1000.0

def plot_line_separate(genfn, source, handle):
    # Generate first 1000 primes to warm up the code first
    list(islice(genfn(approx_nth(1000)), 999, 1000))[0]
    for r in range(REPEATS):
        for n, i in enumerate(iterations()):
            def timed():
                return list(islice(genfn(approx_nth(i)), i - 1, i))[0]
            t = timeit(timed, number=1, globals=globals()) - NULL_TIME
            if r == 0:
                source.stream(dict(x=[i], y=[t]))
            else:
                if t < source.data['y'][n]:
                    source.patch(dict(x=[(n, i)], y=[(n, t)]))
            push_notebook(handle=handle)
            if r == 0:
                if t >= 1.2: break
            else:
                if n >= len(source.data['x']) - 1: break

def plot_line_combined(genfn, source, handle):
    # Generate first 1000 primes to warm up the code first
    list(islice(genfn(), 999, 1000))[0]
    for r in range(REPEATS):
        t = 0.0
        last_i = 0
        last_t = 0.0
        gen = genfn()
        for n, i in enumerate(iterations()):
            def timed():
                num = i - last_i
                return list(islice(gen, num - 1, num))[0]
            t = timeit(timed, number=1, globals=globals()) - NULL_TIME
            last_t += t
            if r == 0:
                source.stream(dict(x=[i], y=[last_t]))
            else:
                if last_t < source.data['y'][n]:
                    diff = source.data['y'][n] - last_t
                    source.patch(dict(y=[(y, source.data['y'][y] - diff)
                                         for y in range(n, len(source.data['y']))]))
                else:
                    last_t = source.data['y'][n]
            push_notebook(handle=handle)
            last_i = i
            if r == 0:
                if last_t > 1.2: break
            else:
                if n >= len(source.data['x']) - 1: break

def timing_plot(genfn):
    def plot(fig, name, vals, num, dash='solid'):
        col = timing_palette[num % len(timing_palette)]
        fig.line('x', 'y', legend=name, source=vals, line_dash=dash, color=col)
        fig.scatter('x', 'y', legend=name, source=vals, marker='o', color=col)
    name = genfn.__name__
    exist = None
    args = dict(plot_width=800, plot_height=400, toolbar_location='above', title="Timing")
    linfig = figure(**args, y_range=[0,1], x_range=DataRange1d(start=0))
    logfig = figure(**args, y_range=[1e-6, 1], x_range=DataRange1d(start=1),
                    x_axis_type='log', y_axis_type='log')
    num = 0
    # add previous lines
    for k, v in timing_lines.items():
        plot(linfig, k, v, num, 'dashed')
        plot(logfig, k, v, num, 'dashed')
        if k == name:
            exist = num
        num += 1
    source = ColumnDataSource(data=dict(x=[], y=[]))
    for fig in (linfig, logfig):
        plot(fig, name, source, exist or num)
        fig.xaxis.axis_label = "Primes"
        fig.xaxis.formatter = NumeralTickFormatter(format='0[.]0 a')
        fig.xgrid.minor_grid_line_color = 'lightgrey'
        fig.xgrid.minor_grid_line_alpha = 0.2
        fig.yaxis.axis_label = "Seconds"
        fig.legend.location = 'top_left'
        fig.legend.click_policy='hide'
        fig.legend.background_fill_alpha = 0.5
    linfig.yaxis.formatter = BasicTickFormatter()
    logfig.yaxis.formatter = BasicTickFormatter(use_scientific=True, precision=0)
    lintab = Panel(child=linfig, title="Linear")
    logtab = Panel(child=logfig, title="Log")
    tabs = Tabs(tabs=[lintab, logtab])
    handle = show(tabs, notebook_handle=True)
    if genfn.__code__.co_argcount == 0:
        plot_line_combined(genfn, source, handle)
    else:
        plot_line_separate(genfn, source, handle)
    # save line data to show on next plot
    timing_lines[name] = source.data

from IPython.display import Javascript

# Cell clearing code based on:
# https://stackoverflow.com/questions/45638720/jupyter-programmatically-clear-output-from-all-cells-when-kernel-is-ready

Javascript('''
require(['base/js/namespace', 'base/js/events'],
function (Jupyter, events) {
    function swap_src(el, src, t) {
        console.log("swap", el, src, t);
        var old = el.src;
        el.src = src;
        setTimeout(function() {el.src = old;}, t);
    }

    // save a reference to the cell we're currently executing inside of,
    // to avoid clearing it later (which would remove this js)
    var this_cell = $(element).closest('.cell').data('cell');
    function init_presentation() {
        // Clear (other) cell outputs
        Jupyter.notebook.get_cells().forEach(function (cell) {
            if (cell.cell_type === 'code' && cell !== this_cell) {
                cell.clear_output();
            }
            Jupyter.notebook.set_dirty(true);
        });
        // Make sieve clickable to start gif
        sieve.src = 'resources/sieve1.png';
        sieve.onclick = function() {
            swap(document.getElementById("sieve"), 'resources/sieve.gif', 37000);
        };
    }

    if (Jupyter.notebook._fully_loaded) {
        // notebook has already been fully loaded, so init now
        init_presentation();
    }
    // Also clear on any future load
    // (e.g. when notebook finishes loading, or when a checkpoint is reloaded)
    events.on('notebook_loaded.Notebook', init_presentation);
});


''')

<IPython.core.display.Javascript object>

# Optimised Primes

Emlyn Corrin

<img data-gifffer="resources/sieve.gif" />

![](resources/prime.png)
<!--- Image (public domain) from:
https://www.flickr.com/photos/114305749@N08/24438440681
-->

## Why?

- Online programming contests (Project Euler etc.)

- Mathematical or programming exercise

- Because it's fun!

## What is a prime?

A prime number (or a prime) is a natural number greater than 1 that has no positive divisors other than 1 and itself.
<div style="text-align: right">&mdash; Wikipedia</div>

In [2]:
def is_prime(n):
    # Greater than 1
    if n <= 1:
        return False
    # Any positive divisors (> 1 and < n)?
    for i in range(2, n):
        if n % i == 0:
            return False
    # If not, it must be a prime
    return True

## Let's generate a few

In [3]:
# All primes less than 20
[i for i in range(20) if is_prime(i)]

[2, 3, 5, 7, 11, 13, 17, 19]

In [4]:
# Primes less than 2000 (convert to string otherwise Jupyter only displays 1 per line)
print(', '.join(str(i) for i in range(2000) if is_prime(i)))

2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 122

## What about generating them on demand

In [5]:
from itertools import count, islice

def first_try():
    for i in count():
        if is_prime(i):
            yield i

# First 20 primes
list(islice(first_try(), 20))

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

## But how fast is it?

In [6]:
timing_plot(first_try)

## Can we make it faster?

### What about skipping all even numbers (apart from 2)?

In [7]:
from itertools import count

def skip_even():
    def is_prime(n):
        for i in range(3, n, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

### How much faster is it?

In [8]:
timing_plot(skip_even)

## Can we reduce the number of checks further?

### Yes!
Factors always come in pairs: if $n$ has a factor $f$, then $g = n / f$ must also be a factor.  
If $f \leq \sqrt n$ then $g \geq \sqrt n$, and vice versa (equal when $f = g = \sqrt n$).  
So if $n$ has any prime factors, at least one of them must always be $\leq \sqrt n$,  
and therefore we can stop checking once we reach $\sqrt n$.

In [9]:
from itertools import count
from math import sqrt

def to_sqrt():
    def is_prime(n):
        for i in range(3, int(sqrt(n)) + 1, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

## How much faster is this?

In [10]:
timing_plot(to_sqrt)

## Is this the best we can do?

We are still checking more numbers than necessary:  
e.g. once we've tested for divisibility by 3 and 5,  
we shouldn't need to test their multiples (e.g. 9, 15, 25, 30, 45... etc).

i.e. we only need to check for divisibility by primes.

## What about storing a list of primes so far, and only test dividing by those?

In [11]:
from itertools import count

def check_primes():
    yield 2
    primelist = []
    for candidate in count(3, 2):
        for p in primelist:
            if candidate % p == 0:
                break
        else:
            yield candidate
            primelist.append(candidate)

In [12]:
timing_plot(check_primes)

### Why is this slower?

We forgot to stop testing primes $> \sqrt n$.  
Let's fix that.

In [13]:
from itertools import count

def check_primes_sqrt():
    yield 2
    primelist = []
    for candidate in count(3, 2):
        prime = True
        for p in primelist:
            if p * p > candidate:
                break
            if candidate % p == 0:
                prime = False
                break
        if prime:
            yield candidate
            primelist.append(candidate)

In [14]:
timing_plot(check_primes_sqrt)

## The sieve of Eratosthenes

1. start with a grid of numbers, from 2 to max_prime
2. find first (next) unmarked number, return that as a prime
3. mark all multiples of it (actually just from n² onwards)
4. go back to 2.
<img src="resources/sieve.png" id="sieve" />

In [15]:
def sieve(max_prime):
    sieve = [True] * (max_prime // 2)
    yield 2
    for i in range(3, max_prime, 2):
        if sieve[i//2]:
            yield i
            for j in range(i*i, max_prime, i*2):
                sieve[j//2] = False

In [16]:
timing_plot(sieve)

## Problems?

### Memory use
- Use packed data structure (e.g. struct module), encode 8 cells/byte
- Also skip multiples of 3 (only check numbers of form $6n \pm 1$)

### Need to allocate storage upfront
Often don't know in advance how much to allocate
(e.g. first 100k primes)

## Better?
Now don’t have to decide upper limit in advance, but slower

What about switching things around… for each prime store the next multiple higher than candidate, then we just just have to check if candidate is in the list, not multiple test divisions per candidate.
For each multiple in the list, we store the original prime, so that when we reach it, we we can add it to generate the next multiple. But it could be a multiple of more than one prime, so we have to store a list of source primes:

In [17]:
from itertools import count

def unbounded_sieve():
    state = {}
    for candidate in count(2):
        if candidate in state:
            for factor in state[candidate]:
                if candidate + factor in state:
                    state[candidate + factor].append(factor)
                else:
                    state[candidate + factor] = [factor]
            del state[candidate]
        else:
            yield candidate
            state[2 * candidate] = [candidate]

In [18]:
timing_plot(unbounded_sieve)

We can make a few optimisations:
Defaultdict so we don’t have to check if a number is present
We skip even numbers, and therefore even multiples of primes
When we find a prime, p, the first multiple we have to add to the state is p^2, because smaller multiples will have another factor less than p  (p*q, where q < p).


In [19]:
from collections import defaultdict
from itertools import count

def improved_sieve():
    yield 2
    state = defaultdict(list)
    for candidate in count(3, 2):
        if candidate in state:
            for inc in state[candidate]:
                state[candidate + inc].append(inc)
            del state[candidate]
        else:
            yield candidate
            state[candidate * candidate] = [2 * candidate]

In [20]:
timing_plot(improved_sieve)

In [21]:
import heapq
from itertools import count

def priority_queue():
    yield 2
    yield 3
    state = [(9, 2 * 3)]
    for candidate in count(5, 2):
        #print('dbg', candidate, state[0])
        if candidate == state[0][0]:
            while candidate == state[0][0]:
                mult, inc = state[0]
                heapq.heapreplace(state, (mult + inc, inc))
        else:
            yield candidate
            heapq.heappush(state, (candidate * candidate, 2 * candidate))

In [22]:
timing_plot(priority_queue)

In [23]:
from pyprimesieve import primes

def library(n):
    return primes(n)

timing_plot(library)

In [24]:
from IPython.display import Javascript

Javascript('''
sieve.src = 'resources/sieve.png';
''')

<IPython.core.display.Javascript object>