<a href="https://colab.research.google.com/github/dhassan24/DeepLearning_For_Biology/blob/main/DLFP_Chapter1_Intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/dhassan24/DeepLearning_For_Biology

Cloning into 'DeepLearning_For_Biology'...
remote: Enumerating objects: 4, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 4 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (4/4), done.


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
from __future__ import print_function
#Mean squared error, a common function for selecting model performance in deep learning
#Since you want to minimize the error

def mean_squared_error (y_true, y_pred):
  squared_error = (y_true-y_pred)**2
  return np.mean(squared_error)

#Example of usage:

y_true = np.array([0.8, 1.5, -0.34, 1.91, 0.49])
y_pred = np.array([0.25, 1.2, -0.07, 1.32, 0.67])

print('MSE:', mean_squared_error(y_true, y_pred))

MSE: 0.16917999999999997


In [5]:
mean_squared_error(y_true, y_pred)

np.float64(0.16917999999999997)

Can also imporve the function by adding hints to specify that the inputs are np.ndarray objects and the return type is a float

In [6]:
def mean_squared_error (y_true: np.ndarray, y_pred:np.ndarray) -> float:
  """
  Calculate the Mean Squared Error (MSE) between two NumPy arrays.

  Args:
    y_true (np.ndarray): The true target values.
    y_pred (np.ndarray): The predicted target values.
  """
  squared_error = (y_true-y_pred)**2
  return np.mean(squared_error)

Decorators are functions that modify the behavior of other functions or mtehods; used to enchance performance, cache results, or log function behavior.

One of the more common decorators in JAX (our HP numerical computing + ML platform) is jax.jit. It performs JIT complication to accelerate code execution.

In [7]:
#Ex: a function that takes a JAX array, raises all values to the 10th power,
#and then sums them:

import jax
import jax.numpy as jnp

def compute_ten_power_sum(arr: jax.Array) -> float:
  return jnp.sum(arr**10)

arr = jnp.array([1, 2, 3, 4, 5])
compute_ten_power_sum(arr)

Array(10874275, dtype=int32)

We can speed up this function in one of two ways:

In [9]:
#First way to speed up JAX function:

%%time
jitted_compute_ten_power_sum = jax.jit(compute_ten_power_sum)
jitted_compute_ten_power_sum(arr)

CPU times: user 1.62 ms, sys: 69 µs, total: 1.69 ms
Wall time: 2.44 ms


Array(10874275, dtype=int32)

In [10]:
%%time
@jax.jit
def compute_ten_power_sum(arr: jax.Array) -> float:
  return jnp.sum(arr**10)

compute_ten_power_sum(arr)

CPU times: user 46.8 ms, sys: 593 µs, total: 47.3 ms
Wall time: 46.4 ms


Array(10874275, dtype=int32)

@jax.jit applies --> then JAX first traces the function (runs through it once with a special tracer object to build a computational graph, which is the static representation of all numerical operations performed) --> JAX compiles the graph using XLA (Accelerated Linear Algebra), this is a backend that generates highly optimized ML code --> compiled version is chached and reused whenever the function is called again with the same input shapes and types --> SPEEDUPS