# Topic 6a: Improving code

Our strategy is this course is always to write the simplest implemention of an idea first. One should not over-complicate our code on the first pass, because many parts of your idea may fail. We should always test out that the basic idea works.

Once something is working and debugged, we might want to make it more elegant or run faster.

In [91]:
import numpy as np

### Numpy tricks

When we think about what a computer is doing, it often useful to think about performing for loops or while loops to run through lists of numebres / data. However, when it comes to numpy, we really want to avoid loops as much as possible: loops are slow and numpy is designed to do the loops behind the scenes in a faster way.

The most obvious time we might use for loops is for logical statements: if,<,> etc. 

In [116]:
a=np.random.randn(10)
b=np.random.randn(10)
print(a<b)

[False False False  True  True  True False  True  True  True]


But now suppose we want to return $a^2$ if $a>b$ and $b^2$ if $b\geq a$. You can't use an if statment for the whole array

In [117]:
if a<b:
    print(a**2)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

The obvious solution is a for loop:

In [120]:
c=(a<b)
out1=np.zeros(len(a))
for i in range(len(a)):
    if c[i]:
        out1[i]=a[i]**2
    else:
        out1[i]=b[i]**2
print(out1)

[4.77808459e-02 1.01905449e-01 2.21222039e-03 1.99083342e-01
 3.34851358e-02 1.53966435e+00 1.40768957e-01 8.89941634e-01
 3.87085309e+00 1.07781536e+00]


One minor improvement is that we can use enumerate to simpify the loop

In [123]:
c=(a<b)
out2=np.zeros(len(a))
for i,item in enumerate(c):
    if item:
        out2[i]=a[i]**2
    else:
        out2[i]=b[i]**2
print(out1)
print(out2)

[4.77808459e-02 1.01905449e-01 2.21222039e-03 1.99083342e-01
 3.34851358e-02 1.53966435e+00 1.40768957e-01 8.89941634e-01
 3.87085309e+00 1.07781536e+00]
[4.77808459e-02 1.01905449e-01 2.21222039e-03 1.99083342e-01
 3.34851358e-02 1.53966435e+00 1.40768957e-01 8.89941634e-01
 3.87085309e+00 1.07781536e+00]


To really go futher, we want to use numpy's internal logical operations. The mosy useful of these is np.where(). It works as follows: lets start with one list and we want to do something only when the entry is > 0. First we can try just np.where(list)

In [96]:
in1=np.array([10,4,-2])
np.where(in1>0)

(array([0, 1]),)

what this spits out is an array of the points where the condition is true. We can use this to return only those values. This is the same as just putting the array of true and false values.  We could already use this, if we realize that if we put an array of values

In [97]:
print(a)
print(a[np.array([0,2,3])])


[-1.20262133  0.65775847  0.65264411 -0.29702135 -0.58284805 -0.78547945
 -0.12167353  0.99295514  1.41426119  0.12115153]
[-1.20262133  0.65264411 -0.29702135]


The same works for a list/array of true and false (bools)

In [98]:
print(b<0)
print(a[b<0])

[False False  True  True  True  True  True  True  True  True]
[ 0.65264411 -0.29702135 -0.58284805 -0.78547945 -0.12167353  0.99295514
  1.41426119  0.12115153]


In this why, np.where(logical statment) is doing something relatively simple, that we could have done without it.

In [99]:
locs1=np.where(in1>0)
print(in1[locs1],in1[in1>0])

[10  4] [10  4]


Of course, np.where is meant to do more. It is deside as follows: we call with 3 arguements np.where(condition,out_true,out_false)

In [100]:
print(np.where(in1>0))
print(np.where(in1>0,5,-100))

(array([0, 1]),)
[   5    5 -100]


We can give the outputs for true and false in the form of arrays:

In [101]:
in2=np.array([2,9,15])
in3=np.array([20,25,-3])
np.where(in2<in3,in2**2,in3**2)

array([ 4, 81,  9])

We can guess what this is doing as follows: let's make the arrays for the outputs and the arrays of true

In [102]:
out_true=in2**2
out_false=in3**2

truelist=in2<in3

We will also need to know which elements are false (ie NOT True)

In [103]:
falselist=np.logical_not(truelist)
print(truelist,falselist)

[ True  True False] [False False  True]


In [107]:
out_comb=np.zeros(len(in2))
out_comb[truelist]=out_true[truelist]
out_comb[falselist]=out_false[falselist]
print(out_comb)

[ 4. 81.  9.]


So we see that np.where can do a lot of logical operations for us. Now we see how we can get our original problem solved in 1 line:

In [124]:
print(np.where(a<b,a**2,b**2))
print(out1)

[4.77808459e-02 1.01905449e-01 2.21222039e-03 1.99083342e-01
 3.34851358e-02 1.53966435e+00 1.40768957e-01 8.89941634e-01
 3.87085309e+00 1.07781536e+00]
[4.77808459e-02 1.01905449e-01 2.21222039e-03 1.99083342e-01
 3.34851358e-02 1.53966435e+00 1.40768957e-01 8.89941634e-01
 3.87085309e+00 1.07781536e+00]


In addition, we can use numpy to do other logical operators if needed, so we aren't at the mercy of np.where

In [129]:
test1=np.array([0,5,-10])
test2=np.array([-1,1,-5])
print(np.logical_or(test1<test2,test2>0))
print(np.logical_and(test1<test2,test2>0))

[False  True  True]
[False False False]


### Tricks for Functions: using args or kwargs

If you look at the documation for most functions, you will see args and kwargs.  These are special kinds of inputs that can be useful:

The first one is args, which acts like a list, without actually needing it to be a list. Now, the name args is just a common use.  What really matters is the use of the star:

In [147]:
def sum_X(X,*argv):
    out=0
    for arg in argv:
        out+=arg
    return X*out

In [148]:
print(sum_X(2,1),sum_X(2,1,3),sum_X(2,1,3,4))

2 8 16


Obviously we could have done the same if we just input a list. However, it is most useful when you want to be able to call functions that take all kinds of different inputs, depending on the function

In [87]:
def func_of_func(fun,*args):
    return fun(2,*args)

In [88]:
def fun1(x,p=1,q=2):
    return x*p*q
def fun2(x,p=1):
    return x*p

In [89]:
func_of_func(fun1,4,4)


32

In [90]:
func_of_func(fun2,4)

8

Two stars, ususally called kwargs, is like a dictionary without calling a dictionary. Here you need to input both a key and a value. Recall how a dictionary works:

In [185]:
dict1={'a':'hi','b':2}
print(dict1.items())
print(dict1.keys())
print(dict1.values())

dict_items([('a', 'hi'), ('b', 2)])
dict_keys(['a', 'b'])
dict_values(['hi', 2])


We can do something similar with kwargs. We think of the kwargs just like a dictionary

In [186]:
def kwargs_test(**kwargs):
    for arg,n in kwargs.items():
        print('hi '+arg)
        print(n*4)

The key difference is that we call it in the form string = something. In this case, we are assuming string=name

In [187]:
kwargs_test(alice=2,bob=50)

hi alice
8
hi bob
200


But we could make a version with string = string

In [191]:
def kwargs_test2(**kwargs):
    for arg,name in kwargs.items():
        print('hi '+arg)
        print('hi '+name)

In [197]:
kwargs_test2(alice='a',bob='50')

hi alice
hi a
hi bob
hi 50


Using dictionary ideas, we can also just run over the keys or values

In [194]:
def kwargs_test3(**kwargs):
    for key in kwargs.keys():
        print(key)
    for val in kwargs.values():
        print(val)

In [195]:
kwargs_test3(alice=2,bob=50)

alice
bob
2
50


Now you can understand how plt.plot (ax.plot) work.  All of the special features of the line are input as kwargs, as we already know from experience. Now we can understand how we could have done this ourselves, if we wanted to.