<a href="https://colab.research.google.com/github/mc2engel/noneq_opt/blob/main/notebooks/optimize_ising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Instructions ##

Since `noneq_opt` is currently private, you need to create a [Personal Access Token](https://docs.github.com/en/github/authenticating-to-github/creating-a-personal-access-token). Once you have the token, enter it below, and you should be able to run the cell to pip install the package.

To run with a GPU, go to `Runtime > Change runtime type` and choose `GPU`.

<!---TODO: add TPU instructions and code.--->

# Installs and Imports

In [None]:
token=''
!pip install git+https://$token@github.com/mc2engel/noneq_opt.git --upgrade

In [None]:
import functools

import tqdm

from google.colab import files

import jax
import jax.numpy as jnp
import jax.experimental.optimizers as jopt
import numpy as np

from noneq_opt import ising
from noneq_opt import parameterization as p10n

import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import rc
import pandas as pd
import seaborn as sns
rc('animation', html='jshtml')

# Simulation and training parameters

In [None]:
size = 40  #@param
seed = 0  #@param

time_steps = 31 #@param

field_degree = 16 #@param
log_temp_degree = 16 #@param

batch_size = 256 #@param

optimizer = jopt.adam(3e-2) #@param
training_steps = 1000 #@param

# Define initial guess for the optimal protocol

We do this by defining "baseline" functions and learning a "diff" from these baselines.

In [None]:
def log_temp_baseline(min_temp=.75, max_temp=3., degree=1):
  def _log_temp_baseline(t):
    scale = (max_temp - min_temp)
    shape = (1 - t)**degree * t**degree * 4 ** degree
    return jnp.log(shape * scale + min_temp)
  return _log_temp_baseline

def field_baseline(start_field=-1., end_field=1.):
  def _field_baseline(t):
    return (1 - t) * start_field + t * end_field
  return _field_baseline

initial_log_temp_schedule = p10n.AddBaseline(
    p10n.ConstrainEndpoints(
        p10n.Chebyshev(
            jnp.zeros(log_temp_degree)
        ),
        y0=0.,
        y1=0.,
    ),
    baseline=log_temp_baseline()
)


initial_field_schedule = p10n.AddBaseline(
    p10n.ConstrainEndpoints(
        p10n.Chebyshev(
            jnp.zeros(field_degree)
        ),
        y0=0.,
        y1=0.,
    ),   
    baseline=field_baseline()
)

assert initial_field_schedule.domain == (0., 1.)
assert initial_log_temp_schedule.domain == (0., 1.)

initial_schedule = schedule = ising.IsingSchedule(
    initial_log_temp_schedule, initial_field_schedule)

time = jnp.linspace(0, 1, 100)
initial_temp = jnp.exp(initial_log_temp_schedule(time))
initial_field = initial_field_schedule(time)

def plot_schedules(schedules):
  time = np.linspace(0, 1, 100)
  fig, ax = plt.subplots(1, 3, figsize=[21, 6])
  for name, sched in schedules.items():
    temp = np.exp(sched.log_temp(time))
    field = sched.field(time)
    ax[0].plot(time, temp, label=name)
    ax[1].plot(time, field, label=name)
    ax[2].plot(temp, field, label=name)

  ax[0].set_title('Time vs. Temperature')
  ax[0].set_xlabel('Time')
  ax[0].set_ylabel('Temperature')
  ax[0].legend() 

  ax[1].set_title('Time vs. Field')
  ax[1].set_xlabel('Time')
  ax[1].set_ylabel('Field')
  ax[1].legend()

  ax[2].set_title('Temperature vs. Field')
  ax[2].set_xlabel('Temperature')
  ax[2].set_ylabel('Field')
  ax[2].legend()

plot_schedules(dict(initial=initial_schedule))

# Training

In [None]:
def seed_stream(seed):
  key = jax.random.PRNGKey(seed)
  while True:
    key, yielded = jax.random.split(key)
    yield(key)

stream = seed_stream(0)
state = optimizer.init_fn(schedule)
initial_spins = -jnp.ones([size, size])

train_step = ising.get_train_step(optimizer,
                                  initial_spins,
                                  batch_size,
                                  time_steps,
                                  ising.total_entropy_production)

summaries = []

for j in tqdm.trange(training_steps, position=0):
  state, summary = train_step(state, j, next(stream))
  summaries.append(jax.device_get(summary))

# Plot entropy production during training.
plt.figure(figsize=[12, 8])
plt.plot([s.entropy_production.sum(-1) for s in summaries], 'r,', alpha=.1)
plt.plot([s.entropy_production.sum(-1).mean() for s in summaries], 'b-')
plt.xlabel('Training step')
plt.ylabel('Entropy production')
plt.title('Training')
plt.show();


# Plot initial and final protocols

In [None]:
final_schedule = optimizer.params_fn(state)
plot_schedules(dict(initial=initial_schedule, final=final_schedule))

# Plot initial and final summaries

In [None]:
# If plotting is slow, increase subsampling.
time_subsampling = 1

metric_dict = {'initial': summaries[0]._asdict(),
               'final': summaries[-1]._asdict()}
dt = 1 / time_steps
dataframes = []
for name, metrics in metric_dict.items():
  for metric_name, metric in metrics.items():
    metric = metric[:, ::time_subsampling]  # Subsample
    dataframe = pd.DataFrame(metric).melt()
    dataframe['metric'] = metric_name
    dataframe['time'] = dataframe.variable * dt
    dataframe = dataframe.drop(columns='variable')
    dataframe['version'] = name
    dataframes.append(dataframe)
data = pd.concat(dataframes)
grid = sns.FacetGrid(data, col='metric', hue='version', col_wrap=2, aspect=4, sharey=False)
grid.map(sns.lineplot, 'time', 'value', ci=99.99)
grid.add_legend()
plt.show();


# Animate a single trajectory for initial and final protocols

In [None]:
times = jnp.linspace(0, 1, time_steps)
initial_params = initial_schedule(times)
final_params = final_schedule(times)

simulate = functools.partial(jax.jit, static_argnums=3)(ising.simulate_ising)

_, (initial_summary, initial_states) = simulate(initial_params,
                                                initial_spins,
                                                jax.random.PRNGKey(0),
                                                True)
initial_trajectory = jnp.concatenate([initial_spins[jnp.newaxis], initial_states.spins])

_, (final_summary, final_states) = simulate(final_params,
                                            initial_spins,
                                            jax.random.PRNGKey(0),
                                            True)
final_trajectory = jnp.concatenate([initial_spins[jnp.newaxis], final_states.spins])

initial_entropy_production = initial_summary.entropy_production.cumsum()
final_entropy_production = final_summary.entropy_production.cumsum()

fig, ax = plt.subplots(3, 2, figsize=[18, 18])

def frame(j):
  # Plot images of trajectories
  for a, traj, title in zip(ax[0],
                            [initial_trajectory, final_trajectory],
                            ['Initial', 'Optimized']):
    a.clear()
    a.imshow(traj[j])
    a.get_xaxis().set_visible(False)
    a.get_yaxis().set_visible(False)
    a.set_title(title, fontsize=18)

  # Plot temperature vs. field
  min_field = np.min([initial_params.field, final_params.field])
  max_field = np.max([initial_params.field, final_params.field])
  max_temp = np.exp(np.maximum(initial_params.log_temp.max(), final_params.log_temp.max()))
  for params, a in zip([initial_params, final_params], ax[1]):
    a.clear()
    field = params.field
    temp = np.exp(params.log_temp)
    a.plot(temp[:j + 1], field[:j + 1], 'r-', linewidth=4)
    a.set_xlim(0, 1.3 * max_temp)
    a.set_ylim(1.3 * min_field, 1.3 * max_field)
    a.set_xlabel('Temperature')
    a.set_ylabel('Field')

  # Plot cumulative entropy production
  times = np.linspace(0, 1, len(initial_entropy_production))
  max_entropy_production = np.max([initial_entropy_production, final_entropy_production])
  min_entropy_production = np.min([initial_entropy_production, final_entropy_production])
  for entropy_production, a in zip([initial_entropy_production, final_entropy_production], ax[2]):
    a.clear()
    a.plot(times[:j + 1], entropy_production[:j + 1], 'b-', linewidth=4)
    a.set_xlim(times.min(), times.max())
    a.set_ylim(1.3 * min_entropy_production, 1.3 * max_entropy_production)
    a.set_xlabel('Time')
    a.set_ylabel('Cumulative entropy production')

  return ()

anim = animation.FuncAnimation(
    fig, frame, blit=True, frames=time_steps)
plt.close(fig)
anim

In [None]:
# Download the animation

path = '/tmp/optimized_ising.mp4'
anim.save(path)
files.download(path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>