# Vectorization Solutions

In [1]:
import jax.numpy as jnp
import chex

## Array Operations

In [2]:
def array_operations_q1(N):
    return jnp.sum(jnp.log(jnp.arange(N) + 1.0))

print(array_operations_q1(10.0))
print(array_operations_q1(20.0))

15.104412
42.335617


In [3]:
def array_operations_q2(N, M):
    i = jnp.tile(jnp.arange(N)[..., None], M)
    j = jnp.tile(jnp.arange(M)[..., None], N).T

    return jnp.sum(jnp.log(i * j + 1.0))

print(array_operations_q2(5, 10))
print(array_operations_q2(30, 20))

84.905975
2507.8203


## Slicing

In [4]:
def array_slicing_q1(a):
    return a[1:] - a[:-1]

array_slicing_q1(jnp.array([0.0, 1.0, 5.0, 10.0, 20.0]))

Array([ 1.,  4.,  5., 10.], dtype=float32)

In [5]:
def array_slicing_q2(a):
    return a[..., 1:] - a[..., :-1]

array_slicing_q2(jnp.array([
    [0.0, 1.0, 5.0, 10.0, 20.0],
    [2.0, 5.0, 6.0, 20.0, 30.0]
]))

Array([[ 1.,  4.,  5., 10.],
       [ 3.,  1., 14., 10.]], dtype=float32)

## Indexing with Boolean Arrays

In [6]:
def boolean_indexing_q1(a):
    chex.assert_rank(a, 2)    
    positive_sum = a.sum(axis=-1) > 0.0
    return a[positive_sum]

boolean_indexing_q1(jnp.array([
    [1.0, 2.0, 3.0],
    [-1.0, -2.0, 3.0],
    [-1.0, -2.0, 4.0],
    [-1.0, -2.0, -3.0],
]))

Array([[ 1.,  2.,  3.],
       [-1., -2.,  4.]], dtype=float32)

In [7]:
def boolean_indexing_q2(N, x, y, r):
    assert(0 < N)
    assert(0 <= x and x < N)
    assert(0 <= y and y < N)
    assert(0 < r)
    
    rows = jnp.tile(jnp.arange(N)[..., None], N)
    cols = rows.T
    result = ((rows - x) ** 2.0 + (cols - y) ** 2.0 <= r ** 2.0).astype('int32')

    chex.assert_shape(result, (N, N))
    return result

print(boolean_indexing_q2(10, 3, 4, 2))
print('')
print(boolean_indexing_q2(10, 9, 8, 2))

[[0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 1 1 1 0 0 0 0]
 [0 0 1 1 1 1 1 0 0 0]
 [0 0 0 1 1 1 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]]

[[0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 1 1]
 [0 0 0 0 0 0 1 1 1 1]]


## Indexing with Arrays of Indices

In [8]:
def integer_indexing_q1(a):
    chex.assert_rank(a, 2)
    chex.assert_size(a, a.shape[0] ** 2)
    
    i = jnp.arange(a.shape[0])
    result = a[(i, i)]

    chex.assert_shape(result, (a.shape[0],))
    return result

a = jnp.arange(6 * 6).reshape(6, 6)
print(a)
print('')
print(integer_indexing_q1(a))

[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

[ 0  7 14 21 28 35]


In [9]:
def integer_indexing_q2(a, offset=0):
    chex.assert_rank(a, 2)
    chex.assert_size(a, a.shape[0] ** 2)
    
    i = jnp.arange(a.shape[0] - abs(offset)) - offset * (offset <= 0)
    j = jnp.arange(a.shape[1] - abs(offset)) + offset * (offset > 0)
    result = a[(i, j)]

    chex.assert_shape(result, (a.shape[0] - abs(offset),))
    return result

a = jnp.arange(6 * 6).reshape(6, 6)
print(a)
print('')

for offset in [-2, -1, 0, 1, 2]:
    print(f'Offset {offset}:', integer_indexing_q2(a, offset=offset))

[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

Offset -2: [12 19 26 33]
Offset -1: [ 6 13 20 27 34]
Offset 0: [ 0  7 14 21 28 35]
Offset 1: [ 1  8 15 22 29]
Offset 2: [ 2  9 16 23]


## Broadcasting

In [10]:
def broadcasting_q1(a, b):
    chex.assert_rank((a, b), 2)
    chex.assert_equal_shape_suffix((a, b), 1)
    
    result = jnp.sum((b[None, ...] - a[:, None, ...]) ** 2.0, axis=-1)

    chex.assert_shape(result, (a.shape[0], b.shape[0]))
    return result

a = jnp.array([
    [1, 2],
    [2, 4],
    [5, 6],
])

b = jnp.array([
    [5, 3],
    [4, 1],
    [6, 6],
    [7, 1],
])

broadcasting_q1(a, b)

Array([[17., 10., 41., 37.],
       [10., 13., 20., 34.],
       [ 9., 26.,  1., 29.]], dtype=float32)

## Regression Solutions

In [11]:
# Import a bunch of libraries we'll be using below
import pandas as pd
import matplotlib.pylab as plt
import numpyro
import numpyro.distributions as D
import jax.numpy as jnp
import jax.random as jrandom
import numpyro
import numpyro.distributions as D
import numpyro.distributions.constraints as C
from cs349 import *

# Load the data into a pandas dataframe
csv_fname = 'data/IHH-CTR-CGLF.csv'
data = pd.read_csv(csv_fname, index_col='Patient ID')

# Print a random sample of patients, just to see what's in the data
data.sample(15, random_state=0)

Unnamed: 0_level_0,Glow,Telekinetic-Ability
Patient ID,Unnamed: 1_level_1,Unnamed: 2_level_1
90,0.604085,0.079067
254,0.613645,0.029835
283,0.829212,0.240791
445,0.98112,0.361027
461,0.688329,0.07275
15,0.796853,0.066299
316,0.839546,0.44451
489,0.929422,0.368031
159,0.893813,0.522464
153,0.832483,0.475658


In [12]:
def model_polynomial_regression(N, x, y=None, degree=1):
    coefficients = numpyro.param(
        'coefficients',
        jnp.ones(degree + 1),
        constraint=C.real,
    )

    std_dev = numpyro.param(
        'std_dev',
        jnp.array(1.0),
        constraint=C.positive,
    )

    with numpyro.plate('data', N):
        mu = jnp.polyval(coefficients, x)
        p_y_given_x = D.Normal(mu, std_dev)
        numpyro.sample('y', p_y_given_x, obs=y)

In [None]:
NUM_ITERATIONS = 10000
DEGREE = 2

# Define an optimizer; here we chose the "Adam" algorithm
optimizer = numpyro.optim.Adam(step_size=0.01)

# Pick a random generator seed for the optimizer
key_optimizer = jrandom.PRNGKey(seed=0)

result = cs349_mle(
    model_polynomial_regression, 
    optimizer, 
    key_optimizer, 
    NUM_ITERATIONS,
    len(data['Glow']), 
    jnp.array(data['Glow']), 
    y=jnp.array(data['Telekinetic-Ability']),
    degree=DEGREE,
)

 39%|▍| 3891/10000 [00:00<00:01, 5940.56it/s, init loss: 1608.8676, avg. loss [3001-3500

In [None]:
plt.scatter(jnp.arange(NUM_ITERATIONS), result.losses)
plt.xlabel('Optimization Step')
plt.ylabel('Loss')
plt.title('Convergence of MLE')
plt.show()

In [None]:
test_x = jnp.linspace(0.0, 1.3, 100)

S = 10
samples = cs349_sample_generative_process(
    result.model_mle, 
    jrandom.PRNGKey(seed=0), 
    len(test_x), 
    test_x,
    degree=DEGREE,
    num_samples=S,
)

plt.scatter(
    test_x[None, ...].repeat(S, axis=0).flatten(), 
    samples['y'].flatten(), 
    color='blue', 
    alpha=0.5,
    label='Samples',
)
plt.scatter(data['Glow'], data['Telekinetic-Ability'], color='red', alpha=0.5, label='Data')

plt.xlabel('Glow')
plt.ylabel('Telekinetic Ability')
plt.title('Telekinetic Ability vs. Glow')

plt.legend()
plt.show()