<a href="https://colab.research.google.com/github/djrakita/CPSC_487_587_Assignment_Demo/blob/main/WASPDerivatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and settings

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

print(jax.devices())


[cuda(id=0)]


# Derivative Engines

## Derivative Engine ABC

In [None]:
from enum import Enum
import numpy as np
import jax.numpy as jnp

class ADMode(Enum):
    Forward = 1,
    Reverse = 2

class DerivativeEngine:
    def __init__(self, f, n: int, m: int):
        self.f = f
        self.n = n
        self.m = m

    def call(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.f(x)

    def derivative(self, x) -> jnp.ndarray:
        raise NotImplemented("This must be implemented in subclass")

## Finite Differencing engine

In [None]:
import copy

class FDDerivativeEngine(DerivativeEngine):
    def __init__(self, f, n: int, m: int, jit_compile_f: bool = True, jit_compile_df: bool = True):
        super().__init__(f, n, m)

        if jit_compile_f:
            self.f = jax.jit(f)
        else:
            self.f = f

        if jit_compile_df:
            self.df = jax.jit(self.derivative_internal)
        else:
            self.df = self.derivative_internal


    def derivative_internal(self, x: jnp.ndarray) -> jnp.ndarray:
        out = jnp.zeros((self.m, self.n))

        p = 0.0000001

        f0 = self.f(x)
        for i in range(self.n):
            xh = copy.deepcopy(x)
            xh = x.at[i].add(p)
            fh = self.f(xh)
            col = (fh - f0) / p
            out = out.at[:, i].set(col)

        return out

    def derivative(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.df(x)



In [None]:
def f(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.array(x[0] * x[1])

e = FDDerivativeEngine(f, 2, 1)
e.derivative(jnp.array([5.0, 2.0]))

Array([[2.00000001, 4.99999999]], dtype=float64)

## JAX engine

In [None]:
class JaxDerivativeEngine(DerivativeEngine):
    def __init__(self, f, n: int, m: int, jit_compile_f: bool = True, jit_compile_df: bool = True, ad_mode: ADMode = ADMode.Reverse):
        super().__init__(f, n, m)

        self.ad_mode = ad_mode

        if jit_compile_f:
            self.f = jax.jit(f)

        if self.ad_mode == ADMode.Forward:
            self.jac_fn = jax.jacfwd(self.f)
        elif self.ad_mode == ADMode.Reverse:
            self.jac_fn = jax.jacrev(self.f)

        if jit_compile_df:
            self.jac_fn = jax.jit(self.jac_fn)

    def derivative(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.jac_fn(x)

In [59]:
def f(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.array(x[0] * x[1])

e = JaxDerivativeEngine(f, 2, 1)
e.derivative(jnp.array([5.0, 2.0]))

Array([2., 5.], dtype=float64)

## WASP engine

In [61]:
class WASPDerivativeEngine(DerivativeEngine):
    def __init__(self, f, n: int, m: int, jit_compile_f: bool = True, jit_compile_df: bool = True):
        super().__init__(f, n, m)

        if jit_compile_f:
            self.f = jax.jit(f)
        else:
            self.f = f

        if jit_compile_df:
            self.df = jax.jit(self.derivative_internal)
        else:
            self.df = self.derivative_internal

    def _derivative_internal_loop(self, x, recursive_call=False, f0=None):
        pass

    def derivative_internal(self, x: jnp.ndarray) -> jnp.ndarray:
        pass

    def derivative(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.df(x)