## Structure solving as meta-optimization (demo)

This is going to be so cool!

In the work of Senior et al. (2019), Yang et al. (2020), and others, static optimization constraints are predicted then provided to a static, general purpose optimization algorithm (with some amount of manual tuning of optimization parameters to the specific task).

Fascinatingly, there is a broad modern literature on the use of neural networks to learn to optimize. For example, Andrychowicz et al. (2016) demonstrate the learning of a domain-specific optimization algorithm that subsequently was shown to out-perform all of the best in class optimizers available for that problem (that had been a legacy of painstaking effort over more than a decade).

This is amazing because there's the potential to learn better and better optimizers from data which can in turn save time and money for future work - but it's also quite interesting to think of how an optimizer might learn to become specialized to individual optimization problems (such as navigating the energy landscape of a protein structure).

<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Folding_funnel_schematic.svg" alt="Folding funnel schematic.svg" height="480" width="463">

(Image [CC-BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0) / [Thomas Splettstoesser](commons.wikimedia.org/wiki/User:Splette); [original](https://commons.wikimedia.org/wiki/File:Folding_funnel_schematic.svg#/media/File:Folding_funnel_schematic.svg))



### Work in progress

The plan is to modify the [GraphNetEncoder](https://github.com/google/jax-md/blob/master/jax_md/nn.py#L650) and [EnergyGraphNet](https://github.com/google/jax-md/blob/master/jax_md/energy.py#L944) from jax-md to also accept as input evolutionary data and not to predict a single energy value but to predict several things including:

1. A future conformation,
2. A distance matrix,
3. Bond angles, and
4. Compound interaction strengths

The simplest way to include (1) in a loss seems to be to have one of the model outputs be a coordinate for each node that are passed to a conventional jax-md energy function which is then used to incentivized input conformations being mapped to output conformations with lower energy.

It looks like (2) and (3) would be straightforward if the model returned edge representation in some form. It's possible to for now also accomplish (4) in this way.

The philosophy regarding (4) is that when folding a new protein you could obtain its iteraction profile fairly easily and if your model was previously trained to use interaction profiles as a guide (in the same way as using evolutionary data as a guide) might then be able to solve the structure more easily. Succeeding with that means architecting the model in a way consistent with that use case.

This might be done in a variety of ways. In the spirit of our learned optimizer, we might wish to learn an optimizer that not only minimizes energy but predicts conformations that are more and more consistent with interaction profiles with a set of compounds. To do this it seems we may need to run a simulator of those structure/compound interactions (which would be computationally expensive but not impossible, especially for important structures). The tendency of the learned energy minimizer to minimize energy could be fine-tuned based on the interactions of produced structures with compounds.

Or, we might consider the compound interactions as simply a guide to better learning how to extract information from evolutionary data and ignore their predictions at structure inference time.

Alternatively, we might consider compound-polymer interaction strengths as a type of input, like evolutionary data, that need to be correctly encoded but need not be predicted by the network - it simply is yet another kind of input information that can help the model learn to predict low-energy structures.

It's possible we might want to synergize with the energy-predicting approach of jax-md given that the task of learning to predict structures of lower energy seems closely related to that of computing energies - so training node functions to compute partial energies might be nice pre-training for their learning to perform position updates that reduce energy.


### Setup

Ensure the most recent version of Flatland is installed.

In [None]:

!pip install git+git://github.com/cayley-group/flatland.git --quiet


### Loading examples

Here we use a [Tensorflow Datasets](https://github.com/tensorflow/datasets) definition of a dataset generated using the Flatland environment. This provides a simplified interface to returning a [tf.data](https://www.tensorflow.org/guide/data) Dataset which has a variety of convenient methods for handling the input example stream (e.g. for batching, shuffling, caching, and pre-fetching).

Let's load an example from the "flatland_mock" dataset to see what the structure and data type of examples will be.


In [6]:

from absl import logging
logging.set_verbosity(logging.INFO)

import tensorflow as tf
import tensorflow_datasets as tfds
import flatland.dataset

ds = tfds.load('flatland_mock', split="train")
assert isinstance(ds, tf.data.Dataset)

ds = ds.cache().repeat()
for example in tfds.as_numpy(ds):
  break


INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/flatland_mock/0.0.1
INFO:absl:Reusing dataset flatland_mock (/home/jupyter/tensorflow_datasets/flatland_mock/0.0.1)
INFO:absl:Constructing tf.data.Dataset for split train, from /home/jupyter/tensorflow_datasets/flatland_mock/0.0.1


In [4]:
example

{'aa_sequence': array([1, 2, 0, 0, 0, 0, 0, 1, 1, 0], dtype=int32),
 'alignments': <tf.RaggedTensor [[2, 2, 2, 0, 0, 2, 1, 2, 0, 0], [1, 1, 0, 2, 1, 1, 2, 1, 2, 0], [1, 1, 2, 0, 1, 1, 0, 2, 0, 2], [1, 0, 1, 1, 0, 2, 1, 1, 2, 2], [0, 2, 0, 1, 2, 0, 2, 2, 0, 0]]>,
 'compound_affinity': array([0.00758087, 0.23774783, 0.37480316, 0.08547738, 0.93872553,
        0.5524232 , 0.31635335, 0.12758023, 0.28242147, 0.44345865,
        0.83017063, 0.4613364 , 0.84054464, 0.43620852, 0.8709514 ,
        0.6868894 , 0.21353707, 0.21109186, 0.7294403 , 0.32652786,
        0.06179944, 0.58756477, 0.2238007 , 0.78140414, 0.93743885,
        0.10661111, 0.6824019 , 0.15397342, 0.7832685 , 0.7026999 ,
        0.8029139 , 0.4829773 , 0.12657185, 0.747258  , 0.8097496 ,
        0.5689918 , 0.63764423, 0.04804136, 0.73256326, 0.31102976,
        0.8087152 , 0.9840257 , 0.9967061 , 0.8089805 , 0.6003544 ,
        0.80616623, 0.522206  , 0.43120426, 0.13516551, 0.23430112,
        0.69846284, 0.47294742, 0.94

## Train demo solver

Here we have a wrapper to train the demo solver that currently only trains an energy predicting model but subsequently will transfer-learn this to predicting lower-energy structures.

In [None]:

from flatland.train import train_demo_solver
from absl import logging
logging.set_verbosity(logging.INFO)

params = train_demo_solver(num_training_steps=1,
                           training_log_every=1,
                           batch_size=16)


In [11]:

from flatland.train import demo_example_stream, graph_network_neighbor_list
from flatland.train import OrigamiNet
from jax_md import space
from functools import partial

box_size = 10.862
batch_size = 16

iter_examples = demo_example_stream(
  batch_size=batch_size, split="train")

positions, energies, forces = next(iter_examples)
_, polymer_length, polymer_dimensions = positions.shape

displacement, shift = space.periodic(box_size)

neighbor_fn, init_fn, apply_fn = graph_network_neighbor_list(
  network=OrigamiNet,
  displacement_fn=displacement,
  box_size=box_size,
  polymer_length=polymer_length,
  polymer_dimensions=polymer_dimensions,
  r_cutoff=3.0,
  dr_threshold=0.0)

neighbor = neighbor_fn(positions[0], extra_capacity=6)

structure_fn = partial(apply_fn, params)


INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/flatland_mock/0.0.1
INFO:absl:Reusing dataset flatland_mock (/home/jupyter/tensorflow_datasets/flatland_mock/0.0.1)
INFO:absl:Constructing tf.data.Dataset for split train, from /home/jupyter/tensorflow_datasets/flatland_mock/0.0.1


In [12]:

structure = structure_fn(positions[0], neighbor)[1:]


In [13]:

structure


DeviceArray([-125.68002 , -180.35202 ,   75.57445 ,   37.897186,
               72.177246,  145.60895 , -163.79515 ,   23.409634,
               30.021103,  101.55815 ,   38.034508, -137.15297 ,
             -132.84296 , -161.73619 ,   99.6241  ,  -55.685173,
               49.254753,   70.4132  ,  -14.434248,  131.40115 ],            dtype=float32)

In [14]:

# A polymer of length 10 and dimension 2
structure.shape


(20,)

In [15]:

%timeit structure_fn(next(iter_examples)[0][0], neighbor)


150 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Long auto-regressive search

Here we will provide some minimal experimentation with using the model to actually optimize a structure by simply repeatedly applying the structure minimizer. We'll characterize what happens to the energy - e.g. does it consistently go down over time or does it diverge after a certain length of such a "rollout"?


In [None]:

# WIP


## Genetic + short auto-regressive

Presuming the previous won't be stable under long-rollouts, we'll use the previous method only over somewhat short rollouts (for the horizon over which these are stable) in conjunction with an evolutionary optimization approach to progressively determining better and better optimization starting points.


In [None]:

# WIP
