# Generalized CPGs


## Requirements

First, we import the required libraries. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import grad, jit, vmap

from vector_field import vector_field, utilities

To model the tangential flow, we construct a simple counterclockwise rotational field.

In [None]:
# Defining a general class of functions which define the
# tangential component of the CPG update. 

class SimpleRotationalField(vector_field.VectorField):
    def __init__(self):
        pass 
    def get_gradient(self,x):
        theta = np.arctan2(x[0], x[1])
        return np.array([-np.cos(theta), np.sin(theta)])

## Constructing a basic CPG

Now we're ready to combine the above elements to construct a CPG out of base components. 

In [None]:
square = lambda x: jnp.dot(x, x)
inv_sq = lambda x: 1 / jnp.dot(x, x)
s1 = vector_field.FunctionalPotentialField(square)
s2 = vector_field.FunctionalPotentialField(inv_sq)
s3 = vector_field.LinearCombinationPotentialField([s1, s2])

m = SimpleRotationalField()
d = vector_field.LinearCombinationVectorField([s3, m])

We simulate the CPG update for 100 steps with step size of 0.1

In [None]:
history = utilities.simulate_trajectory(
    d, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100)

Lastly, we visualize the resulting trajectory. 
As we can see, we have constructed a system with stable limit cycle at ```x^2 + y^2 = 1```

In [None]:
def plot_history(x_history, **subplot_kwargs):
    fig, ax = plt.subplots(**subplot_kwargs)
    ax.scatter(x_history[:,0], x_history[:,1])
    ax.grid(True)

plot_history(history, figsize=(8,8))

## Linear Transformations

We consider linear transformations of 2D space. The transform shown below warps the circle into an ellipse.

In [None]:
A = jnp.array([[1.0, -0.5],[-0.5, 1.0]])

def scatter_circle_points():
    n = 100
    x = np.zeros((100,2))
    phases = np.linspace(0, 2*np.pi, n)
    for i in range(n):
        x[i] = np.array([np.cos(phases[i]), np.sin(phases[i])])
    return x 

fig, ax = plt.subplots(figsize=(5,5))
circle_points = scatter_circle_points()
ax.scatter(circle_points[:,0], circle_points[:,1])
ellipse_points = circle_points @ A.T
ax.scatter(ellipse_points[:,0], ellipse_points[:,1])
rot_ellipse_points = ellipse_points @ utilities.get_rotational_matrix(np.pi/2).T
ax.scatter(rot_ellipse_points[:,0], rot_ellipse_points[:,1])

In [None]:
f = lambda x: jnp.linalg.inv(A) @ x
ellipse = vector_field.SmoothTransformationVectorField(d, f)

history = utilities.simulate_trajectory(
    ellipse, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100, grad_clip=0.1)
plot_history(history, figsize=(8,8))

We can construct a rotated form of the above elliptical shape with linear transformation as well. 

In [None]:
g = lambda x: np.linalg.inv(utilities.get_rotational_matrix(np.pi/2) @ A) @ x
rot_ellipse = vector_field.SmoothTransformationVectorField(d, g)

history = utilities.simulate_trajectory(
    rot_ellipse, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100, grad_clip=0.1)
plot_history(history, figsize=(8,8))

We can also visualize the potential fields used to obtain these elliptical limit cycles

In [None]:
s3_f = vector_field.SmoothTransformationPotentialField(s3, f)
utilities.plot_potential_field(s3_f, jnp.array([-1, 1]), jnp.array([-1, 1]), max_clip = 3)

In [None]:
s3_g = vector_field.SmoothTransformationPotentialField(s3, g)
utilities.plot_potential_field(s3_g, jnp.array([-1, 1]), jnp.array([-1, 1]), max_clip = 3)

## Designing A Complex Limit Cycle

With the above building blocks, we can now try to build a dynamical system that exhibits a more complex limit cycle.

By superimposing two elliptical potential fields in a cross-shape and the rotational field previously given, we theorize that we can construct a clover shape. 

In [None]:
clover = vector_field.LinearCombinationVectorField(
    [ellipse, rot_ellipse]
)

history = utilities.simulate_trajectory(
    clover, jnp.array([0.5, 0.5]),
    step_size = 0.1, num_iters = 100, grad_clip=0.1)
plot_history(history, figsize=(8,8))

That worked surprisingly well. Overall the shape seems to be 'rotated' slightly from the ideal X-shape. Not entirely sure why that happens but I think it is due to the fact that we generated ```rot_ellipse``` as a 90-degree rotation of ```ellipse``` - some of that rotation is preserved in the linear combination.

Nonetheless, we observe four distinct corners. 

Let's try a wider variety of starting conditions. 

In [None]:
starting_conditions = [
    np.random.uniform(-1.2, 1.2, size = (2)) for i in range(5)
]
x_history = []
for s in starting_conditions:
    history = utilities.simulate_trajectory(
        clover, s,
        step_size = 0.02, num_iters = 400, grad_clip=0.1)
    x_history.append(history)

def plot_histories(x_history, **subplot_kwargs):
    """
    x_history: A list of histories
    """
    fig, ax = plt.subplots(**subplot_kwargs)
    for x_hist in x_history:
        ax.plot(x_hist[:,0], x_hist[:,1])
        ax.grid(True)

plot_histories(x_history, figsize=(8,8))

It's a bit rough but the seeds of something very cool are here. 

## Learning Limit Cycles

So far we have designed functions that warp the state space in some analytic way. 

Now we want to see if we can learn simple transformations that deform the circular limit cycle into an arbitrary shape, e.g a five-pointed star shape. 

This involves a supervised learning task. Minimally, we need to construct pairs of points on the base and target limit cycles, and learn a smooth (preferably invertible) diffeomorphism between them. 

In [None]:
import flax
import optax


As a warm-up, let's use points on the limit cycle obtained from our previous construction and pair them with points on the circle (of similar phase). 

To do this, we initialize a trajectory, allow it to converge, and then collect 1000 data points associated with phases. 

In [None]:
def get_dataset():
    history = utilities.simulate_trajectory(
        clover, np.array([1.0, 1.0]),
        step_size = 0.02, num_iters = 1200, grad_clip=0.1
    )
    # Cut off the off-limit-cycle points
    return history[200:]

targets = get_dataset()
plot_history(targets)

In [None]:
phases = np.arctan2(targets[:,0], targets[:,1])
features = np.zeros([1000, 2])
features[:,0] = np.cos(phases)
features[:,1] = np.sin(phases)

plot_history(features)

## Implementing ML in Flax

Adapted from https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html

In [None]:
import jax
import flax.linen as nn
from flax.training import train_state 

# 2. Define Network
class Net(nn.Module):
    """ Construct a learnable, nonlinear diffeomorphism on R^2"""
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features = 8)(x)
        x = nn.elu(x)
        x = nn.Dense(features = 2)(x)
        return x

# 3. Define loss
def mse_loss(*, preds, targets):
    return optax.l2_loss(preds, targets).mean()

# 4. Define metrics - TODO
def compute_metrics(*, preds, targets):
  loss = mse_loss(preds=preds, targets=targets)
  metrics = {
      'loss': loss,
  }
  return metrics

# 5. Define dataset
def get_datasets(features, targets):
    # Split train and val sets
    X_train = features[:800]
    X_test = features[800:]
    y_train = targets[:800]
    y_test = targets[800:]
    
    train_ds = {
        'features': X_train, 
        'targets': y_train
    }
    test_ds = {
        'features': X_test, 
        'targets': y_test,
    }
    return train_ds, test_ds


# 6. Train state
def create_train_state(rng, learning_rate, momentum):
    net = Net()
    params = net.init(rng, jnp.ones([1, 2]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    def loss_fn(params):
        preds = Net().apply({'params': params}, batch['features'])
        loss = mse_loss(preds=preds, targets=batch['targets'])
        return loss, preds
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, preds), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(preds=preds, targets=batch['targets'])
    return state, metrics

@jax.jit
def eval_step(params, batch):
    preds = Net().apply({'params': params}, batch['features'])
    return compute_metrics(preds=preds, targets=batch['targets'])

def train_epoch(state, train_ds, batch_size, epoch, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['features'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]}

    print('train epoch: %d, loss: %.4f ' % (
        epoch, epoch_metrics_np['loss']))

    return state

def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss']

In [None]:
train_ds, test_ds = get_datasets(features, targets)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 10
batch_size = 32

for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f ' % (
      epoch, test_loss))