In [None]:
import asmsa
import mdtraj as md
import nglview as nv
import numpy as np

In [None]:
conf = "trpcage_correct.pdb"
topol = "topol_correct.top"
index = 'index_correct.ndx'
gro = 'trpcage_correct.gro'

In [None]:
# outputs of prepare.ipynb

#train_tr = md.load('x_train.xtc',top=conf)
#test_tr = md.load('x_test.xtc',top=conf)

train_tr = md.load('train.xtc',top=conf)
test_tr = md.load('test.xtc',top=conf)

In [None]:
train_g = np.moveaxis(train_tr.xyz,0,-1)
test_g = np.moveaxis(test_tr.xyz,0,-1)

In [None]:
train_g.shape

In [None]:
sparse_dists = asmsa.NBDistancesSparse(train_g.shape[0], density=2)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
train_int = mol.intcoord(train_g)
train_int.shape

In [None]:
test_int = mol.intcoord(test_g)
test_int.shape

In [None]:
# faking ... true output of decoder predicting something should come here

out_idx = 100

dec_out = test_int[:,out_idx]
dec_out.shape

In [None]:
diff = train_int.T - dec_out
msd = np.sum(diff * diff,axis=1)

In [None]:
minidx = np.argmin(msd)

In [None]:
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)

In [None]:
dec_out_scaled = dec_out * train_scale + train_mean

In [None]:
# see train.ipynb
grotr = md.load(gro)
nhs = grotr.topology.select('element != H')

with open(index) as f:
    f.readline()
    ndx = np.fromstring(" ".join(f),dtype=np.int32,sep=' ')-1

pdb2gmx = nhs[np.argsort(ndx)]+1


In [None]:
pdb2gmx

In [None]:
import gromacs as gmx

In [None]:
mintr = train_tr[minidx]
mintr.save('min.pdb')

In [None]:
ompthreads=2
mpiranks=1

In [None]:
# Kubernetes deployment
mdrunner=gmx.MDrunnerK8s()

def mdrun(**kwargs):
    mdrunner.run(pre={'cores':ompthreads*mpiranks,'gpus':1}, mdrunargs={**kwargs,'ntomp':ompthreads,'pin':'on'},ncores=mpiranks)

In [None]:
# alternative local deployment
mdrunner=gmx.run.MDrunner()

# XXX: no MPI support so far
def mdrun(**kwargs):
    mdrunner.run(mdrunargs={**kwargs,'ntomp':ompthreads,'pin':'on'})

In [None]:
gmx.pdb2gmx(f='min.pdb',o='min.gro',p='min.top',water='tip3p',ff='amber94')

In [None]:
mdbox=2.0
gmx.editconf(f='min.gro',o="min-box.gro",c=True,d=str(mdbox),bt="dodecahedron")

In [None]:
!cp min.top restrained.top

off = mol.dihed4.shape[0]*2+mol.dihed9.shape[0]*2+mol.angles.shape[0]+mol.bonds.shape[0]

dec_dist = dec_out_scaled[off:]


with open('restrained.top','a') as t:
    t.write('''
[ distance_restraints ]
''')
    for i,d in enumerate(sparse_dists.bonds):
        t.write(f'{pdb2gmx[d[0]]} {pdb2gmx[d[1]]} 1 {i} 2 {dec_dist[i]*.99} {dec_dist[i]*1.01} 42.0 10.0\n')

In [None]:
with open('min.mdp','w') as m:
    m.write('''
integrator  = steep         ; Algorithm (steep = steepest descent minimization)
emtol       = 1000.0        ; Stop minimization when the maximum force < 1000.0 kJ/mol/nm
emstep      = 0.01          ; Minimization step size
nsteps      = 50000         ; Maximum number of (minimization) steps to perform

nstxout                 = 0         
nstvout                 = 0         
nstfout                 = 0         
nstlog                  = 5
nstxout-compressed      = 0

; Parameters describing how to find the neighbors of each atom and how to calculate the interactions
nstlist         = 1         ; Frequency to update the neighbor list and long range forces
cutoff-scheme   = Verlet    ; Buffered neighbor searching
ns_type         = grid      ; Method to determine neighbor list (simple, grid)
coulombtype     = Cut-off   ; Treatment of long range electrostatic interactions
rcoulomb        = 1.0       ; Short-range electrostatic cut-off
rvdw            = 1.0       ; Short-range Van der Waals cut-off
pbc             = xyz       ; Periodic Boundary Conditions in all 3 dimensions

disre           = Simple
''')

In [None]:
gmx.grompp(f="min.mdp",c="min-box.gro",p='restrained.top',o="min.tpr")

In [None]:
mdrun(deffnm="min")

In [None]:
solm = md.load('min.gro')
nhs = solm.topology.select('element != H')
solm.atom_slice(nhs,inplace=True)
v = nv.show_mdtraj(solm)

refm = md.load_pdb(conf)
refm.xyz = test_g[:,:,out_idx]
refm.superpose(solm)

v.add_component(refm)
v.clear(component=0)
v.clear(component=1)

v.add_representation('licorice',color='green',component=1)
v.add_representation('licorice',color='red',component=0)

v