-
Notifications
You must be signed in to change notification settings - Fork 75
/
multistatesampler.py
1800 lines (1481 loc) · 85.7 KB
/
multistatesampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
MultistateSampler
=================
Base multi-thermodynamic state multistate class
COPYRIGHT
Current version by Andrea Rizzi <andrea.rizzi@choderalab.org>, Levi N. Naden <levi.naden@choderalab.org> and
John D. Chodera <john.chodera@choderalab.org> while at Memorial Sloan Kettering Cancer Center.
Original version by John D. Chodera <jchodera@gmail.com> while at the University of
California Berkeley.
LICENSE
This code is licensed under the latest available version of the MIT License.
"""
# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================
import os
import copy
import time
import typing
import inspect
import logging
import datetime
import numpy as np
try:
import openmm
from openmm import unit
except ImportError: # OpenMM < 7.6
from simtk import unit, openmm
from openmmtools import multistate, utils, states, mcmc, cache
import mpiplus
from openmmtools.multistate.utils import SimulationNaNError
from pymbar.utils import ParameterError
from openmmtools.integrators import FIREMinimizationIntegrator
logger = logging.getLogger(__name__)
# ==============================================================================
# MULTISTATE SAMPLER
# ==============================================================================
class MultiStateSampler(object):
"""
Base class for samplers that sample multiple thermodynamic states using
one or more replicas.
This base class provides a general simulation facility for multistate from multiple
thermodynamic states, allowing any set of thermodynamic states to be specified.
If instantiated on its own, the thermodynamic state indices associated with each
state are specified and replica mixing does not change any thermodynamic states,
meaning that each replica remains in its original thermodynamic state.
Stored configurations, energies, swaps, and restart information are all written
to a single output file using the platform portable, robust, and efficient
NetCDF4 library.
Parameters
----------
mcmc_moves : MCMCMove or list of MCMCMove, optional
The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves,
they will be assigned to the correspondent thermodynamic state on
creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate,
and 500 steps per iteration will be used.
number_of_iterations : int or infinity, optional, default: 1
The number of iterations to perform. Both ``float('inf')`` and
``numpy.inf`` are accepted for infinity. If you set this to infinity,
be sure to set also ``online_analysis_interval``.
online_analysis_interval : None or Int >= 1, optional, default: 200
Choose the interval at which to perform online analysis of the free energy.
After every interval, the simulation will be stopped and the free energy estimated.
If the error in the free energy estimate is at or below ``online_analysis_target_error``, then the simulation
will be considered completed.
If set to ``None``, then no online analysis is performed
online_analysis_target_error : float >= 0, optional, default 0.0
The target error for the online analysis measured in kT per phase.
Once the free energy is at or below this value, the phase will be considered complete.
If ``online_analysis_interval`` is None, this option does nothing.
Default is set to 0.0 since online analysis runs by default, but a finite ``number_of_iterations`` should also
be set to ensure there is some stop condition. If target error is 0 and an infinite number of iterations is set,
then the sampler will run until the user stop it manually.
online_analysis_minimum_iterations : int >= 0, optional, default 200
Set the minimum number of iterations which must pass before online analysis is carried out.
Since the initial samples likely not to yield a good estimate of free energy, save time and just skip them
If ``online_analysis_interval`` is None, this does nothing
locality : int > 0, optional, default None
If None, the energies at all states will be computed for every replica each iteration.
If int > 0, energies will only be computed for states ``range(max(0, state-locality), min(n_states, state+locality))``.
Attributes
----------
n_replicas
n_states
iteration
mcmc_moves
sampler_states
metadata
is_completed
energy_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache
Context cache to be used for energy computations. Defaults to using global context cache.
sampler_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache
Context cache to be used for propagation. Defaults to using global context cache.
Examples
--------
Sampling multiple states of an alanine dipeptide in implicit solvent system.
>>> import math
>>> import tempfile
>>> from openmm import unit
>>> from openmmtools import testsystems, states, mcmc
>>> from openmmtools.multistate import MultiStateSampler, MultiStateReporter
>>> testsystem = testsystems.AlanineDipeptideImplicit()
Create thermodynamic states
>>> n_replicas = 3
>>> T_min = 298.0 * unit.kelvin # Minimum temperature.
>>> T_max = 600.0 * unit.kelvin # Maximum temperature.
>>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
... for i in range(n_replicas)]
>>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
... for i in range(n_replicas)]
>>> thermodynamic_states = [states.ThermodynamicState(system=testsystem.system, temperature=T)
... for T in temperatures]
Initialize simulation object with options. Run with a GHMC integrator.
>>> move = mcmc.GHMCMove(timestep=2.0*unit.femtoseconds, n_steps=50)
>>> simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2)
Create simulation and store output in temporary file
>>> storage_path = tempfile.NamedTemporaryFile(delete=False).name + '.nc'
>>> reporter = MultiStateReporter(storage_path, checkpoint_interval=1)
>>> simulation.create(thermodynamic_states=thermodynamic_states,
... sampler_states=states.SamplerState(testsystem.positions), storage=reporter)
Optionally, specify unlimited context cache attributes using the fastest mixed precision platform
>>> from openmmtools.cache import ContextCache
>>> from openmmtools.utils import get_fastest_platform
>>> platform = get_fastest_platform(minimum_precision='mixed')
>>> simulation.energy_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform)
>>> simulation.sampler_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform)
Run the simulation
>>> simulation.run()
Note that to avoid the repex cycling problem upon resuming a simulation, make sure to specify do the optional
energy and sampler context caches.
>>> reporter = MultiStateReporter(reporter_file, checkpoint_interval=10)
>>> simulation = HybridRepexSampler.from_storage(reporter)
>>> simulation.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)
>>> simulation.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)
>>> simulation.extend(n_iterations=1)
"""
# -------------------------------------------------------------------------
# Constructors.
# -------------------------------------------------------------------------
def __init__(self, mcmc_moves=None, number_of_iterations=1,
online_analysis_interval=200, online_analysis_target_error=0.0,
online_analysis_minimum_iterations=200,
locality=None):
# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
# Display cuda device in debug log
self._display_cuda_devices()
# These will be set on initialization. See function
# create() for explanation of single variables.
self._thermodynamic_states = None
self._unsampled_states = None
self._sampler_states = None
self._replica_thermodynamic_states = None
self._iteration = None
self._energy_thermodynamic_states = None
self._neighborhoods = None
self._energy_unsampled_states = None
self._n_accepted_matrix = None
self._n_proposed_matrix = None
self._reporter = None
self._metadata = None
self._timing_data = dict()
# Handling default propagator.
if mcmc_moves is None:
# This will be converted to a list in create().
self._mcmc_moves = mcmc.LangevinDynamicsMove(timestep=2.0 * unit.femtosecond,
collision_rate=5.0 / unit.picosecond,
n_steps=500, reassign_velocities=True,
n_restart_attempts=6)
else:
self._mcmc_moves = copy.deepcopy(mcmc_moves)
# Store constructor parameters. Everything is marked for internal
# usage because any change to these attribute implies a change
# in the storage file as well. Use properties for checks.
self.number_of_iterations = number_of_iterations
# Store locality
self.locality = locality
# Online analysis options.
self.online_analysis_interval = online_analysis_interval
self.online_analysis_target_error = online_analysis_target_error
self.online_analysis_minimum_iterations = online_analysis_minimum_iterations
self._online_error_trap_counter = 0 # Counter for errors in the online estimate
self._online_error_bank = []
self._last_mbar_f_k = None
self._last_err_free_energy = None
self._have_displayed_citations_before = False
# Initializing context cache attributes
self._initialize_context_caches()
# Check convergence.
if self.number_of_iterations == np.inf:
if self.online_analysis_target_error == 0.0:
logger.warning("WARNING! You have specified an unlimited number of iterations and a target error "
"for online analysis of 0.0! Your simulation may never reach 'completed' state!")
elif self.online_analysis_interval is None:
logger.warning("WARNING! This simulation will never be considered 'complete' since there is no "
"specified maximum number of iterations!")
@classmethod
def from_storage(cls, storage):
"""Constructor from an existing storage file.
Parameters
----------
storage : str or Reporter
If str: The path to the storage file.
If :class:`Reporter`: uses the :class:`Reporter` options
In the future this will be able to take a Storage class as well.
Returns
-------
sampler : MultiStateSampler
A new instance of MultiStateSampler (or subclass) in the same state of the
last stored iteration.
"""
# Handle case in which storage is a string.
reporter = cls._reporter_from_storage(storage, check_exist=True)
try:
# Open the reporter to read the data.
reporter.open(mode='r')
sampler = cls._instantiate_sampler_from_reporter(reporter)
sampler._restore_sampler_from_reporter(reporter)
finally:
# Close reporter in reading mode.
reporter.close()
# We open the reporter only in node 0 in append mode ready for use
sampler._reporter = reporter
mpiplus.run_single_node(0, sampler._reporter.open, mode='a',
broadcast_result=False, sync_nodes=False)
# Don't write the new last iteration, we have not technically
# written anything yet, so there is no "junk".
return sampler
# TODO use Python 3.6 namedtuple syntax when we drop Python 3.5 support.
Status = typing.NamedTuple('Status', [
('iteration', int),
('target_error', float),
('is_completed', bool)
])
@classmethod
def read_status(cls, storage):
"""Read the status of the calculation from the storage file.
This class method can be used to quickly check the status of the
simulation before loading the full ``ReplicaExchange`` object
from disk.
Parameters
----------
storage : str or Reporter
The path to the storage file or the reporter object.
Returns
-------
status : ReplicaExchange.Status
The status of the replica-exchange calculation. It has three
fields: ``iteration``, ``target_error``, and ``is_completed``.
"""
# Handle case in which storage is a string.
reporter = cls._reporter_from_storage(storage, check_exist=True)
# Read iteration and online analysis info.
try:
reporter.open(mode='r')
options = reporter.read_dict('options')
iteration = reporter.read_last_iteration(last_checkpoint=False)
# Search for last cached free energies only if online analysis is activated.
target_error = None
last_err_free_energy = None
# Check if online analysis is set AND that the target error is a stopping condition (> 0)
if (options['online_analysis_interval'] is not None and
options['online_analysis_target_error'] != 0.0):
target_error = options['online_analysis_target_error']
try:
last_err_free_energy = cls._read_last_free_energy(reporter, iteration)[1][1]
except TypeError:
# Trap for undefined free energy (has not been run yet)
last_err_free_energy = np.inf
finally:
reporter.close()
# Check if the calculation is done.
number_of_iterations = options['number_of_iterations']
online_analysis_target_error = options['online_analysis_target_error']
is_completed = cls._is_completed_static(number_of_iterations, iteration,
last_err_free_energy,
online_analysis_target_error)
return cls.Status(iteration=iteration, target_error=target_error,
is_completed=is_completed)
# -------------------------------------------------------------------------
# Public properties.
# -------------------------------------------------------------------------
@property
def n_states(self):
"""The integer number of thermodynamic states (read-only)."""
if self._thermodynamic_states is None:
return 0
else:
return len(self._thermodynamic_states)
@property
def n_replicas(self):
"""The integer number of replicas (read-only)."""
if self._sampler_states is None:
return 0
else:
return len(self._sampler_states)
@property
def iteration(self):
"""The integer current iteration of the simulation (read-only).
If the simulation has not been created yet, this is None.
"""
return self._iteration
@property
def mcmc_moves(self):
"""A copy of the MCMCMoves list used to propagate the simulation.
This can be set only before creation.
"""
return copy.deepcopy(self._mcmc_moves)
@mcmc_moves.setter
def mcmc_moves(self, new_value):
if self._thermodynamic_states is not None:
# We can't modify representation of the MCMCMoves because it's
# impossible to delete groups/variables from an NetCDF file. We
# could support this by JSONizing the dict serialization and
# store it as a string instead, if we needed this.
raise RuntimeError('Cannot modify MCMCMoves after creation.')
# If this is a single MCMCMove, it'll be transformed to a list in create().
self._mcmc_moves = copy.deepcopy(new_value)
@property
def sampler_states(self):
"""A copy of the sampler states list at the current iteration.
This can be set only before running.
"""
return copy.deepcopy(self._sampler_states)
@sampler_states.setter
def sampler_states(self, value):
if self._iteration != 0:
raise RuntimeError('Sampler states can be assigned only between '
'create() and run().')
if len(value) != self.n_replicas:
raise ValueError('Passed {} sampler states for {} replicas'.format(
len(value), self.n_replicas))
# Update sampler state in the object and on storage.
self._sampler_states = copy.deepcopy(value)
mpiplus.run_single_node(0, self._reporter.write_sampler_states,
self._sampler_states, self._iteration)
@property
def is_periodic(self):
"""Return True if system is periodic, False if not, and None if not initialized"""
if self._sampler_states is None:
return None
return self._thermodynamic_states[0].is_periodic
class _StoredProperty(object):
"""
Descriptor of a property stored as an option.
validate_function is a simple function for checking things like "X > 0", but exposes both the
ReplicaExchange instance and the new value for the variable, in that order.
More complex checks which relies on the ReplicaExchange instance, like "if Y == True, then check X" can be
accessed through the instance object of the function
"""
def __init__(self, option_name, validate_function=None):
self._option_name = option_name
self._validate_function = validate_function
def __get__(self, instance, owner_class=None):
return getattr(instance, '_' + self._option_name)
def __set__(self, instance, new_value):
if self._validate_function is not None:
new_value = self._validate_function(instance, new_value)
setattr(instance, '_' + self._option_name, new_value)
# Update storage if we ReplicaExchange is initialized.
if instance._reporter is not None and instance._reporter.is_open():
mpiplus.run_single_node(0, instance._store_options)
# ----------------------------------
# Value Validation of the properties
# Should be @staticmethod with arguments of (instance, value) in that order, even if instance is not used
# ----------------------------------
@staticmethod
def _number_of_iterations_validator(_, number_of_iterations):
# Support infinite number of iterations.
if not (0 <= number_of_iterations <= float('inf')):
raise ValueError('Accepted values for number_of_iterations are'
'non-negative integers and infinity.')
return number_of_iterations
@staticmethod
def _oa_interval_validator(_, online_analysis_interval):
"""Check the online_analysis_interval value for consistency"""
if online_analysis_interval is not None and (
type(online_analysis_interval) != int or online_analysis_interval < 1):
raise ValueError('online_analysis_interval must be an integer >=1 or None')
return online_analysis_interval
@staticmethod
def _oa_target_error_validator(instance, online_analysis_target_error):
if instance.online_analysis_interval is not None:
if online_analysis_target_error < 0:
raise ValueError("online_analysis_target_error must be a float >= 0")
elif online_analysis_target_error == 0 and instance.number_of_iterations is None:
logger.warning("online_analysis_target_error of 0 and number of iterations undefined "
"will never converge!")
return online_analysis_target_error
@staticmethod
def _oa_min_iter_validator(instance, online_analysis_minimum_iterations):
if (instance.online_analysis_interval is not None and
(type(
online_analysis_minimum_iterations) is not int or online_analysis_minimum_iterations < 0)):
raise ValueError("online_analysis_minimum_iterations must be an integer >= 0")
return online_analysis_minimum_iterations
@staticmethod
def _locality_validator(_, locality):
if locality is not None:
if (type(locality) != int) or (locality <= 0):
raise ValueError("locality must be an int > 0")
return locality
number_of_iterations = _StoredProperty('number_of_iterations',
validate_function=_StoredProperty._number_of_iterations_validator)
online_analysis_interval = _StoredProperty('online_analysis_interval',
validate_function=_StoredProperty._oa_interval_validator) #:interval to carry out online analysis
online_analysis_target_error = _StoredProperty('online_analysis_target_error',
validate_function=_StoredProperty._oa_target_error_validator)
online_analysis_minimum_iterations = _StoredProperty('online_analysis_minimum_iterations',
validate_function=_StoredProperty._oa_min_iter_validator)
locality = _StoredProperty('locality', validate_function=_StoredProperty._locality_validator)
@property
def metadata(self):
"""A copy of the metadata dictionary passed on creation (read-only)."""
return copy.deepcopy(self._metadata)
@property
def is_completed(self):
"""Check if we have reached any of the stop target criteria (read-only)"""
return self._is_completed()
# -------------------------------------------------------------------------
# Main public interface.
# -------------------------------------------------------------------------
_TITLE_TEMPLATE = ('Multi-state sampler simulation created using MultiStateSampler class '
'of yank.multistate on {}')
def create(self, thermodynamic_states: list, sampler_states, storage,
initial_thermodynamic_states=None, unsampled_thermodynamic_states=None,
metadata=None):
"""Create new multistate sampler simulation.
Parameters
----------
thermodynamic_states : list of states.ThermodynamicState
Thermodynamic states to simulate, where one replica is allocated per state.
Each state must have a system with the same number of atoms.
sampler_states : states.SamplerState or list
One or more sets of initial sampler states.
The number of replicas is taken to be the number of sampler states provided.
If the sampler states do not have box_vectors attached and the system is periodic,
an exception will be thrown.
storage : str or instanced Reporter
If str: the path to the storage file. Default checkpoint options from Reporter class are used
If Reporter: Uses the reporter options and storage path
In the future this will be able to take a Storage class as well.
initial_thermodynamic_states : None or list or array-like of int of length len(sampler_states), optional,
default: None.
Initial thermodynamic_state index for each sampler_state.
If no initial distribution is chosen, ``sampler_states`` are distributed between the
``thermodynamic_states`` following these rules:
* If ``len(thermodynamic_states) == len(sampler_states)``: 1-to-1 distribution
* If ``len(thermodynamic_states) > len(sampler_states)``: First and last state distributed first
remaining ``sampler_states`` spaced evenly by index until ``sampler_states`` are depleted.
If there is only one ``sampler_state``, then the only first ``thermodynamic_state`` will be chosen
* If ``len(thermodynamic_states) < len(sampler_states)``, each ``thermodynamic_state`` receives an
equal number of ``sampler_states`` until there are insufficient number of ``sampler_states`` remaining
to give each ``thermodynamic_state`` an equal number. Then the rules from the previous point are
followed.
unsampled_thermodynamic_states : list of states.ThermodynamicState, optional, default=None
These are ThermodynamicStates that are not propagated, but their
reduced potential is computed at each iteration for each replica.
These energy can be used as data for reweighting schemes (default
is None).
metadata : dict, optional, default=None
Simulation metadata to be stored in the file.
"""
# Handle case in which storage is a string and not a Reporter object.
self._reporter = self._reporter_from_storage(storage, check_exist=False)
# Check if netcdf files exist. This is run only on MPI node 0 and
# broadcasted. This is to avoid the case where the other nodes
# arrive to this line after node 0 has already created the storage
# file, causing an error.
if mpiplus.run_single_node(0, self._reporter.storage_exists, broadcast_result=True):
raise RuntimeError('Storage file {} already exists; cowardly '
'refusing to overwrite.'.format(self._reporter.filepath))
# Make sure sampler_states is an iterable of SamplerStates.
if isinstance(sampler_states, states.SamplerState):
sampler_states = [sampler_states]
# Initialize internal attribute and dataset.
self._pre_write_create(thermodynamic_states, sampler_states, storage,
initial_thermodynamic_states=initial_thermodynamic_states,
unsampled_thermodynamic_states=unsampled_thermodynamic_states,
metadata=metadata)
# Display papers to be cited.
self._display_citations()
self._initialize_reporter()
@utils.with_timer('Minimizing all replicas')
def minimize(self, tolerance=1.0 * unit.kilojoules_per_mole / unit.nanometers,
max_iterations=0):
"""Minimize all replicas.
Minimized positions are stored at the end.
Parameters
----------
tolerance : openmm.unit.Quantity, optional
Minimization tolerance (units of energy/mole/length, default is
``1.0 * unit.kilojoules_per_mole / unit.nanometers``).
max_iterations : int, optional
Maximum number of iterations for minimization. If 0, minimization
continues until converged.
"""
# Check that simulation has been created.
if self.n_replicas == 0:
raise RuntimeError('Cannot minimize replicas. The simulation must be created first.')
logger.debug("Minimizing all replicas...")
# Distribute minimization across nodes. Only node 0 will get all positions.
# The other nodes, only need the positions that they use for propagation and
# computation of the energy matrix entries.
minimized_positions, sampler_state_ids = mpiplus.distribute(self._minimize_replica, range(self.n_replicas),
tolerance, max_iterations,
send_results_to=0)
# Update all sampler states. For non-0 nodes, this will update only the
# sampler states associated to the replicas propagated by this node.
for sampler_state_id, minimized_pos in zip(sampler_state_ids, minimized_positions):
self._sampler_states[sampler_state_id].positions = minimized_pos
# Save the stored positions in the storage
mpiplus.run_single_node(0, self._reporter.write_sampler_states, self._sampler_states, self._iteration)
def equilibrate(self, n_iterations, mcmc_moves=None):
"""Equilibrate all replicas.
This does not increase the iteration counter. The equilibrated
positions are stored at the end.
Parameters
----------
n_iterations : int
Number of equilibration iterations.
mcmc_moves : MCMCMove or list of MCMCMove, optional
Optionally, the MCMCMoves to use for equilibration can be
different from the ones used in production.
"""
# Check that simulation has been created.
if self.n_replicas == 0:
raise RuntimeError('Cannot equilibrate replicas. The simulation must be created first.')
# If no MCMCMove is specified, use the ones for production.
if mcmc_moves is None:
mcmc_moves = self._mcmc_moves
# Make sure there is one MCMCMove per thermodynamic state.
if isinstance(mcmc_moves, mcmc.MCMCMove):
mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)]
elif len(mcmc_moves) != self.n_states:
raise RuntimeError('The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration'
' must be the same.'.format(len(self._mcmc_moves), self.n_states))
timer = utils.Timer()
timer.start('Run Equilibration')
# Temporarily set the equilibration MCMCMoves.
production_mcmc_moves = self._mcmc_moves
self._mcmc_moves = mcmc_moves
for iteration in range(1, 1 + n_iterations):
logger.debug("Equilibration iteration {}/{}".format(iteration, n_iterations))
timer.start('Equilibration Iteration')
# NOTE: Unlike run(), do NOT increment iteration counter.
# self._iteration += 1
# Propagate replicas.
self._propagate_replicas()
# Compute energies of all replicas at all states
self._compute_energies()
# Update thermodynamic states
self._mix_replicas()
# Computing timing information
iteration_time = timer.stop('Equilibration Iteration')
partial_total_time = timer.partial('Run Equilibration')
time_per_iteration = partial_total_time / iteration
estimated_time_remaining = time_per_iteration * (n_iterations - iteration)
estimated_total_time = time_per_iteration * n_iterations
estimated_finish_time = time.time() + estimated_time_remaining
# TODO: Transmit timing information
# Show timing statistics if debug level is activated.
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Iteration took {:.3f}s.".format(iteration_time))
if estimated_time_remaining != float('inf'):
logger.debug("Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format(
str(datetime.timedelta(seconds=estimated_time_remaining)),
time.ctime(estimated_finish_time),
str(datetime.timedelta(seconds=estimated_total_time))))
timer.report_timing()
# Restore production MCMCMoves.
self._mcmc_moves = production_mcmc_moves
# Update stored positions.
mpiplus.run_single_node(0, self._reporter.write_sampler_states, self._sampler_states, self._iteration)
def run(self, n_iterations=None):
"""Run the replica-exchange simulation.
This runs at most ``number_of_iterations`` iterations. Use :func:`extend`
to pass the limit.
Parameters
----------
n_iterations : int, optional
If specified, only at most the specified number of iterations
will be run (default is None).
"""
# If this is the first iteration, compute and store the
# starting energies of the minimized/equilibrated structures.
if self._iteration == 0:
try:
self._compute_energies()
# We're intercepting a possible initial NaN position here thrown by OpenMM, which is a simple exception
# So we have to under-specify this trap.
except Exception as e:
if 'coordinate is nan' in str(e).lower():
err_message = "Initial coordinates were NaN! Check your inputs!"
logger.critical(err_message)
raise SimulationNaNError(err_message)
else:
# If not the special case, raise the error normally
raise e
mpiplus.run_single_node(0, self._reporter.write_energies, self._energy_thermodynamic_states,
self._neighborhoods, self._energy_unsampled_states, self._iteration)
self._check_nan_energy()
timer = utils.Timer()
timer.start('Run ReplicaExchange')
run_initial_iteration = self._iteration
# Handle default argument and determine number of iterations to run.
if n_iterations is None:
iteration_limit = self.number_of_iterations
else:
iteration_limit = min(self._iteration + n_iterations, self.number_of_iterations)
# Main loop.
while not self._is_completed(iteration_limit):
# Increment iteration counter.
self._iteration += 1
logger.debug('*' * 80)
logger.debug('Iteration {}/{}'.format(self._iteration, iteration_limit))
logger.debug('*' * 80)
timer.start('Iteration')
# Update thermodynamic states
self._replica_thermodynamic_states = self._mix_replicas()
# Propagate replicas.
self._propagate_replicas()
# Compute energies of all replicas at all states
self._compute_energies()
# Write iteration to storage file
self._report_iteration()
# Update analysis
self._update_analysis()
# Computing and transmitting timing information
iteration_time = timer.stop('Iteration')
partial_total_time = timer.partial('Run ReplicaExchange')
self._update_timing(iteration_time, partial_total_time, run_initial_iteration, iteration_limit)
# Show timing statistics if debug level is activated.
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]))
if self._timing_data["estimated_time_remaining"] != float('inf'):
logger.debug("Estimated completion in {}, at {} (consuming total wall clock time {}).".format(
self._timing_data["estimated_time_remaining"],
self._timing_data["estimated_localtime_finish_date"],
self._timing_data["estimated_total_time"]))
# Perform sanity checks to see if we should terminate here.
self._check_nan_energy()
def extend(self, n_iterations):
"""Extend the simulation by the given number of iterations.
Contrarily to :func:`run`, this will extend the number of iterations past
``number_of_iteration`` if requested.
Parameters
----------
n_iterations : int
The number of iterations to run.
"""
if self._iteration + n_iterations > self.number_of_iterations:
# This MUST be assigned to a property or the storage won't be updated.
self.number_of_iterations = self._iteration + n_iterations
self.run(n_iterations)
def __repr__(self):
"""Return a 'formal' representation that can be used to reconstruct the class, if possible."""
return "<instance of {}>".format(self.__class__.__name__)
def __del__(self):
# The reporter could be None if MultiStateSampler was not created.
if hasattr(self, '_reporter') and (self._reporter is not None):
mpiplus.run_single_node(0, self._reporter.close)
# -------------------------------------------------------------------------
# Internal-usage.
# -------------------------------------------------------------------------
def _pre_write_create(self,
thermodynamic_states,
sampler_states,
storage,
initial_thermodynamic_states=None,
unsampled_thermodynamic_states=None,
metadata=None,):
"""
Internal function which allocates and sets up ALL variables prior to actually using them.
This is helpful to ensure subclasses have all variables created prior to writing them out with
:func:`_report_iteration`.
All calls to this function should be *identical* to :func:`create` itself
"""
# Check all systems are either periodic or not.
is_periodic = thermodynamic_states[0].is_periodic
for thermodynamic_state in thermodynamic_states:
if thermodynamic_state.is_periodic != is_periodic:
raise Exception('Thermodynamic states contain a mixture of '
'systems with and without periodic boundary conditions.')
# Check that sampler states specify box vectors if the system is periodic
if is_periodic:
for sampler_state in sampler_states:
if sampler_state.box_vectors is None:
raise Exception('All sampler states must have box_vectors defined if the system is periodic.')
# Make sure all states have same number of particles. We don't
# currently support writing storage with different n_particles
n_particles = thermodynamic_states[0].n_particles
for the_states in [thermodynamic_states, sampler_states]:
for state in the_states:
if state.n_particles != n_particles:
raise ValueError('All ThermodynamicStates and SamplerStates must '
'have the same number of particles')
# Handle default argument for metadata and add default simulation title.
default_title = (self._TITLE_TEMPLATE.format(time.asctime(time.localtime())))
if metadata is None:
metadata = dict(title=default_title)
elif 'title' not in metadata:
metadata['title'] = default_title
self._metadata = metadata
# Save thermodynamic states. This sets n_replicas.
self._thermodynamic_states = copy.deepcopy(thermodynamic_states)
# Handle default unsampled thermodynamic states.
if unsampled_thermodynamic_states is None:
self._unsampled_states = []
else:
self._unsampled_states = copy.deepcopy(unsampled_thermodynamic_states)
# Deep copy sampler states.
self._sampler_states = [copy.deepcopy(sampler_state) for sampler_state in sampler_states]
# Set initial thermodynamic state indices if not specified
if initial_thermodynamic_states is None:
initial_thermodynamic_states = self._default_initial_thermodynamic_states(thermodynamic_states,
sampler_states)
self._replica_thermodynamic_states = np.array(initial_thermodynamic_states, np.int64)
# Assign default system box vectors if None has been specified.
for replica_id, thermodynamic_state_id in enumerate(self._replica_thermodynamic_states):
sampler_state = self._sampler_states[replica_id]
if sampler_state.box_vectors is not None:
continue
thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id]
sampler_state.box_vectors = thermodynamic_state.system.getDefaultPeriodicBoxVectors()
# Ensure there is an MCMCMove for each thermodynamic state.
if isinstance(self._mcmc_moves, mcmc.MCMCMove):
self._mcmc_moves = [copy.deepcopy(self._mcmc_moves) for _ in range(self.n_states)]
elif len(self._mcmc_moves) != self.n_states:
raise RuntimeError('The number of MCMCMoves ({}) and ThermodynamicStates ({}) must '
'be the same.'.format(len(self._mcmc_moves), self.n_states))
# Reset iteration counter.
self._iteration = 0
# Reset statistics.
# _n_accepted_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j.
# _n_proposed_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j.
self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64)
self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64)
# Allocate memory for energy matrix. energy_thermodynamic/unsampled_states[k][l]
# is the reduced potential computed at the positions of SamplerState sampler_states[k]
# and ThermodynamicState thermodynamic/unsampled_states[l].
self._energy_thermodynamic_states = np.zeros([self.n_replicas, self.n_states], np.float64)
self._neighborhoods = np.zeros([self.n_replicas, self.n_states], 'i1')
self._energy_unsampled_states = np.zeros([self.n_replicas, len(self._unsampled_states)], np.float64)
@classmethod
def _instantiate_sampler_from_reporter(cls, reporter):
"""
Creates a new instance of the reporter on disk and sampler which can then be manipulated.
Does not set any variables, use :func:`_restore_sampler_from_reporter` after calling this to set them.
Helper function to break up the :func:`from_storage` method in a way that subclasses can specialize
Parameters
----------
reporter : Reporter
A reporter open for reading.
Returns
-------
sampler : MultiStateSampler
A new instance of MultiStateSampler (or subclass) with options
restored from disk.
"""
# Retrieve options and create new simulation.
options = reporter.read_dict('options')
options['mcmc_moves'] = reporter.read_mcmc_moves()
sampler = cls(**options)
# Display papers to be cited.
sampler._display_citations()
return sampler
def _restore_sampler_from_reporter(self, reporter):
"""
(Re-)initialize the instanced sampler from the reporter. Intended to be called as the second half of a
:func:`from_storage` method after the :class:`MultiStateSampler` has been instanced from disk.
The ``self.reporter`` instance of this sampler will be in an open state for append mode after this has been set,
and the ``reporter`` used as argument will be closed. In the event they are the same, reporter will be
returned as open in append mode.
Note: Needs an already initialized reporter to work correctly.
Warning: can overwrite the current state of this :class:`MultiStateSampler` instance.
Helper function to break up the from_storage method in a way that subclasses can specialize
Parameters
----------
reporter : multistate.MultiStateReporter
Reporter open for reading.
"""
# Read the last iteration reported to ensure we don't include junk
# data written just before a crash.
logger.debug("Reading storage file {}...".format(reporter.filepath))
metadata = reporter.read_dict('metadata')
thermodynamic_states, unsampled_states = reporter.read_thermodynamic_states()
def _read_options(check_iteration):
internal_sampler_states = reporter.read_sampler_states(iteration=check_iteration)
internal_state_indices = reporter.read_replica_thermodynamic_states(iteration=check_iteration)
internal_energy_thermodynamic_states, internal_neighborhoods, internal_energy_unsampled_states = \
reporter.read_energies(iteration=check_iteration)
internal_n_accepted_matrix, internal_n_proposed_matrix = \
reporter.read_mixing_statistics(iteration=check_iteration)
# Search for last cached free energies only if online analysis is activated.
internal_last_mbar_f_k, internal_last_err_free_energy = None, None
if self.online_analysis_interval is not None:
online_analysis_info = self._read_last_free_energy(reporter, check_iteration)
try:
internal_last_mbar_f_k, (_, internal_last_err_free_energy) = online_analysis_info
except TypeError:
# Trap case where online analysis is set but not run yet and (_, ...) = None is not iterable
pass
return (internal_sampler_states, internal_state_indices, internal_energy_thermodynamic_states,
internal_neighborhoods, internal_energy_unsampled_states, internal_n_accepted_matrix,