# The joy of fast cars 

Hi guys, so in this lecture I wanted to talk about speed in Python. Now, as beginners I want to emphasize that you guys should be focusing on writing *correct* programs and not pay too much mind to how fast they run. But with that said it is my opinion that if you want to teach people any skill **THE ABSOLUTE BEST THING you can try to do is inspire them.** Knowledge is easily lost with time but passion tends to stick around a lot longer. 

Of course, inspiring people is not easy, and in this particular case I don't know anything about you or your life story; which means I can't really 'connect' with you on any meaningful/personal level. But what I can do is share a few of things I enjoy doing/thinking about and hope some of that enthusiasm rubs off on you. 

The 'joy of fast cars' is a somewhat cryptic title, but the explanation of it is fairly straight-forward; one of the things I find a lot of fun is trying to make my code more efficient; I genuinely enjoy the process of taking a bit of code and trying to come up with ways to make it faster.  In my mind programming is at its most interesting when you can can look past stuff like language syntax and instead focus of the very nature of the problem itself. The aim of today is to try to get you a glimpse of that. 

Alright, lets start with some code; I want a list of all the prime numbers from 0 to 400.

In [1]:
# Attempt #1
## The joy of fast cars 

# Prime detection...
def is_prime(num):
    """Returns True if number is prime, False otherwise"""
    if num <= 1: return False    # negetive numbers are not prime
    # check for factors
    for i in range(2,num): # for loop that iterates 2-to-num. Each number in the iteration is called "i"
        if (num % i) == 0: # modular arithmetic; this asks if num is divisible by i (with no remainder).
            return False
        # If we have iterated through every number upto num without finding a divisor it must be prime.
    return True


# Making the list:
def get_primes(b):
    primes = []
    for num in range(0,b+1):
        if is_prime(num): # <== Yes you can call functions inside other functions!
            primes.append(num)
    return primes

print(get_primes(400))

[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397]


Alright, it seems like we are off to a reasonable start! The first question we want to ask ourselves is "just how fast is it?" If we want to improve the speed then it is obviously helpful to have a benchmark. On my machine I asked Python to calculate all the primes from 0-to-30,000, using this code results in:
    
> 33251 function calls in 10.761 seconds
 
The seconds is pretty slow I think. The next step is to find our ‘bottleneck’:

> Function: is_prime ... calls: 30001 ... Total Time: 10.744

In other words, the is_prime function is responsible for about 90% of the calls and 99.8% of the time. Thus, if we try to fix another part of the program the performance gains would be trivial at best. 

For example, list comprehensions are faster than for-loops and so I could speed the ‘get_primes’ function by doing that. For arguments sake lets suppose using a list comprehension can speed it up by 50%. That’s a huge improvement! But when we look at the total time we see that get_primes took 0.016 seconds on my machine. So a 50% improvement would speed us up by about 0.008 seconds. We are talking 1000th of a second savings here; To hell with that!

Okay, so the function we need to improve is the ‘is_prime’ function. I think the first line of code to study is the for-loop:

>  for i in range(2, num):

So this is where the fun begins! To solve this puzzle we need to think logically and be a bit creative here. Improving the speed of this line of code is not simply “know more Python”. Rather, we need to think logically and apply a splash mathematics. Here, let me show you something:

In [14]:
for i in range(1, 20):
    print(i, "--", 20/i)

1 -- 20.0
2 -- 10.0
3 -- 6.666666666666667
4 -- 5.0
5 -- 4.0
6 -- 3.3333333333333335
7 -- 2.857142857142857
8 -- 2.5
9 -- 2.2222222222222223
10 -- 2.0
11 -- 1.8181818181818181
12 -- 1.6666666666666667
13 -- 1.5384615384615385
14 -- 1.4285714285714286
15 -- 1.3333333333333333
16 -- 1.25
17 -- 1.1764705882352942
18 -- 1.1111111111111112
19 -- 1.0526315789473684


So this code is dividing the number 20 by 'i', where 'i' is 1-to-20. The salient point here is that numbers past 11 are not whole numbers. This makes a lot of sense when you think about it; the minimum number of ‘parts’ we can split X into (besides 1) is two. Thus, when we start looking at numbers greater than n/2 the solution will never be a whole number. And that stands for all numbers, not just 20. 

Now, we can we use this information to make our prime search smarter. As things currently stand proving 1499 is prime requires about that many steps; our code is (at the moment) asking if numbers like 1001, 1002, 1003, ... are divisors of 1499 but as the above logic demonstrates since we want to find divisors of n these checks are actually unnecessary. So, If we stop iterating at the number 750 we can approximately half the time it takes to find a prime number and still have a correct solution. 

> for i in range(2, num//2):

As a quick note, we are using integer division here because the range function cannot handle floats. Now, before we run the benchmark though, we need to check for correctness; whenever you make a changes, even If it is a small one, you should test it on a few inputs. We want to check we haven’t broken anything with our change (more on ‘regression testing’ later). With this in mind, I ran the following code on my machine *(where is_prime is the old function and is_prime2 is with the change):*
  
    x = [i for i in range(0,30000) if is_prime(i)]
    y = [i for i in range(0,30000) if is_prime2(i)]
    print(x == y) ---> False

We have a bug batman! What went wrong?  To find out, I ran the following bit of code:
    
    x2 = set(x)
    y2 = set(y)
    x2.symmetric_difference(y2) ---> {4}

I converted the lists to sets because sets have this handy method for quickly telling the difference between two items. It turns out we have two lists, each with 3200+ numbers and the only difference is that one of these lists contains the number 4 and the other does not. So what’s the problem?  

Well, our new function uses:

	 range(2, n//2)

and:
	
    4//2 == 2	

In short, our change to the function works great for large input but breaks for tiny inputs. I think the simplest fix to this problem is to use n//+1, which should fix our error with a insignificant performance cost. 

In [2]:
n = 4
for i in range(2, n//2):
    print("(1)...", i)
    # Nothing is printed! WTF!!

# Okay, attempted fix:
for i in range(2, n//2+1):
    print("(2)..." , i)


(2)... 2


I ran our correctness test and this time it worked. Okay, lets benchmark!

> 33251 function calls in 5.386 seconds

So this small change has roughly halved the amount of time it to get all primes upto 30,000. The function calls stayed the same, which makes sense when you think about it (we are still checking if every number 0-to-30,000 are prime, afterall).

Are we done? Well actually I can think of a few more tweaks...

Let's think about the nature of primes for one moment. The definition of a prime is that it is only divisible by itself and 1. And since an even number is, by definition, divisible by 2 we know that the only prime that is even is 2. 

Let’s  think of a large odd number (not necessarily prime). Our code is going to ask if 2,4,8,10,12… are divisors. But from the definition of even numbers we know that if 12, 18, 22, etc are divisors of X then so must 2. Which therefore means if 2 is not a divisor then neither is 6,8,100,102, etc. 

In short, checking for 2 is equivalent to checking for all even numbers. Can we apply this insight to our code? I think so:

In [None]:
# Prime detection...
def is_prime2(num):
    if num > 1:
        if num == 2: # covering special case, 2 only even prime.
            return True
        elif num % 2 == 0: # tests n for all even numbers
            return False
        
        else:
            for i in range(3,num//2+1, 2): # range function starts at odd number with a step of 2. 
                if (num % i) == 0: 
                    return False
            return True
    else:
        return False

So this code has a special check for the number 2 itself, next up we check if n is even. And finally we enter our range function, and because of the use of step we are only considering odd numbers (3, 5, 7, ..). At first glance it seems as if our function has more or less halved the search space once again. So now I'm expecting a run time of about 2.5 seconds for 0-to-30,000 primes. 

> 33250 function calls in 3.439 seconds

So, we got a noticeable speed up here but not quite the 2.5 seconds I was expecting. For what its worth measuring time isn’t really a great way to measure the speed of programs because it is highly variable; CPU temp, workload, available RAM etc could all affect these numbers (which is why I ran each test a handful of times), but its also that case measuring time is easier than to prove the speed of an algorithm with maths. 

I decided to run one more test but now with an input of 200,000. I’m hoping that the large input size will ‘drown out’ the noise and give us a better approximation. 

> func1: 217989 function calls in 254.250 seconds
> func2: 217989 function calls in 127.124 seconds

The number of function calls doesn’t change because ‘get_primes’ hasn’t changed. We could optimise that function to only call is_prime on odd numbers but as I pointed out above get_primes is responsible for about 0.002 % of the time, so halving the number of calls would NOT offer much in terms of performance. Anyway, since 254/2 = 127 it seems like the intuition was correct; checking for even divisors once and then only looking for odd numbers seems about twice as fast.

Can we do better?  Perhaps, but I’m out of ideas at this point. But hey, google is a treasure trove of information, I wonder if there is some other maths ‘trick’ out there we could use…

After googling, I found out that apparently we can use the square root of n! And here I’m going to reproduce a maths proof which I found [here](https://stackoverflow.com/questions/5811151/why-do-we-check-up-to-the-square-root-of-a-prime-number-to-determine-if-it-is-pr).

## Proof

Suppose ‘M’ is the square root of ‘N’. This means M\*M = N.
If N is not prime, then N can be written as N = A\*B 
Thus, M\*M = A\*B

Now there are three possible cases:

1. A > M then B < M
1. A = M then B = M
1. A < M then B > M

In all three cases the min(A, B) <= M. Thus we only need to search until M in order to find a factor of N and thus prove N is not prime. Neat huh?

## Benchmarking Sqrt(N)

Alright, how can we implement this. Well, it took a bit of testing, but eventually I came up with this line (after importing the math module, of course):

> for num in range(3, math.ceil(math.sqrt(num))+1, 2):

Sqrt(n) in many cases is not a whole number in some cases and as discussed elsewhere range requires an integer. That’s where math.ceil comes in, it rounds n up to the next integer (eg. math.ceil(6.0003) ---> 7) I then add one to make sure small numbers like 3,4 are handled correctly. 

How much faster do we think this function will be?  Well, our current function runs roughly N/4 operations So if the input is somewhere close to 10,000 we are probably looking at about 2,500 numbers (when N is prime). Using sqrt(N) AND only looking at odd numbers means we have to check approximately 50 values. In short, we can verify primes are in fact Prime orders of magnitudes faster than we were previously. I'm going to guess our new function will take about 10 seconds for an input of 200,000 (which is 27 times faster than our current function). After testing for correctness, I ran the benchmark:

> 417987 function calls in 1.351 seconds

Oh boy, using square root is not 27 times faster as I guessed, it seems over 100 times faster. I guess in hindsight this is not so surprising, our previous improvements were *linear* in nature (for every N, we check upto N/4), but sqrt scales much much slower than that. 

Lets take a breath for a moment; we started with a function that took 10 seconds to find all primes less than 30,000. Now we have a function that spends 10 times less time to solve input 7 times larger!

Are we done?

## Rethinking the problem

So, our task is create lists of primes upto N and so far our strategy for improving performance is to reduce the number of divisors we need to check. 

The interesting thing about our current method is that right now we are searching for ‘needles in haystacks’. **And right now our method for improvement is to find ways to search the haystack faster.** 

What if we could find a method for *generating* prime numbers via some mathematical operation instead of *searching* for them? Might that be faster?

After a bit of research, it seems like the ["Sieve of Atkin"](https://en.wikipedia.org/wiki/Sieve_of_Atkin) is the fastest known algorithm. Seeing as I felt a bit lazy I decided to google someone else’s solution rather than try to implement this algorithm myself. I found [this bit of code](https://gist.github.com/mineta/7840849) on github and decided to benchmark it (after testing for correctness, of course).

    For input size 200,000:
    Sqrt Func:   417987 function calls in 1.326 seconds
    Atkin Func:   18439 function calls in 1.258 seconds

Atkin is a tiny tiny bit faster here but these results are well within our margin of error. Lets see how well these functions handle inputs of size ten-fucking-million!

    For input size 10,000,000:
    Sqrt Func: 20664582 function calls in 273.478 seconds
    Atkin Func:  667749 function calls in  65.531 seconds


So the Atkin algorithm is a lot faster once we start looking at massive lists of prime numbers. But, the main lesson I want you to learn here is that optimisation can be thought of as being two seperate ideas: the first idea is low-level tinkering, in other words, we look at all the small details and see if we can save a byte or two of memory here or there. But then there is ‘high-level’ optimisation, and that is where we try to come up with an entirely new (and hopefully better) strategy for solving the problem. In this lecture we started *searching* needles in haystacks and then we thought of new idea, which was to *generate* primes.  

## Homework assignment

In this weeks (optional) homework, your task it to try and write a bit of code that is *faster* than my code. And there is going to be two basic ways to do it; you can get your hands dirty and try some low-level optimisation or you can ditch all that and favour a high-level approach.

Unlike most of the homeworks, this more about being clever than it is about understanding Python.

    The Challenge: BEAT MY TIME!!
    
The below code will create a list of all *ODD* square numbers starting at 1 and ending at x. Example:

    If x is 100, the squares are:
    [1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
    
    Of which, we only want the odd numbers:
    [1, 9, 25, 49, 81]

A few hints...
* Remember that "a and b" can be slower than "b and a" (see logic lecture). Basically, **the order** in which you do things can make a difference.
* Finding a needle in a haystack is probably slower than [BLANK] ? 

Please study the code below. Your jump is to either make it faster by tinkering with it. Or alternatively you may wish to use your own algorithm.

In [2]:
import math                 

# My code, this is the function to beat! How can you improve it?
def squares(x):
    lst = []                    
    for number in range(1, x+1):    
        square = math.sqrt(number)     # We call the square_root function on the number.
        if square.is_integer():
        # is_integer is a float method that return true if the the number can be represented as an integer. 
        # for example, 4.0 = True, 4.89 = False 
            
            if number % 2 != 0: # checks if number is odd.
                lst.append(number) 
    return lst

print(squares(1000))

[1, 9, 25, 49, 81, 121, 169, 225, 289, 361, 441, 529, 625, 729, 841, 961]


In [1]:
import math 

def my_squares(x):
    # YOUR CODE GOES HERE !!!
    # x --> int
    # Return a list of odd squares...
    
    # Note, don't change name of this function, if you do, I cant test it!

    pass # <--- delete this line!

##################################
# MY CODE, a.k.a THE CODE TO BEAT!
# Please do not change this!!!

def hamster_squares(x):
    lst = []                    
    for number in range(1, x+1):    
        square = math.sqrt(number)    
        if square.is_integer():
            if number % 2 != 0: 
                lst.append(number) 
    return lst

################## THE CONTROL PANEL ################################
#####################################################################
verbose = False # set to True if you want more detailed statistics...
X = 5000000 
# Lower X if tests are taking too long on your machine. 
# Raise this value if you want higher accuracy.
#####################################################################

teacher = hamster_squares(10000)
student = my_squares(10000)
correct = None

# TEST 1: CORRECTNESS
if teacher == student:
    print("CORRECTNESS TEST = PASSED")
    correct = True
else:
    print("CORRECTNESS TEST = FAILED", "NOW TRYING TO DEBUG...", sep="\n")   
    
    # here is a bit of code to help you find the problem(s)!   
    # returning a list?
    if not isinstance(student, list):
        print("... Try returning a list next time, not a bloody {} !".format(type(student)))
    
    # too many/too few items?
    elif len(teacher) != len(student):
        print(".... Your list has {} items, it should have {} items".format(len(student), len(teacher)))
        
    # small numbers correct?
    elif student[:10] != teacher[:10]:
        print("... Start of list incorrect.\nYOURS: {}\nEXPECTED: {}".format(student[:10], teacher[:10]))
    
    # testing for same items. Note that this test DOES NOT take order into consideration.
    else:
        ts = set(teacher)
        st = set(student)
        diff = ts.symmetric_difference(st)
        if diff:
            print("... The lists contain different numbers, these are... \n {}".format(diff))
    

# SPEED TESTS ... (just ignore this code)
if correct:
    print("...Now testing speed. Please, note, this may take a while...\n", 
          "Also, I'd advise a margin or error of about +- 0.2 seconds\n")
    
    # Feel free do ignore this fucntion, even I dont understand it!
    def profile(function, *args, **kwargs):
        """ Returns performance statistics (as a string) for the given function.
        """
        def _run():
            function(*args, **kwargs)
        import cProfile as profile
        import pstats
        import os
        import sys; sys.modules['__main__'].__profile_run__ = _run
        id = function.__name__ + '()'
        profile.run('__profile_run__()', id)
        p = pstats.Stats(id)
        p.stream = open(id, 'w')
        p.sort_stats('tottime').print_stats(20)
        p.stream.close()
        s = open(id).read()
        os.remove(id)
        return s
    
    def string(i, func, detail): 
        i = i.split("\n")     
        s= "✿ Stats for {} function... \n{}".format(func, i[2])
        if detail:
            s = s + "\n" + "\n".join(i[3:-7]) + "\n"
        return s
    
    print("-------- Solution Comparision, where input size is {}. -------- \n".format(X))
    
    hs = profile(hamster_squares, X)
    print(string(hs, "Teacher's Squares", verbose))

    ss = profile(my_squares, X)
    print(string(ss, "'YOUR'", verbose))

CORRECTNESS TEST = FAILED
NOW TRYING TO DEBUG...
... Try returning a list next time, not a bloody <class 'NoneType'> !
