# Day 1: BrainPy programming basics homework  
This is the first assignment for this course. The assignment is to familiarize themselves with the basic programming of BrainPy that was covered in class, and the participants will need to fill in the missing content according to the code comments and execute the cells to observe the results.  

First of all, we need to import all the libraries.

In [None]:
import brainpy as bp
import brainpy.math as bm
import numpy as np

## 1. JIT compilation  
Just-in-time compilation is the basic technique that gaurantee the efficiency of BrainPy. In this section, we will show the basic usages of JIT compilation and experience the improvement on running performance.

### 1.1 Functional JIT compilation

Let's start with the function. Suppose we implement a Gaussian Error Linear Unit (GELU) function.

In [None]:
def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let's test the execution time without JIT compilation first:

In [None]:
x = bm.random.random(100000)
%timeit gelu(x)

If you use JIT compilation and pass the function into bm.jit(), the execution time of the function will be significantly reduced.

In [None]:
# TODO: JIT compile the gelu function using the brainpy.math library
# Hint: Use the bm.jit()
gelu_jit = ...
%timeit gelu_jit(x)

### 1.2 Object-oriented JIT compilation

We use the logistic regression classifier as an example, in this model, since the weight $w$ needs to be modified during training, it needs to be defined as ``brainpy.math.Variable``, and the rest of the parameters will be treated as static variables during compilation, and their values will not be changed.

In [None]:
class LogisticRegression(bp.BrainPyObject):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters
        self.dimension = dimension
			
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
      
    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w.value = self.w - u # in-place update

To test the execution time, we write a function that calculates the execution time and define the dataset:.

In [None]:
import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):   
        model(points, labels)
    print(f'{name} used time {time.time() - t0} s')
          
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)

Next, let's test the execution time without JIT compilation.

In [None]:
lr1 = LogisticRegression(num_dim)
benckmark(lr1, points, labels, name='Logistic Regression (without jit)')

In [None]:
lr1 = bm.jit(LogisticRegression(num_dim))
lr1.vars().keys()

Next, we test the execution time for the case of JIT compilation, which is used in a similar way to a function, simply passing the class instance into ``brainpy.math.jit()``:

In [None]:
lr2 = LogisticRegression(num_dim)
# TODO: JIT compile the gelu function using the brainpy.math library
# Hint: Use the bm.jit()
lr2 = ...
benckmark(lr2, points, labels, name='Logistic Regression (with jit)')

## 2. Data structures  
### 2.1 Arrays  
An array is a data structure that organizes algebraic objects in a multi-dimensional vector space. Simply put, in BrainPy, this data structure is a multidimensional array of the same data type, most commonly numeric or boolean.

In [None]:
bm_array = bm.array([0, 1, 2, 3, 4, 5])
np_array = np.array([0, 1, 2, 3, 4, 5])
bm_array

We can create a high-dimensional array and check the properties of the array.

In [None]:
# TODO: Create a new brainpy array name t2
t2 = ...
print('t2.ndim: {}'.format(t2.ndim))
print('t2.shape: {}'.format(t2.shape))
print('t2.size: {}'.format(t2.size))
print('t2.dtype: {}'.format(t2.dtype))


The array created by ``brainty.math`` will be stored in a JaxArray, which internally holds the JAX data format DeviceArray. if the user wants to unwrap the JaxArray to get the JAX data type DeviceArray inside, simply perform the ``.value`` operation:

In [None]:
# TODO: Get value from t2
t2_value = ...
print('t2_value: {}'.format(t2_value))

### 2.2 Variables

A dynamic variable is a pointer to an array of values (DeviceArray) stored in memory. The data in a dynamic variable can be modified during JIT compilation. If an array is declared as a dynamic variable, it means that it is an array that changes dynamically over time. To convert an array to a dynamic variable, the user simply wraps the array in `brainpy.math`.

In [None]:
t = bm.arange(4)
# TODO: Convert t to Variable
v = ...

Since dynamic variables are stored as arrays, all operations on arrays can be grafted directly onto dynamic variables. In addition, dynamic variables can be modified by the user, and in the next section, we will explain in detail how to modify dynamic variables under JIT compilation.

#### Indexing and slicing  
Users can use indexes to modify data in dynamic variables:

In [None]:
v = bm.Variable(bm.arange(4))
# TODO: Set the first element of v to 10
...
v

#### Augmented assignment  
All incremental assignments in Python modify only the internal value of a dynamic variable, so you can use incremental assignments without worrying about updating dynamic variables.

In [None]:
# TODO: all the elements in v add 1
...
v

#### `.value` assignment

This is one of the most common operations for updating variables in place. We often need to assign an array of values to a dynamic variable when updating it, and a common scenario is to reset the value of a dynamic variable during an iterative update of the dynamics system. In this case, we can use the `.value` assignment operation to override the data of the dynamic variable v, which has direct access to the data stored in the JaxArray.

In [None]:
# TODO: reset all the elements in v to 0
...
v

#### `.update` assignment

This method is functionally similar to `.value `assignment and is another method provided by BrainPy to override dynamic variables, which also requires that the shape and element types of the array be consistent with the dynamic variable.

In [None]:
# TODO: set v to be [3, 4, 5, 6]
...
v

## 3. Control flows  
### 3.1 If-else

Compilation errors occur when conditional judgment depends on dynamic variables. Our error message will tell you about alternative solutions, so here are two ways to write a conditional statement that can be used instead of an if-else statement.  

First we check out the simple example that will occur compilation error:

In [None]:
class OddEvenCauseError(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCauseError, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.rand < 0.5:  
            self.a += 1
        else:  
            self.a -= 1
        return self.a

In [None]:
wrong_model = bm.jit(OddEvenCauseError())

try:
    wrong_model()
except Exception as e:
    print(f"{e.__class__.__name__}: {str(e)}")

#### `brainpy.math.where()`  
This function in NumPy corresponds to `numpy.where()`, where(condition, x, y) function According to the condition to determine the true or false, the condition is true to return x, the condition is false to return y. We can change the above example of failure to.

In [None]:
class OddEvenWhere(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenWhere, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        # TODO: Use bm.where() to fix the error
        ...
        return self.a

In [None]:
model = bm.jit(OddEvenWhere())
model()

#### `brainpy.math.ifelse()`  
BrainPy provides a generic conditional statement that enables multiple branches. You need to change this example to the `bm.ifelse` statement version:

In [None]:
class OddEvenCond(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCond, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        # TODO: Use bm.ifelse() to fix the error
        ...
        return self.a

In [None]:
model = bm.jit(OddEvenCond())
model()

### For loop

In fact, BrainPy can write loops in Python mode. The user simply iterates over the sequence data and then operates on the iterated objects. This loop syntax is compatible with JIT compilation, but can lead to long tracing and compilation times. The following example is a class object that implement for loop in its function. 

In [None]:
class LoopSimple(bp.BrainPyObject):
    def __init__(self):
        super(LoopSimple, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = bm.Variable(rng.random(1000))
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        for s in self.seq:
            self.res += s
        return self.res.value

By running the following code, we will find that the first compilation takes longer, and if the logic of the statements in the program is more complex, the compilation will take an intolerable amount of time.

In [None]:
import time

def measure_time(f, return_res=False, verbose=True):
    t0 = time.time()
    r = f()
    t1 = time.time()
    if verbose:
        print(f'Result: {r}, Time: {t1 - t0}')
    return r if return_res else None

model = bm.jit(LoopSimple())

# First time will trigger compilation
measure_time(model)

# Second running
measure_time(model)

#### `brainpy.math.for_loop()`  
We speed up the code by using structured looping statements, you need to fill the blank in the code below:

In [None]:
class LoopStruct(bp.BrainPyObject):
    def __init__(self):
        super(LoopStruct, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        # TODO: Use bm.for_loop() to complete the loop
        ...

In [None]:
model = bm.jit(LoopStruct())

r = measure_time(model, verbose=False, return_res=True)
r.shape

## Solutions

In [None]:

# Functional JIT compilation: 
gelu_jit = bm.jit(gelu)

# Object-oriented JIT compilation:
lr2 = bm.jit(lr2)

# Create arrays:
t2 = bm.array([[[0, 1, 2, 3], [1, 2, 3, 4], [4, 5, 6, 7]],
               [[0, 0, 0, 0], [-1, 1, -1, 1], [2, -2, 2, -2]]])


# Get values of arrays:
t2_value = t2.value

# Convert to variable:
v = bm.Variable(t)

# Indexing and slicing:
v[0] = 10

# Augmented assignment:
v += 1

# .value assignment:
v.value = bm.zeros(4, dtype=int)

# .update assignment:
v.update(bm.array([3, 4, 5, 6]))

# where condition:
class OddEvenWhere(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenWhere, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.where(self.rand < 0.5, 1., -1.)
        return self.a

# ifelse condition:
class OddEvenCond(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCond, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.ifelse(self.rand[0] < 0.5,
                            [1., -1.])
        return self.a

# For loop:
class LoopStruct(bp.BrainPyObject):
    def __init__(self):
        super(LoopStruct, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        def add(s):
          self.res += s
          return self.res.value

        return bm.for_loop(body_fun=add, operands=self.seq)