# NUMBA experiements

Create a function that can be easily adapted to calculate the heat transport from MITgmc

In [None]:
from netCDF4 import Dataset
import os

#use OFES to play around
url = 'http://apdrc.soest.hawaii.edu:80/dods/public_ofes/OfES/ncep_0.1_global_mmean/uvel'
dataset = Dataset(url, mode = 'r')

In [None]:
def transport(dataset, year, yy):
    #Inputs: dataset, year and latitude
    
    uvel = dataset.variables['uvel'][year, :, yy-2:yy+2, :]
    uvel[uvel == -9999] = 0
    
    u_cum = 0
    for i in range(uvel.shape[2]): #lon
        for j in range(uvel.shape[1]): #lat
            for n in range(uvel.shape[0]): #depth
                u_cum = u_cum + uvel[n,j,i]#later, here multiply by temperature a thermal coeff
    
    return u_cum

In [None]:
%%time
transport(dataset, 700, 10)

## Try NUMBA

In [None]:
from numba import jit

In [None]:
@jit(nopython=True)
def calc_cum(uvel):
    u_cum = 0
    for i in range(uvel.shape[2]): #lon
        for j in range(uvel.shape[1]): #lat
            for n in range(uvel.shape[0]): #depth
                u_cum = u_cum + uvel[n,j,i]#later, here multiply by temperature a thermal coeff
    return u_cum

def transport_numba(dataset, year, yy):
    #Inputs: dataset, year and latitude
    
    uvel = dataset.variables['uvel'][year, :, yy-2:yy+2, :]
    uvel[uvel == -9999] = 0
    
    u_cum = calc_cum(uvel)
    
    return u_cum

In [None]:
%%time
transport_numba(dataset, 700, 10)

More than 3 times faster!


This was using only one processor.

Let's try now with two!

**NOTA BENE** multuprocessing is incompatible with jupyter atm

In [None]:
import multiprocessing

#redefining this because passing multiple arguments to multiprocessing is quite involuted 
def transport_numba(year):
    #Input: year
    url = 'http://apdrc.soest.hawaii.edu:80/dods/public_ofes/OfES/ncep_0.1_global_mmean/uvel'
    dataset = Dataset(url, mode = 'r')
    yy=10
    
    uvel = dataset.variables['uvel'][year, :, yy-2:yy+2, :]
    uvel[uvel == -9999] = 0
    
    u_cum = calc_cum(uvel)
    
    return u_cum

In [None]:
#testing on 5 years
years = range(700,705)

t1 = time.time()
for i in years:
    transport_numba(i)
t2=time.time()
print(t2-t1)

In [None]:
years = [str(x) for x in range(700,705)]

pool = multiprocessing.Pool(processes=2)
t1 = time.time()
r = pool.map(transport_numba, years)
t2 = time.time()
print(t2-t1)
pool.close()

This runs almost 2 times faster.

**NOTA BENE** Specifying the output type speeds up the code a tiny bit more.

In [None]:
@jit(nopython=True,'f8(f4[:,:,:])')