## Use np.apply_along_axis for faster 1d operations on data

When we have patch data and want to apply a opteration/function only at the last axis (spectrum), we alway reshape the data in a 1d form, apply the function and reshape the data in the original form. 
With the function np.apply_along_axis we don't need to reshape the data and so it should be faster.

In [1]:
from sklearn.preprocessing import StandardScaler
from datetime import datetime

import numpy as np

np.random.seed(42)

In [2]:
size_1d = [1000, 15] # samples, spectrum
size_3d = [1000, 3, 3, 15] # samples, patch0, patch1, spectrum

array_1d = np.arange(np.prod(size_1d), dtype=np.float32).reshape(size_1d)
array_2d = np.arange(np.prod(size_3d), dtype=np.float32).reshape(size_3d)

In [3]:
def old_transform(X: np.ndarray) -> np.ndarray:
    scaler = StandardScaler()
    
    return scaler.fit_transform(X.T).T


def old_scale_X(X: np.ndarray) -> np.ndarray:
    _3d = False
    shape = []
    
    if len(X.shape) > 2:
        _3d = True
        shape = X.shape
        X = np.reshape(X, newshape=(np.prod(X.shape[:-1]), X.shape[-1]))
        
    X = old_transform(X)
    
    if _3d:
        X = np.reshape(X, newshape=shape)
    
    return X

In [4]:
def new_transform(X: np.ndarray) -> np.ndarray:
    def transform_func(spectrum: np.ndarray):
        return (spectrum - spectrum.mean(axis=-1)) / spectrum.std(axis=-1)
    
    return np.apply_along_axis(func1d=transform_func, axis=-1, arr=X)

def new_scale_X(X: np.ndarray) -> np.ndarray:
    return new_transform(X)

In [7]:
def compate_scale_functions(arr: np.ndarray):
    def step(arr: np.ndarray, func):
        start = datetime.now()
        scaled = func(arr)
        stop = datetime.now()
        
        return scaled, start, stop
    
    old = step(arr=arr, func=old_scale_X)
    new = step(arr=arr, func=new_scale_X)
    
    print(f"Old function needs {old[2] - old[1]}.")
    print(f"New function needs {new[2] - new[1]}.")
    
    if np.all(np.isclose(old[0], new[0])):
        print("Arrays from the functions are the same.")
    else:
        print("Arrays from the functions are not the same.")

In [8]:
print("1d Test")
compate_scale_functions(arr=array_1d)

print("3d Test")
compate_scale_functions(arr=array_2d)

1d Test
Old function needs 0:00:00.008265.
New function needs 0:00:00.155586.
Arrays from the functions are the same.
3d Test
Old function needs 0:00:00.011223.
New function needs 0:00:01.180703.
Arrays from the functions are the same.
