## Structure solving as meta-optimization (demo)

This is going to be so cool!

In the work of Senior et al. (2019), Yang et al. (2020), and others, static optimization constraints are predicted then provided to a static, general purpose optimization algorithm (with some amount of manual tuning of optimization parameters to the specific task).

Fascinatingly, there is a broad modern literature on the use of neural networks to learn to optimize. For example, Andrychowicz et al. (2016) demonstrate the learning of a domain-specific optimization algorithm that subsequently was shown to out-perform all of the best in class optimizers available for that problem (that had been a legacy of painstaking effort over more than a decade).

This is amazing because there's the potential to learn better and better optimizers from data which can in turn save time and money for future work - but it's also quite interesting to think of how an optimizer might learn to become specialized to individual optimization problems (such as navigating the energy landscape of a protein structure).

<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Folding_funnel_schematic.svg" alt="Folding funnel schematic.svg" height="480" width="463">

(Image [CC-BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0) / [Thomas Splettstoesser](commons.wikimedia.org/wiki/User:Splette); [original](https://commons.wikimedia.org/wiki/File:Folding_funnel_schematic.svg#/media/File:Folding_funnel_schematic.svg))



### Work in progress

The plan is to modify the [GraphNetEncoder](https://github.com/google/jax-md/blob/master/jax_md/nn.py#L650) and [EnergyGraphNet](https://github.com/google/jax-md/blob/master/jax_md/energy.py#L944) from jax-md to also accept as input evolutionary data and not to predict a single energy value but to predict several things including:

1. A future conformation,
2. A distance matrix,
3. Bond angles, and
4. Compound interaction strengths

The simplest way to include (1) in a loss seems to be to have one of the model outputs be a coordinate for each node that are passed to a conventional jax-md energy function which is then used to incentivized input conformations being mapped to output conformations with lower energy.

It looks like (2) and (3) would be straightforward if the model returned edge representation in some form. It's possible to for now also accomplish (4) in this way.

The philosophy regarding (4) is that when folding a new protein you could obtain its iteraction profile fairly easily and if your model was previously trained to use interaction profiles as a guide (in the same way as using evolutionary data as a guide) might then be able to solve the structure more easily. Succeeding with that means architecting the model in a way consistent with that use case.

This might be done in a variety of ways. In the spirit of our learned optimizer, we might wish to learn an optimizer that not only minimizes energy but predicts conformations that are more and more consistent with interaction profiles with a set of compounds. To do this it seems we may need to run a simulator of those structure/compound interactions (which would be computationally expensive but not impossible, especially for important structures). The tendency of the learned energy minimizer to minimize energy could be fine-tuned based on the interactions of produced structures with compounds.

Or, we might consider the compound interactions as simply a guide to better learning how to extract information from evolutionary data and ignore their predictions at structure inference time.

Alternatively, we might consider compound-polymer interaction strengths as a type of input, like evolutionary data, that need to be correctly encoded but need not be predicted by the network - it simply is yet another kind of input information that can help the model learn to predict low-energy structures.

It's possible we might want to synergize with the energy-predicting approach of jax-md given that the task of learning to predict structures of lower energy seems closely related to that of computing energies - so training node functions to compute partial energies might be nice pre-training for their learning to perform position updates that reduce energy.


### Setup

Ensure the most recent version of Flatland is installed.

In [None]:

!pip install git+git://github.com/cayley-group/flatland.git --quiet


### Loading examples

Here we use a [Tensorflow Datasets](https://github.com/tensorflow/datasets) definition of a dataset generated using the Flatland environment. This provides a simplified interface to returning a [tf.data](https://www.tensorflow.org/guide/data) Dataset which has a variety of convenient methods for handling the input example stream (e.g. for batching, shuffling, caching, and pre-fetching).

Let's load an example from the "flatland_mock" dataset to see what the structure and data type of examples will be.


In [2]:

import tensorflow as tf
import tensorflow_datasets as tfds
import flatland.dataset

ds = tfds.load('flatland_mock', split="train")
assert isinstance(ds, tf.data.Dataset)

ds = ds.take(1).cache().repeat()
for example in tfds.as_numpy(ds):
  break


In [3]:
example

{'aa_sequence': array([2, 2, 3, 1, 1, 1, 0, 2, 1, 0], dtype=int32),
 'alignments': <tf.RaggedTensor [[2, 0, 3, 2, 1, 2, 2, 1, 3, 2], [2, 3, 0, 3, 2, 1, 2, 3, 1, 1], [0, 3, 2, 2, 2, 0, 1, 2, 1, 2], [2, 2, 0, 2, 0, 1, 2, 3, 2, 0], [0, 3, 3, 0, 2, 2, 3, 1, 1, 3]]>,
 'compound_affinity': array([0.2486111 , 0.54139775, 0.73639923, 0.69151855, 0.24467154,
        0.11261518, 0.82251084, 0.9783993 , 0.4251778 , 0.55030143,
        0.8410047 , 0.2991653 , 0.62026346, 0.16084638, 0.7104868 ,
        0.34349436, 0.85450906, 0.839869  , 0.05980841, 0.2693819 ,
        0.8683603 , 0.6705369 , 0.329811  , 0.90815157, 0.4958696 ,
        0.8589536 , 0.26986575, 0.35537535, 0.83140236, 0.78714424,
        0.47348848, 0.5145506 , 0.0813486 , 0.8461258 , 0.76298654,
        0.0742124 , 0.73888475, 0.07325479, 0.8835788 , 0.64374226,
        0.8619347 , 0.2644174 , 0.6814376 , 0.8745566 , 0.7278718 ,
        0.15643644, 0.65877795, 0.62081593, 0.20465161, 0.5998772 ,
        0.8594883 , 0.9285772 , 0.04