In [None]:
from openff.units import unit
from perses.protocols.nonequilibrium_cycling import SimulationUnit, ResultUnit
from gufe import ChemicalSystem, SmallMoleculeComponent, ProteinComponent, SolventComponent
from gufe.mapping.ligandatommapping import LigandAtomMapping


In [None]:
# Receptor
protein_comp = ProteinComponent.from_pdb_file(
    "Tyk2_protein.pdb"
)
# Ligands
ligand_a_comp = SmallMoleculeComponent.from_sdf_file(
    'ejm_31.sdf'
)
ligand_b_comp = SmallMoleculeComponent.from_sdf_file(
    'jmc_30.sdf'
)
# Solvent parameters/components
solvent_a_comp = SolventComponent(
    ion_concentration=0.15*unit.molar,
    positive_ion="Na",
    negative_ion="Cl"
)
solvent_b_comp = SolventComponent(
    ion_concentration=0.15*unit.molar,
    positive_ion="Na",
    negative_ion="Cl"
)

# Complex system states
state_a_complex = {'protein': protein_comp, 'ligand': ligand_a_comp, 'solvent': solvent_a_comp}
state_b_complex = {'protein': protein_comp, 'ligand': ligand_b_comp, 'solvent': solvent_b_comp}
# Solvent system states
state_a_solvent = {"ligand": ligand_a_comp, "solvent": solvent_a_comp}
state_b_solvent = {"ligand": ligand_b_comp, "solvent": solvent_b_comp}
# Vacuum system states
state_a_vacuum = {'ligand': ligand_a_comp}
state_b_vacuum = {'ligand': ligand_b_comp}

# Complex chemical systems
system_a_complex = ChemicalSystem(components=state_a_complex)
system_b_complex = ChemicalSystem(components=state_b_complex)
# Solvent chemical systems
system_a_solvent = ChemicalSystem(components=state_a_solvent)
system_b_solvent = ChemicalSystem(components=state_b_solvent)
# Vacuum chemical systems
system_a_vacuum = ChemicalSystem(components=state_a_vacuum)
system_b_vacuum = ChemicalSystem(components=state_b_vacuum)



In [None]:
# Build gufe mapping object
# Manually extracted from perses AtomMapper
#    NOTE: perses has a different sense for the mapping so this is 
#    componentB_to_componentA in gufe terms
mapping_dict = {0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 6: 6,
 7: 7,
 8: 8,
 9: 9,
 10: 10,
 11: 11,
 12: 12,
 13: 13,
 14: 14,
 15: 15,
 16: 16,
 17: 17,
 18: 18,
 20: 23,
 22: 26,
 23: 27,
 24: 28,
 25: 29,
 26: 30,
 27: 31,
 28: 32,
 29: 33}
mapping = LigandAtomMapping(
    componentA=system_a_complex.components['ligand'],
    componentB=system_b_complex.components['ligand'],
    componentA_to_componentB=mapping_dict)

In [None]:
from openff.units import unit
# Build Settings gufe object
from gufe.settings.models import (
    Settings, 
    ThermoSettings, 
)
from perses.protocols.settings import NonEqCyclingSettings

settings = Settings.get_defaults()
settings.thermo_settings.temperature = 300*unit.kelvin
settings.protocol_settings = NonEqCyclingSettings(eq_steps=2500, neq_steps=2500, traj_save_frequency=250, work_save_frequency=25)
# non_eq_settings = settings.NonEqCyclingSettings()

In [None]:
settings.json()

In [None]:
settings.dict()

In [None]:
# Running the NonEq Cycling Protocol
from perses.protocols.nonequilibrium_cycling import NonEquilibriumCyclingProtocol
from gufe.protocols.protocoldag import execute_DAG
neq_cycling = NonEquilibriumCyclingProtocol(settings)
dag_result = execute_DAG(neq_cycling.create(stateA=system_a_vacuum,
                                            stateB=system_b_vacuum,
                                            mapping=mapping,
                                           )
                        )

In [None]:
dag_result.ok()

In [None]:
dag_result.protocol_unit_failures

In [None]:
dag_result.protocol_unit_results[0].outputs

In [None]:
dag_result.protocol_unit_results[1].outputs

In [None]:
# Checking the default settings and running the protocol with defaults
default_neq_cycling = NonEquilibriumCyclingProtocol(NonEquilibriumCyclingProtocol.default_settings())
default_dag_result = execute_DAG(default_neq_cycling.create(stateA=system_a_vacuum,
                                                            stateB=system_b_vacuum,
                                                            mapping=mapping,
                                                           )
                                )

## debugging units

In [None]:
# Debugging SimulationUnit
simulation = SimulationUnit(state_a=system_a, state_b=system_b, mapping=mapping, settings=settings)
result = simulation.execute(shared='/tmp/', **simulation.inputs)


In [None]:
result.ok()

In [None]:
print(result.traceback)

In [None]:
# GatherUnit -- Results
neq_results = ResultUnit(phase="vacuum")
# result = neq_results.execute(shared='/tmp/', **neq_results.inputs)
result = neq_results.execute(shared='/tmp/', **neq_results.inputs)

In [None]:
result.outputs

In [None]:
failure = dag_result.protocol_unit_failures[0]

In [None]:
print(failure.traceback)

In [None]:
result_graph = dag_result.result_graph

In [None]:
result_graph.nodes

In [None]:
dag_result.graph.nodes