In [None]:
import torch
import torch.nn as nn
from openmmtorch import TorchForce
import os

import yaml
from pathlib import Path


import numpy as np
from openmm.app import *
from openmm import *
from openmm.unit import *

from cgmap.mapping import Mapper
from maputils import EquiValReporter

import training_modules as tm

from torch.optim import Adam
from torch.nn import MSELoss, L1Loss
import matplotlib.pyplot as plt
from torch.utils.data import Dataset , DataLoader
from training_utils import CGDataset , TrainSystem


In [None]:
# MOVE TO DEDICATED NOTEBOOK/SCRIPT #


# # Assign directory
# input_directory = 'StartingStructures'

# output_directory = '../datasets/NoPBCChiignolinVoid'

# trj_list = []

# for name in os.listdir(output_directory):
#     if name.endswith(".trr"):
#         trj_list.append(output_directory + '/' + name)

# print(trj_list)

# for sim_path in trj_list:

#     config = {
#         'mapping': 'martini3',
#         'input': 'StartingStructures/chignolin_frames0.pdb',
#         'inputtraj': [sim_path],
#         'selection': 'protein',
#         'output': 'pro.gro',
#         'outputtraj': 'xtc',
#         #'trajslice': slice(100,10000)
#     }

#     # Map atoms to beads (only protein for now)
#     mapping = Mapper(config)
#     mapping.map()
#     try:
#         dataset["bead_forces"] = np.append(dataset["bead_forces"], mapping.dataset['bead_forces'],axis=0)
#         dataset["bead_pos"] = np.append(dataset["bead_pos"], mapping.dataset['bead_pos'],axis=0)

#     except:
#         dataset = mapping.dataset

#     print(dataset['bead_forces'].shape)
#     # force_set = forces[:,162]

# current_dir = os.getcwd()

# output_file = os.path.join(current_dir, "dataset.NoPBCVoidALLFrames.npz")
# np.savez(output_file, **dataset)
# print(f"{output_file} successfully saved!")

In [None]:
dataset = dict(np.load('dataset.NoPBCVoidALLFrames.npz', allow_pickle=True))

In [None]:
equvalrep = EquiValReporter(dataset=dataset)

equvalrep.bondMapper(config_file_path="config/bond_config.yaml")
equvalrep.angleMapper(conf_angles_path="test_conf/config.angles.yaml")
equvalrep.improperDihedralMapper(conf_angles_path="test_conf/config.dihedrals.yaml")
equvalrep.beadChargeMapper()

In [None]:
equvalrep.reportEquiVals(reportPath='test_conf/')

In [None]:
dataset = equvalrep.getDataset()

# print(dataset['bead_forces'][:].sum(axis=-1).sum(axis=-1))

# dataset['bead_forces'] = (dataset['bead_forces'][:,:,:] - (dataset['bead_forces'][:,:,:].sum(axis=1)/len(dataset['bead_types']))[:,None,:])

# print(dataset['bead_forces'][:].sum(axis=-1).sum(axis=-1))


In [None]:
conf_bonds: dict = equvalrep.getBonds()
conf_angles: dict = equvalrep.getAngles()
conf_dihedrals: dict = equvalrep.getImproperDihs()
conf_bead_charges: dict = equvalrep.getBeadCharges()

In [None]:
dataset['bead_forces'].shape

In [None]:
# print(torch.cuda.current_device())
# torch.cuda.set_device(0)

system = TrainSystem(dataset, conf_bonds, conf_angles, conf_dihedrals, conf_bead_charges)#, device_index=1

model = system.initiateTraining(dataset = dataset, train_steps=800, batch_size=128, patience=30, model_name='DihandAnglesTestTruncated')

In [None]:
system.plotLosses(truncate=10)

In [None]:
system.plotForceMagnitudeMatching(bead_index=0,to_frame=30)

In [None]:
system.plotForceMathingByFrame(frame=10)

In [None]:
system.plotInitialForceGuess()

In [None]:
system.plotForceMathing(bead_index=0,to_frame=78)

In [None]:
system.plotValForceMathing(bead_index=0,to_frame=14)

In [None]:
system.plotABSForceMathing(frame=77)

In [None]:
system.state_dict()

In [None]:

pdb_file = os.path.join("/home/enere@usi.ch/FMartIP/ChigCG.pdb" ) #'/home/enere@usi.ch/FMartIP/original_CG_A2A.pdb' "ChignCG_unfolded.pdb" "original_CG_A2A.pdb" "chig_CG/original_CG_a2a_Water.pdb" 
# "/home/enere@usi.ch/FMartIP/chig_CG/original_CG_a2a_4.pdb"
pdb = PDBFile(pdb_file) # OpenMM loader

In [None]:
index = 300    
for atom, bead in zip(pdb.topology.atoms(), np.unique(dataset['bead_names'])):
    # print(chr(index + 150))
    i = dataset['bead_types'][np.where(dataset['bead_names'] == bead)]
    print(i[0])
    mass = 0 #dataset['bead_mass_dict'][bead]
    print(mass*amu)
    print(bead)
    atom.element = Element(number = i[0], name = bead, symbol = chr(index), mass = mass*amu)
    index +=3
    print(Element.getByAtomicNumber(i[0]))

In [None]:
system = System()

for atom in pdb.topology.atoms():
    # print(atom)
    # print(dataset['bead_mass'][atom.index])
    system.addParticle(atom.element.mass)

boxVectors = pdb.topology.getPeriodicBoxVectors()
if boxVectors is not None:
    system.setDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2])
print(boxVectors)


In [None]:
force = TorchForce(model)

integrator = NoseHooverIntegrator(300*kelvin, 1/picosecond, 0.010*picoseconds)


while system.getNumForces() > 0:
    system.removeForce(0)
    
# The system should not contain any additional force and constrains
assert system.getNumConstraints() == 0
assert system.getNumForces() == 0

# Add the NNP to the system
system.addForce(force)

# This line combines the molecular topology, system, and integrator to begin a new simulation. It creates a Simulation object and assigns it to a variable called simulation. 
# A Simulation object manages all the processes involved in running a simulation, such as advancing time and writing output.
simulation = Simulation(pdb.topology, system, integrator)
simulation.context.setPositions(pdb.positions)

# Performs a local energy minimization. It is usually a good idea to do this at the start of a simulation, since the coordinates in the PDB file might produce very large forces.
simulation.minimizeEnergy()
print("starting Sim")

simulation.reporters.append(PDBReporter('output.pdb', 100))
simulation.reporters.append(StateDataReporter('output.dat', 100, step=True, potentialEnergy=True, kineticEnergy=True, temperature=True, time=True, totalEnergy=True))

#This line adds another reporter to print out some basic information every 1000 time steps
simulation.step(50000)
state = simulation.context.getState(getPositions=True, getEnergy=True, getForces=True)
f = np.array([[a.x,a.y,a.z]for a in state.getForces()])
p = np.array([[a.x,a.y,a.z]for a in state.getPositions()])
# print(state.getForces(), state.getPositions())