# Generalized CPGs


## Requirements

First, we import the required libraries. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import grad, jit, vmap

from vector_field import vector_field, utilities

To model the tangential flow, we construct a simple counterclockwise rotational field.

In [None]:
# Defining a general class of functions which define the
# tangential component of the CPG update. 

class SimpleRotationalField(vector_field.VectorField):
    def __init__(self):
        pass 
    def get_gradient(self,x):
        theta = np.arctan2(x[0], x[1])
        return np.array([-np.cos(theta), np.sin(theta)])

## Constructing a basic CPG

Now we're ready to combine the above elements to construct a CPG out of base components. 

In [None]:
square = lambda x: jnp.dot(x, x)
inv_sq = lambda x: 1 / jnp.dot(x, x)
s1 = vector_field.FunctionalPotentialField(square)
s2 = vector_field.FunctionalPotentialField(inv_sq)
s3 = vector_field.LinearCombinationPotentialField([s1, s2])

m = SimpleRotationalField()
d = vector_field.LinearCombinationVectorField([s3, m])

We simulate the CPG update for 100 steps with step size of 0.1

In [None]:
history = utilities.simulate_trajectory(
    d, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100)

Lastly, we visualize the resulting trajectory. 
As we can see, we have constructed a system with stable limit cycle at ```x^2 + y^2 = 1```

In [None]:
def plot_history(x_history, **subplot_kwargs):
    fig, ax = plt.subplots(**subplot_kwargs)
    ax.scatter(x_history[:,0], x_history[:,1])
    ax.grid(True)

plot_history(history, figsize=(8,8))

## Linear Transformations

We consider linear transformations of 2D space. The transform shown below warps the circle into an ellipse.

In [None]:
A = jnp.array([[1.0, -0.5],[-0.5, 1.0]])

def scatter_circle_points():
    n = 100
    x = np.zeros((100,2))
    phases = np.linspace(0, 2*np.pi, n)
    for i in range(n):
        x[i] = np.array([np.cos(phases[i]), np.sin(phases[i])])
    return x 

fig, ax = plt.subplots(figsize=(5,5))
circle_points = scatter_circle_points()
ax.scatter(circle_points[:,0], circle_points[:,1])
ellipse_points = circle_points @ A.T
ax.scatter(ellipse_points[:,0], ellipse_points[:,1])
rot_ellipse_points = ellipse_points @ utilities.get_rotational_matrix(np.pi/2).T
ax.scatter(rot_ellipse_points[:,0], rot_ellipse_points[:,1])

In [None]:
f = lambda x: jnp.linalg.inv(A) @ x
d_ellipse = vector_field.SmoothTransformationVectorField(d, f)

history = utilities.simulate_trajectory(
    d_ellipse, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100, grad_clip=0.1)
plot_history(history, figsize=(8,8))

We can construct a rotated form of the above elliptical shape with linear transformation as well. 

In [None]:
g = lambda x: np.linalg.inv(utilities.get_rotational_matrix(np.pi/2) @ A) @ x
rot_ellipse = vector_field.SmoothTransformationVectorField(d, g)

history = utilities.simulate_trajectory(
    rot_ellipse, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100, grad_clip=0.1)
plot_history(history, figsize=(8,8))