In [None]:
import asmsa
import mdtraj as md
import nglview as nv
import numpy as np
from tensorflow.keras.saving import load_model
import torch

In [None]:
import os
os.environ['REQUEST_CPU'] = '8'
os.environ['REQUEST_RAM'] = '32'

In [None]:
"""conf = "trpcage_correct.pdb"
topol = "topol_correct.top"
index = 'index_correct.ndx'
gro = 'trpcage_correct.gro' """
exec(open('inputs.py').read())

In [None]:
# outputs of prepare.ipynb

#train_tr = md.load('x_train.xtc',top=conf)
#test_tr = md.load('x_test.xtc',top=conf)

train_tr = md.load('train.xtc',top=conf)
test_tr = md.load('test.xtc',top=conf)

In [None]:
train_g = np.moveaxis(train_tr.xyz,0,-1)
test_g = np.moveaxis(test_tr.xyz,0,-1)

In [None]:
train_g.shape

In [None]:
sparse_dists = asmsa.NBDistancesSparse(train_g.shape[0], density=2)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
mol.model.dihed4_model.forward(torch.tensor(test_g)).shape

In [None]:
train_int = mol.intcoord(train_g)
train_int.shape

In [None]:
test_int = mol.intcoord(test_g)
test_int.shape

In [None]:
topol

In [None]:
mol.dihed4

In [None]:
mol.dihed9

In [None]:
latent = np.loadtxt('sample_latent.txt')
decout = np.loadtxt('sample_int.txt')

In [None]:
latent

In [None]:
# not anymore (faking) ... true output of decoder predicting something should come here

out_idx = 113

# dec_out = test_int[:,out_idx]
dec_out = decout[out_idx,:]
dec_out.shape

In [None]:
diff = train_int.T - dec_out
msd = np.sum(diff * diff,axis=1)

In [None]:
minidx = np.argmin(msd)
minidx

In [None]:
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)

In [None]:
dec_out_scaled = dec_out * train_scale + train_mean

In [None]:
# see train.ipynb
grotr = md.load(gro)
nhs = grotr.topology.select('element != H')

with open(index) as f:
    f.readline()
    ndx = np.fromstring(" ".join(f),dtype=np.int32,sep=' ')-1

pdb2gmx = nhs[np.argsort(ndx)]+1


In [None]:
pdb2gmx

In [None]:
dihed9_gmx = [
    [ pdb2gmx[a] for a in d ]
    for d in mol.dihed9
]
dihed4_gmx = [
    [ pdb2gmx[a] for a in d ]
    for d in mol.dihed4
]

In [None]:
dihed9_gmx

In [None]:
import gromacs as gmx

In [None]:
gmx.select(s=gro,on='bb.ndx',select='Backbone')

In [None]:
with open('bb.ndx') as bb:
    bb.readline()
    l = " ".join(bb.readlines())
    bbndx = np.array(l.split(),np.int32)

In [None]:
mintr = train_tr[minidx]
#mintr = train_tr[32000]
mintr.save('min.pdb')

In [None]:
nv.show_mdtraj(mintr)

In [None]:
gmx.pdb2gmx(f='min.pdb',o='min.gro',p='min.top',water='tip3p',ff='amber94')

In [None]:
mdbox=2.0
gmx.editconf(f='min.gro',o="min-box.gro",c=True,d=str(mdbox),bt="dodecahedron")

In [None]:
bbndx[0:4]

In [None]:
dih9_ndx = [ np.where(np.all(dihed9_gmx == bbndx[i:i+4],axis=1))[0][0] for i in range(len(bbndx)-3) ]

# XXX: backbone dihedrals seem to be all in dih9 and none in dih4

In [None]:
len(dih9_ndx), bbndx.shape

In [None]:
off = mol.dihed4.shape[0]*2+mol.angles.shape[0]+mol.bonds.shape[0]
size = mol.dihed9.shape[0]

In [None]:
dec_dih9_sc = dec_out_scaled[off:off+size*2]
dec_dih9_sc.shape

In [None]:
dec_dih9 = np.arctan2(dec_dih9_sc[0:size],dec_dih9_sc[size:]) / np.pi * 180.

In [None]:
bb_angles = [ dec_dih9[i] for i in dih9_ndx ]

In [None]:
with open('posre.itp','w') as p:
    p.write('[ dihedral_restraints ]\n')
    for i in range(0,len(bbndx)-3):
        p.write('  '.join(map(str,bbndx[i:i+4])))
        p.write(f' 1 {bb_angles[i]} 0 5000\n')


In [None]:
with open('min.mdp','w') as m:
    m.write('''
integrator  = steep         ; Algorithm (steep = steepest descent minimization)
emtol       = 1000.0        ; Stop minimization when the maximum force < 1000.0 kJ/mol/nm
emstep      = 0.01          ; Minimization step size
nsteps      = 50000         ; Maximum number of (minimization) steps to perform

nstxout                 = 1         
nstvout                 = 0         
nstfout                 = 0         
nstlog                  = 5
nstxout-compressed      = 1

; Parameters describing how to find the neighbors of each atom and how to calculate the interactions
nstlist         = 1         ; Frequency to update the neighbor list and long range forces
cutoff-scheme   = Verlet    ; Buffered neighbor searching
ns_type         = grid      ; Method to determine neighbor list (simple, grid)
coulombtype     = Cut-off   ; Treatment of long range electrostatic interactions
rcoulomb        = 1.0       ; Short-range electrostatic cut-off
rvdw            = 1.0       ; Short-range Van der Waals cut-off
pbc             = xyz       ; Periodic Boundary Conditions in all 3 dimensions

disre           = Simple
define                  = -DPOSRES 
''')

In [None]:
gmx.grompp(f='min.mdp',c='min-box.gro',p='min.top',o='min.tpr')

In [None]:
nv.show_file('min-box.gro')

In [None]:
gmx.mdrun(deffnm='min')

In [None]:
!tail -30 min.log

In [None]:
tr = md.load('min.trr',top='min-box.gro')[:100]
v=nv.show_mdtraj(tr)
#v.clear()
v.add_representation('licorice')
v

In [None]:
nv.show_file('min-box.gro')

In [None]:
!cp min.top restrained.top

off = mol.dihed4.shape[0]*2+mol.dihed9.shape[0]*2+mol.angles.shape[0]+mol.bonds.shape[0]

dec_dist = dec_out_scaled[off:]


with open('restrained.top','a') as t:
    t.write('''
[ distance_restraints ]
''')
    for i,d in enumerate(sparse_dists.bonds):
        t.write(f'{pdb2gmx[d[0]]} {pdb2gmx[d[1]]} 1 {i} 2 {dec_dist[i]*.99} {dec_dist[i]*1.01} 42.0 10.0\n')

In [None]:
with open('min.mdp','w') as m:
    m.write('''
integrator  = steep         ; Algorithm (steep = steepest descent minimization)
emtol       = 1000.0        ; Stop minimization when the maximum force < 1000.0 kJ/mol/nm
emstep      = 0.01          ; Minimization step size
nsteps      = 50000         ; Maximum number of (minimization) steps to perform

nstxout                 = 0         
nstvout                 = 0         
nstfout                 = 0         
nstlog                  = 5
nstxout-compressed      = 0

; Parameters describing how to find the neighbors of each atom and how to calculate the interactions
nstlist         = 1         ; Frequency to update the neighbor list and long range forces
cutoff-scheme   = Verlet    ; Buffered neighbor searching
ns_type         = grid      ; Method to determine neighbor list (simple, grid)
coulombtype     = Cut-off   ; Treatment of long range electrostatic interactions
rcoulomb        = 1.0       ; Short-range electrostatic cut-off
rvdw            = 1.0       ; Short-range Van der Waals cut-off
pbc             = xyz       ; Periodic Boundary Conditions in all 3 dimensions

disre           = Simple
''')

In [None]:
gmx.grompp(f="min.mdp",c="min-box.gro",p='restrained.top',o="min.tpr")

In [None]:
mdrun(deffnm="min")

In [None]:
solm = md.load('min.gro')
nhs = solm.topology.select('element != H')
solm.atom_slice(nhs,inplace=True)
v = nv.show_mdtraj(solm)

refm = md.load_pdb(conf)
refm.xyz = train_g[:,:,minidx]
refm.superpose(solm)

v.add_component(refm)
v.clear(component=0)
v.clear(component=1)

v.add_representation('licorice',color='green',component=1)
v.add_representation('licorice',color='red',component=0)

v

In [None]:
gmx.energy(f='min.edr',input=['Potential','Dis.-Rest.'],o='min.xvg')

In [None]:
import matplotlib.pyplot as plt
energ = np.loadtxt('min.xvg',comments=['#','@'])
plt.plot(energ[:,1],label='Dis.-Rest.')
plt.plot(energ[:,2],label='Potential')
plt.plot(energ[:,2]-energ[:,1],label='Net')
plt.yscale('log')
plt.legend()
plt.show()
