In [None]:
import gromacs as gmx
from gromacs import formats as gmf
import numpy as np
import mdtraj as md
import nglview as nv
import os

In [None]:
os.system("pwd")

In [None]:
dir='../../My_AE/trpcage/ae/'
ftrain=dir + 'trpcage_ds.npy'
pdb = dir + 'trpcage_npt400.pdb'
dec = "../../My_AE/trpcage/MC/ae_noise_1/dec_out_LS_matrix_pos_32x32.npy"
xtc = dir + 'trpcage_ds_fit.xtc'

train = np.load(ftrain)

train_view = np.reshape(train,(-1,272,3))

train.shape

In [None]:
tr = md.load(pdb)

tr.xyz = train_view

v=nv.show_mdtraj(tr)
v.add_representation("licorice")
v

In [None]:
dec_out_all = np.load(dec)
dec_out_all.shape

In [None]:
dec_out_all_r = np.reshape(dec_out_all, (32*32, 816))
dec_out_all_r.shape

In [None]:
def Unstandardize(x):
    x = np.array(x)
    return ((x*7)-1)

In [None]:
dec_out_all_r = Unstandardize(dec_out_all_r)

In [None]:
dec_out = dec_out_all[3,27]

In [None]:
diff = train - dec_out
msd = np.sum(diff * diff,axis=1)

In [None]:
imin = np.argmin(msd)
imin

In [None]:
tr2 = md.load(pdb)
print(tr2.xyz.shape)
tr2.xyz = np.reshape(dec_out,(-1,272,3))
#tr2.xyz = train_view[np.argmin(msd)]
v2 = nv.show_mdtraj(tr2)
#v2.clear()
v2.add_representation("licorice")
v2

In [None]:
v.frame=int(imin)

In [None]:
tr3 = md.load(xtc,top=pdb)
tr3.xyz = np.reshape(dec_out_all,(-1,272,3))
tr3.time = tr3.time[0:tr3.xyz.shape[0]]
tr3.unitcell_vectors = tr3.unitcell_vectors[0:tr3.xyz.shape[0]]

tr3.save_xtc('generated.xtc')

### Identify backbone atoms

In [None]:
!cp {pdb} reference.pdb

In [None]:
gmx.select(s='reference.pdb',on='bb.ndx',select='Backbone')

### Compute backbone dihedrals of the generated structures

In [None]:
with open('bb.ndx') as bb:
    bb.readline()
    l = " ".join(bb.readlines())
    ndx = np.array(l.split(),np.int32)

In [None]:
with open('angle.ndx','w') as a:
    a.write('[ dihedrals ]\n')
    for i in range(0,len(ndx)-3):
        a.write('  '.join(map(str,ndx[i:i+4])))
        a.write('\n')

In [None]:
gmx.gangle(f='generated.xtc',n='angle.ndx',g1='dihedral',group1='dihedrals',oall='dihedrals.xvg')

In [None]:
dihs = gmf.XVG('dihedrals.xvg').array[1:]

### Generate restrained topology and run minimization

In [None]:
gmx.pdb2gmx(f='reference.pdb',o='reference.gro',p='reference.top',water='tip3p',ff='amber99')

In [None]:
with open('min.mdp','w') as m:
    m.write('''
integrator  = steep         ; Algorithm (steep = steepest descent minimization)
emtol       = 100.0        ; Stop minimization when the maximum force < 1000.0 kJ/mol/nm
emstep      = 0.01          ; Minimization step size
nsteps      = 50000         ; Maximum number of (minimization) steps to perform

nstxout                 = 1         
nstvout                 = 0         
nstfout                 = 0         
nstlog                  = 5
nstxout-compressed      = 1

; Parameters describing how to find the neighbors of each atom and how to calculate the interactions
nstlist         = 1         ; Frequency to update the neighbor list and long range forces
cutoff-scheme   = Verlet    ; Buffered neighbor searching
ns_type         = grid      ; Method to determine neighbor list (simple, grid)
coulombtype     = Cut-off   ; Treatment of long range electrostatic interactions
rcoulomb        = 1.0       ; Short-range electrostatic cut-off
rvdw            = 1.0       ; Short-range Van der Waals cut-off
pbc             = xyz       ; Periodic Boundary Conditions in all 3 dimensions

disre           = Simple
define                  = -DPOSRES 
''')

In [None]:
np.unravel_index(123, shape=((32,32)))

In [None]:
np.ravel_multi_index([3,27], dims=((32,32)))

In [None]:
%%time 
for n in range(1,2):
    n = np.ravel_multi_index([1,1], dims=((32,32)))
    with open('posre.itp','w') as p:
        p.write('[ dihedral_restraints ]\n')
        for i in range(0,len(ndx)-3):
            p.write('  '.join(map(str,ndx[i:i+4])))
            p.write(f' 1 {dihs[i,n]} 0 5000\n')
    
    gtr = tr3[n]
    gtr.save_gro('min.gro')
    
    gmx.grompp(f='min.mdp',c='min.gro',p='reference.top',o='min.tpr',maxwarn=100000)
    
    gmx.mdrun(deffnm='min')

In [None]:
v=nv.show_file('min.gro')
v.add_representation('licorice')
v

In [None]:
v2=nv.show_mdtraj(tr3[n])
v2.add_representation('licorice')
v2

In [None]:
gmx.gangle(f='min.gro',n='angle.ndx',g1='dihedral',group1='dihedrals',oall='min.xvg')

In [None]:
!cat min.xvg

In [None]:
dihs[:,n]

In [None]:
gmx.energy(f='min.edr',input=['Potential','Dih.-Rest.'],o='emin.xvg')

In [None]:
import matplotlib.pyplot as plt
energ = np.loadtxt('emin.xvg',comments=['#','@'])
plt.plot(energ[10:,1],label='Dih.-Rest.')
plt.plot(energ[10:,2],label='Potential')
plt.plot(energ[10:,2]-energ[10:,1],label='Net')
plt.yscale('log')
plt.legend()
plt.show()


In [None]:
pot_E = energ[-1, 2] - energ[-1, 1]
(energ[-1, 2], energ[-1, 1], pot_E)

# kJ/mol

In [None]:
mtr = md.load('min.trr',top=pdb)
mtr.xyz = np.vstack((np.zeros((mtr.xyz[-2:-1,:,:].shape)), mtr.xyz))
mtr.xyz[0,:,:] = mtr.xyz[-1,:,:]
v = nv.show_mdtraj(mtr)
v.add_representation('licorice')
v

## Minimize all generated structures

In [None]:
matrix_res = 32
LS_matrix_pos_min = np.zeros((matrix_res,matrix_res, 272*3))
LS_matrix_E_min = np.zeros((matrix_res,matrix_res))

In [None]:
%%time
for n in range(0, matrix_res*matrix_res):

    !rm *\#
    
    j,k = np.unravel_index(n, shape=((32,32)))
    print(f"\n\n***   ***   {n}   ***   ***\n\n")
    with open('posre.itp','w') as p:
        p.write('[ dihedral_restraints ]\n')
        for i in range(0,len(ndx)-3):
            p.write('  '.join(map(str,ndx[i:i+4])))
            p.write(f' 1 {dihs[i,n]} 0 5000\n')
    
    gtr = tr3[n]
    gtr.save_gro('min.gro')
    
    gmx.grompp(f='min.mdp',c='min.gro',p='reference.top',o='min.tpr',maxwarn=1)

    try:
        gmx.mdrun(deffnm='min')
    
        LS_matrix_pos_min[j,k,:] = np.reshape(md.load("min.gro", top=pdb).xyz, (272*3))
    
        gmx.energy(f='min.edr',input=['Potential','Dih.-Rest.'],o='emin.xvg')
        energ = np.loadtxt('emin.xvg',comments=['#','@'])
        try:
            LS_matrix_E_min[j,k] = energ[-1,2] - energ[-1,1]
        except IndexError:
            print(energ)
            LS_matrix_E_min[j,k] = energ[2] - energ[1]
            
    except:
        LS_matrix_pos_min[j,k,:] = np.reshape(md.load("min.gro", top=pdb).xyz, (272*3))
        LS_matrix_E_min[j,k] = 1e12

np.save("LS_matrix_E_min.npy", LS_matrix_E_min*1000)
np.save("LS_matrix_pos_min.npy", LS_matrix_pos_min)

In [None]:
plt.imshow(np.rot90(LS_matrix_E_min, axes=(0,1)), vmax=2000)
plt.colorbar()