In [1]:
import deeprefine as dr
import torch
import numpy as np
from openmm import unit
import mdtraj as md
import matplotlib.pyplot as plt

### Read-in data

In [2]:
traj_file = "./1BTI/1bti_implicit_traj.h5"
pdb_file = "./1BTI/1bti_fixed.pdb"

In [3]:
sim_x, top = dr.utils.align_md(traj_file, shuffle=True, ref_pdb=pdb_file)
top2, mm_1bti = dr.setup_protein(pdb_file, 300, 
                                 implicit_solvent=True, 
                                 platform='CUDA', 
                                 length_scale=unit.nanometer)

In [4]:
assert top == top2

In [5]:
print(sim_x.shape)
print(top)

(4000, 2646)
<mdtraj.Topology with 1 chains, 58 residues, 882 atoms, 895 bonds>


### Create dataprocessing blocks

In [6]:
icconverter = dr.ICConverter(top, vec_angles=True)

In [7]:
ic0 = icconverter.xyz2ic(dr.assert_tensor(sim_x))

In [8]:
cosangle_idx = np.concatenate([icconverter.cosangle_idxs, icconverter.costorsion_idxs])
sinangle_idx = np.concatenate([icconverter.sinangle_idxs, icconverter.sintorsion_idxs])
featurefreezer = dr.FeatureFreezer(ic0, bond_idx=icconverter.bond_idxs, 
                                   cosangle_idx=cosangle_idx, sinangle_idx=sinangle_idx)

In [9]:
ic1 = featurefreezer.forward(ic0)

In [11]:
# This could send out some warnings when N_samples < 2*N_features
whitener = dr.Whitener(X0=ic1, 
                       dim_cart_signal=icconverter.dim_cart_signal, 
                       keepdims=-6)

### Create bg

In [12]:
realnvp_args = {
    "n_layers" : 4,
    "n_hidden" : [128,256,128],
    "activation" : torch.relu,
    "activation_scale" : torch.tanh,
    "init_output_scale" : 0.01
}
bg = dr.construct_bg(icconverter, featurefreezer, whitener, 
                     n_realnvp=8, **realnvp_args, prior='normal')

In [13]:
bg.summarize()

ICConverter    :         2646  ->          4062
FeatureFreezer :         4062  ->          1938
Whitener       :         1938  ->          1932
SplitChannels  :         1932  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
RealNVP        :   [966, 966]  ->    [966, 966]
MergeChannels  :   [966, 966]  ->          1932
Number of parameters:     10057920


### ML Training

In [14]:
optim = torch.optim.Adam(bg.flow.parameters(), lr=0.001)
mltrainer = dr.nn.flow.MLTrainer(bg, optim, iwae=False)

In [15]:
X0 = torch.tensor(sim_x, dtype=torch.float32)

In [None]:
batchsize = [128]*2 + [256]*2 + [512]*6 + [1024]*10 + [2048]*20
epochs = 2 + 2 + 6 + 10 + 20
mltrain_record = mltrainer.train(X0, epochs=epochs, batch_size=batchsize, 
                                 checkpoint_epoch=4, 
                                 checkpoint_name="xxx/prefix_")

### KL + ML Training

In [None]:
# Read in a ckpt before
bg = dr.load_bg('xxx_xxx.pkl', mm_1bti)

In [None]:
optim2 = torch.optim.Adam(bg.flow.parameters(), lr=0.0001)
kltrainer = dr.nn.flow.FlexibleTrainer(bg, optim2)

In [None]:
epochs_KL     = [  1,   1,   1,   1,   1,   1,  1,  1,  2, 2, 2, 3, 4]
high_energies = [1e10,  1e9,  1e8,  1e7,  1e6,  1e5,  1e5,  1e5,  5e4,  5e4,  2e4,  2e4, 2e4]
w_KLs         = [1e-12, 1e-6, 1e-5, 1e-4, 1e-3, 1e-3, 5e-3, 1e-3, 5e-3, 5e-2, 0.05, 0.05, 0.05]
report = []
for s, epochs in enumerate(epochs_KL):
    report = kltrainer.train(X0, 
                             epochs=epochs_KL[s], batchsize_ML=1024, batchsize_KL=1024, 
                             w_KL=w_KLs[s], Ehigh=high_energies[s], 
                             record=report, checkpoint_name=f"xxx/prefix_{s}")
    # Analyze
    samples_z = bg.sample_z(nsample=2000, return_energy=False)
    samples_x, _ = bg.TzxJ(samples_z)
    samples_e = dr.assert_numpy(bg.energy_model.energy(samples_x))
    Elevels = list(set(high_energies))
    energy_violations = [np.count_nonzero(samples_e > E) for E in Elevels]
    print('Energy violations:', flush=True)
    for E, V in zip(Elevels, energy_violations):
        print(V, '\t>\t', E, flush=True)