<a href="https://colab.research.google.com/github/daniel-hu-37/thesis/blob/main/thesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
# Third-party imports
import jax
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Rectangle

# Local application/library specific imports
try:
  import brax
except ImportError:
  from IPython.display import clear_output
  !pip install git+https://github.com/google/brax.git@main
  clear_output(wait=True)
  import brax

# Remaining imports
import brax.envs as envs
from jax import random
import copy
from tqdm.notebook import tqdm

In [8]:
N = 10
T = 8
L = 3
epsilon = 1e-4  # Small value for finite difference calculation
costs = []

def cost(state, action):
    return state[1]**2 + 0.1*action**2

def project(vector, u):
  abs_max = np.abs(vector).max()
  scale = u / abs_max if abs_max > 0 else 1
  return vector * scale

def igpc_finite_difference(offset, control):
  env_plus = envs.create(env_name='inverted_pendulum')
  rng = random.PRNGKey(0)
  plus_state = env_plus.reset(rng=rng)

  env_minus = envs.create(env_name='inverted_pendulum')
  rng = random.PRNGKey(0)
  minus_state = env_minus.reset(rng=rng)

  gradient = np.zeros_like(control)
  o_plus_u = offset + control
  for t in tqdm(range(T)):  # Iterate over each time step in the trajectory
      u_plus = np.copy(o_plus_u)
      u_minus = np.copy(o_plus_u)

      # Perturb the control input at time t
      u_plus[t] += epsilon
      u_minus[t] -= epsilon

      # Compute the cost for the perturbed policies
      cost_plus, cost_minus = 0, 0
      for t1 in range(T):
        cost_plus += cost(plus_state.obs, u_plus[t1])
        cost_minus += cost(minus_state.obs, u_minus[t1])

        plus_state = env_plus.step(plus_state, u_plus[t1])
        minus_state = env_minus.step(minus_state, u_minus[t1])

      gradient[t] = (cost_plus - cost_minus) / (2 * epsilon)
  return gradient

def rollout_finite_difference(M, w_prime, control, state):
    # M is the array of matrices for which we want to compute the gradient
    # func is the function that takes M and returns a scalar value
    grad_M = np.zeros_like(M)  # Initialize gradient array with the same shape as M
    M_plus = np.copy(M)
    M_minus = np.copy(M)
    env_plus = envs.create(env_name='inverted_pendulum')
    env_minus = envs.create(env_name='inverted_pendulum')
    state_plus = copy.deepcopy(state)
    state_minus = copy.deepcopy(state)

    for i in tqdm(range(M.shape[0])):  # Loop over the matrices
        for j in range(M.shape[1]):  # Loop over the rows
            for k in range(M.shape[2]):  # Loop over the columns
                original_value = M[i, j, k]

                # Perturb the current element positively
                M_plus[i, j, k] = original_value + epsilon
                M_minus[i, j, k] = original_value - epsilon
                c_plus = 0
                c_minus = 0

                # Compute the finite difference
                for t in range(T):
                  o_plus = np.sum(np.array([M_plus[i] @ w_prime[i + t] for i in range(L)]))
                  o_minus = np.sum(np.array([M_minus[i] @ w_prime[i + t] for i in range(L)]))
                  action_plus = control[t] + o_plus
                  action_minus = control[t] + o_minus
                  c_plus += cost(state.obs, action_plus)
                  c_minus += cost(state.obs, action_minus)

                # Compute the finite difference
                grad_M[i, j, k] = (c_plus - c_minus) / (2 * epsilon)

    return grad_M

def igpc_rollout(control, disturbances, L, eta, gamma=2, S=5):
  M = np.random.rand(L, 4, 1)
  x_T = []
  a_T = np.array([])
  w_prime = np.concatenate((np.zeros((L, 1)), disturbances), axis=0)
  o_T = np.array([])

  env = envs.create(env_name='inverted_pendulum')
  env_plus = envs.create(env_name='inverted_pendulum')
  env_minus = envs.create(env_name='inverted_pendulum')

  rng = random.PRNGKey(0)
  state = env.reset(rng=rng)
  state_plus = copy.deepcopy(state)
  state_minus = copy.deepcopy(state)

  for t in tqdm(range(T)):
    o = np.sum(np.array([M[i] @ w_prime[i + t] for i in range(L)]))
    action = control[t] + o
    c = cost(state.obs, action)

    prev_state = copy.deepcopy(state)
    state = env.step(state, action)
    prev_state = env.step(prev_state, control[t])
    perturbation = prev_state.obs - state.obs


    # Finite Difference
    gradient = rollout_finite_difference(M, w_prime, control, state)


    # Gradient Step on GPC loss
    M = M - eta * gradient
    for i, matrix in enumerate(M):
      U, s, V = np.linalg.svd(matrix)
      spectral_norm = s[0]
      if spectral_norm > gamma:
        scaling_factor = gamma/spectral_norm
        M[i] = matrix * scaling_factor
    x_T.append(state.obs)
    a_T = np.append(a_T, [action])
    o_T = np.append(o_T, [o])
    w_prime = np.concatenate((np.zeros((L, 1)), disturbances), axis=0)
  return np.array(x_T), a_T, w_prime[L:], o_T

def igpc(w_0=0.1, u=1, eta=0.01):
  # initialize control and disturbances
  control = np.random.uniform(-u, u, T)
  print("Controls and Disturbances Initialized")

  # iterate
  for i in tqdm(range(N)):
    # initializations
    print("Iteration: ", i)
    disturbances = np.array([[np.random.choice([-w_0, 0], p=[0.3, 0.7])] for i in range(T)])
    disturbances = np.zeros_like(disturbances)

    # rollout
    print("Rollout")
    x, a, w, o = igpc_rollout(control, disturbances, L, eta)

    # print cumulative cost from rollout
    cur_cost = sum([cost(xt, at) for (xt, at) in zip(x, a)])/len(x)
    costs.append(cur_cost)
    print(cur_cost)
    print()

    # update
    print("Control Gradient Update")
    gradient = igpc_finite_difference(o, control)

    # Update policy
    control = control - (eta * gradient)
    control = project(control, u)
    print("Control: ", control)
    print()

  return control




In [None]:
final_control = igpc()

Controls and Disturbances Initialized


  0%|          | 0/10 [00:00<?, ?it/s]

Iteration:  0
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.031069227573801587

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.61169011 -0.36558618  0.57839112  0.33582401 -0.79445722  1.
  0.29341535  0.4478997 ]

Iteration:  1
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.04094825578697514

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.61027601 -0.34904628  0.58403761  0.33277519 -0.79474339  1.
  0.29894634  0.45002189]

Iteration:  2
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.040541419455820626

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.6115092  -0.33566187  0.59770495  0.3333045  -0.80057204  1.
  0.3002228   0.45586701]

Iteration:  3
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.04051014291979186

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.61186358 -0.32343464  0.61712423  0.33409029 -0.80272528  1.
  0.29991559  0.4554722 ]

Iteration:  4
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.0403527440098551

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.6078153  -0.31093093  0.66057557  0.33188783 -0.79870873  1.
  0.32018437  0.45079117]

Iteration:  5
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.04016043118566445

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.60675888 -0.30438992  0.65421869  0.26238425 -0.8019558   1.
  0.31807157  0.45001425]

Iteration:  6
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0.03998515652514567

Control Gradient Update


  0%|          | 0/8 [00:00<?, ?it/s]

Control:  [-0.60779609 -0.29541     0.72980177  0.28739675 -0.80658427  1.
  0.31914557  0.45163998]

Iteration:  7
Rollout


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
plt.plot(costs)

In [None]:
# Constants for visualization
PENDULUM_LENGTH = 2
CART_WIDTH = 1.0
CART_HEIGHT = 0.5

# Modified visualize function to handle an observation
def visualize(ax, obs, alpha=0.5):
    cart_position = obs[0]
    angle = obs[1]

    plt.xlim([-3, 3])
    plt.ylim([-1, 3])

    # Draw cart
    cart = Rectangle(xy=(cart_position - CART_WIDTH / 2, -CART_HEIGHT / 2),
                     width=CART_WIDTH,
                     height=CART_HEIGHT,
                     color=(0, 0, 1, alpha))
    ax.add_patch(cart)

    # Calculate pendulum end position and draw
    pendulum_x = cart_position + PENDULUM_LENGTH * np.sin(angle)
    pendulum_y = PENDULUM_LENGTH * np.cos(angle)

    ax.add_line(Line2D([cart_position, pendulum_x], [0, pendulum_y], lw=2, color=(1, 0, 0, alpha)))
    pendulum_mass = Circle(xy=(pendulum_x, pendulum_y),
                           radius=0.1,
                           color=(0, 0, 0, alpha))
    ax.add_patch(pendulum_mass)

    plt.draw()

In [None]:
env = envs.create(env_name='inverted_pendulum')
rng = random.PRNGKey(0)
state = env.reset(rng=rng)

_, ax = plt.subplots()
# Set these limits based on the range you expect the pendulum to move
plt.xlim([-3, 3])
plt.ylim([-1, 3])

# We'll use a simple heuristic to extract the angle; replace this with the correct method
def get_angle(state):
    # Assuming state.obs is [cart position, pole angle, cart velocity, pole angular velocity]
    return state.obs[1]

total_iterations = 100
observations = [state.obs]
obs = state.obs
for i in range(len(final_control)):
    if state.done:
      break
    print("Iteration: ", i)
    print("Cart position: ", obs[0], "Pendulum Angle: ", obs[1], "Cart velocity: ", obs[2], "Pole angular velocity: ", obs[3])
    print()
    action = final_control[i]
    state = env.step(state, action)
    print()
    obs = state.obs
    observations.append(obs)


total_iterations = len(observations)
for i in range(total_iterations):
  obs = observations[i]
  visualize(ax, obs, alpha=(0.1 + 0.9 * ((i+1) / (total_iterations+1))))  # Fading effect

plt.title('Inverted Pendulum in Motion')
plt.show()