In [None]:
########################################
## Evaluate this cell before starting ##
########################################

from collections import defaultdict
from itertools import count, islice
import time

from IPython.core.magics.execution import _format_time as format_time
from IPython.display import Javascript

Javascript('''
function swap(el, src, t) {
  console.log("swap", el, src, t);
  var old = el.src;
  el.src = src;
  setTimeout(function() {el.src = old;}, t);
}

sieve.onclick = function() {
  swap(document.getElementById("sieve"), 'sieve.gif', 37000);
};
''')

# Optimised Primes

Emlyn Corrin

<img data-gifffer="sieve.gif" />

![](prime.png)
<!--- Image (public domain) from:
https://www.flickr.com/photos/114305749@N08/24438440681
-->

## Why?

- Online programming contests (Project Euler etc.)

- Mathematical or programming exercise

- Because it's fun!

## What is a prime?

A prime number (or a prime) is a natural number greater than 1 that has no positive divisors other than 1 and itself.
<div style="text-align: right">&mdash; Wikipedia</div>

In [None]:
def is_prime(n):
    # Greater than 1
    if n <= 1:
        return False
    # Any positive divisors (> 1 and < n)?
    for i in range(2, n):
        if n % i == 0:
            return False
    # If not, it must be a prime
    return True

## Let's generate a few

In [None]:
[i for i in range(20) if is_prime(i)]

In [None]:
print(', '.join(str(i) for i in range(2000) if is_prime(i)))

## What about generating them on demand

In [None]:
from itertools import count, islice

def primes1():
    for i in count():
        if is_prime(i):
            yield i
    
list(islice(primes1(), 20))

## But how fast is it?

In [None]:
import time

def test(gen, num):
    t0 = time.perf_counter()
    numbers = list(islice(gen, num))
    t = time.perf_counter() - t0
    print("Took: {}".format(format_time(t)))
    print("{0}th prime is {1}".format(len(numbers), numbers[-1]))

In [None]:
test(primes1(), 1000)

## Can we make it faster?

### What about skipping all even numbers (apart from 2)?

In [None]:
def primes2():
    def is_prime(n):
        for i in range(3, n, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

### How much faster is it?

In [None]:
test(primes2(), 1000)

## Can we reduce the number of checks further?

### Yes!
Factors always come in pairs: $n = f*g$ (except square numbers where you can have $f = g = \sqrt n$).  
If $f \leq \sqrt n$ then $g \geq \sqrt n$, and vice versa.  
So if $n$ has any prime factors, at least one of them must always be $\leq \sqrt n$.  
So we only have to check up to $\sqrt n$, not to $n$.

In [None]:
from math import sqrt

def primes3():
    def is_prime(n):
        for i in range(3, int(sqrt(n)) + 1, 2):
            if n % i == 0:
                return False
        return True
    yield 2
    for i in count(3, 2):
        if is_prime(i):
            yield i

## How much faster is this?

In [None]:
test(primes3(), 1000)

## Is this the best we can do?

We are still checking more numbers than necessary:  
e.g. once we've tested for divisibility by 3 and 5,  
we shouldn't need to test 9, 15, 25, 30, 45... etc.

i.e. we only need to check for divisibility by primes.

## The sieve of Eratosthenes

1. start with a grid of numbers, from 2 to max_prime
2. find first (next) unmarked number, return that as a prime
3. mark all multiples of it (actually just from n² onwards)
4. go back to 2.
<img src="sieve.png" id="sieve" />

In [None]:
def primes4(max_prime):
    sieve = [True] * max_prime
    for i in range(2, max_prime):
        if sieve[i]:
            yield i
            for j in range(i, max_prime, i):
                sieve[j] = False

In [None]:
def primes5(max_prime):
    sieve = [True] * (max_prime // 2)
    yield 2
    for i in range(3, max_prime, 2):
        if sieve[i//2]:
            yield i
            for j in range(i*i, max_prime, i*2):
                sieve[j//2] = False

In [None]:
test(primes4(2000000), 100000)

In [None]:
test(primes5(2000000), 100000)

## Problems?

### Memory use
- Use packed data structure (e.g. struct module), encode 8 cells/byte
- Also skip multiples of 3 (only check numbers of form $6n \pm 1$)

### Need to allocate storage upfront
Often don't know in advance how much to allocate
(e.g. first 100k primes)

## What about storing a list of primes so far, and only test dividing by those?

In [None]:
def primes6():
    primelist = [2]
    yield 2
    for candidate in count(3, 2):
        isprime = True
        for p in primelist:
            if candidate % p == 0:
                isprime = False
                break
        if isprime:
            yield candidate
            primelist.append(candidate)

In [None]:
test(primes6(), 1000)

## Better?
Now don’t have to decide upper limit in advance, but slower

What about switching things around… for each prime store the next multiple higher than candidate, then we just just have to check if candidate is in the list, not multiple test divisions per candidate.
For each multiple in the list, we store the original prime, so that when we reach it, we we can add it to generate the next multiple. But it could be a multiple of more than one prime, so we have to store a list of source primes:

In [None]:
def primes7():
    state = {}
    for candidate in count(2):
        if candidate in state:
            for factor in state[candidate]:
                if candidate + factor in state:
                    state[candidate + factor].append(factor)
                else:
                    state[candidate + factor] = [factor]
            del state[candidate]
        else:
            yield candidate
            state[2 * candidate] = [candidate]

In [None]:
test(primes7(), 100000)

We can make a few optimisations:
Defaultdict so we don’t have to check if a number is present
We skip even numbers, and therefore even multiples of primes
When we find a prime, p, the first multiple we have to add to the state is p^2, because smaller multiples will have another factor less than p  (p*q, where q < p).


In [None]:
from collections import defaultdict

def primes8():
    yield 2
    state = defaultdict(list)
    for candidate in count(3, 2):
        if candidate in state:
            for factor in state[candidate]:
                state[candidate + 2 * factor].append(factor)
            del state[candidate]
        else:
            yield candidate
            state[candidate * candidate] = [candidate]

In [None]:
test(primes8(), 100000)