# Introduction

This notebook visualizes the optimization process of the `PointRotations` class. The input points, transformed by the network, are displayed as the parameters of the network&mdash;the transformation quaternions&mdash;are optimized.

In addition to `symmys` and `tensorflow`, this notebook requires the following other packages (all available on PyPI):
- `flowws-analysis`
- `flowws-freud`
- `keras-gtar`
- `plato-draw`
- `pyriodic-aflow`

In [None]:
import tensorflow as tf
tf.config.optimizer.set_jit(True)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
# render in the notebook or render in an external window (and save frames to make a movie)
notebook = True

if not notebook:
    %gui qt5
    import vispy, vispy.app
    vispy.app.use_app('pyside2')

import functools

import ipywidgets
import keras_gtar
import plato, plato.draw.vispy as draw
import symmys

import flowws
from flowws_analysis import Pyriodic
from flowws_freud import SmoothBOD

In [None]:
def get_bonds(name, N=512, noise=1e-2, neighbors=4):
    scope = flowws.Workflow([
        Pyriodic(structure=name, size=N, noise=noise),
        SmoothBOD(r_max=2, num_neighbors=neighbors)
    ]).run()
    return scope['SmoothBOD.bonds']

In [None]:
fname = '/tmp/dump.tar'
structure = 'cF8-C'
neighbors = 6

bonds = get_bonds(structure, neighbors=neighbors)
print(bonds.shape)

# currently very complex structures may require adjusting the distance scale for the loss function, like:
# loss = functools.partial(symmys.losses.mean_exp_rsq, r_scale=1./4)
loss = symmys.losses.mean_exp_rsq
opt = symmys.optimization.PointRotations(32, 8, loss=loss)

callbacks = [keras_gtar.callbacks.GTARLogger(fname, when='pre_batch', append=False)]
opt.fit(bonds, extra_callbacks=callbacks);

In [None]:
traj = keras_gtar.Trajectory(fname)
num_frames = len(traj)

prim = draw.SpherePoints(on_surface=False)
scene = draw.Scene(
    prim, size=(4, 4), pixel_scale=128,
    features=dict(additive_rendering=dict(invert=True)))

scene.show()

test_bonds = get_bonds(structure, N=64, neighbors=neighbors)

# cache models to avoid having to load and recompile if backtracking
# in the interactive visualization
@functools.lru_cache
def get_model(index):
    return traj.load(index)

@ipywidgets.interact(frame=(0, num_frames - 1))
def update(frame=0, replicated=True):
    if replicated:
        model = get_model(frame)
        new_bonds = model.predict(test_bonds).reshape((-1, 3))
        prim.points = new_bonds
    else:
        prim.points = test_bonds
    scene.render()

In [None]:
if not notebook:
    !rm -rf /tmp/frames && mkdir /tmp/frames

    for i in range(num_frames):
        update(i)
        scene.save('/tmp/frames/frame.{:05d}.png'.format(i))