/
multistatereporter.py
1874 lines (1568 loc) · 83.1 KB
/
multistatereporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
Multistatereporter
==================
Master multi-thermodynamic state reporter module. Handles all Disk I/O
reporting operations for any MultiStateSampler derived classes.
COPYRIGHT
Current version by Andrea Rizzi <andrea.rizzi@choderalab.org>, Levi N. Naden <levi.naden@choderalab.org> and
John D. Chodera <john.chodera@choderalab.org> while at Memorial Sloan Kettering Cancer Center.
Original version by John D. Chodera <jchodera@gmail.com> while at the University of
California Berkeley.
LICENSE
This code is licensed under the latest available version of the MIT License.
"""
# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================
import os
import copy
import time
import uuid
import yaml
import warnings
import logging
import collections
import numpy as np
import netCDF4 as netcdf
from typing import Union, Any
try:
from openmm import unit
except ImportError: # OpenMM < 7.6
from simtk import unit
from openmmtools.utils import deserialize, with_timer, serialize, quantity_from_string
from openmmtools import states
logger = logging.getLogger(__name__)
# ==============================================================================
# MULTISTATE SAMPLER REPORTER
# ==============================================================================
class MultiStateReporter(object):
"""Handle storage write/read operations and different format conventions.
You can use this object to programmatically inspect the data generated by
ReplicaExchange.
Parameters
----------
storage : str
The path to the storage file for analysis.
A second checkpoint file will be determined from either ``checkpoint_storage`` or automatically based on
the storage option
In the future this will be able to take Storage classes as well.
open_mode : str or None
The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+').
If None, the storage file won't be open on construction, and a call to
:func:`Reporter.open` will be needed before attempting read/write operations.
checkpoint_interval : int >= 1, Default: 50
The frequency at which checkpointing information is written relative to analysis information.
This is a multiple
of the iteration at which energies is written, hence why it must be greater than or equal to 1.
Checkpoint information cannot be written on iterations which where ``iteration % checkpoint_interval != 0``.
checkpoint_storage : str or None, optional
Optional name of the checkpoint point file. This file is used to save trajectory information and other less
frequently accessed data.
This should NOT be a full path, and instead just a filename
If None: the derived checkpoint name is the same as storage, less any extension, then "_checkpoint.nc" is added.
The reporter internally tracks what data goes into which file, so its transparent to all other classes
In the future, this will be able to take Storage classes as well
analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple)
Indices of particles which should be treated as special when manipulating read and write functions.
If this is an empty tuple, then no particles are treated as special
Attributes
----------
filepath
checkpoint_interval
is_periodic
n_states
n_replicas
analysis_particle_indices
"""
def __init__(self, storage, open_mode=None,
checkpoint_interval=50, checkpoint_storage=None,
analysis_particle_indices=()):
# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
# Handle checkpointing
if type(checkpoint_interval) != int:
raise ValueError("checkpoint_interval must be an integer!")
dirname, filename = os.path.split(storage)
if checkpoint_storage is None:
basename, ext = os.path.splitext(filename)
addon = "_checkpoint"
checkpoint_storage = os.path.join(dirname, basename + addon + ext)
logger.debug("Initial checkpoint file automatically chosen as {}".format(checkpoint_storage))
else:
checkpoint_storage = os.path.join(dirname, checkpoint_storage)
self._storage_analysis_file_path = storage
self._storage_checkpoint_file_path = checkpoint_storage
self._storage_checkpoint = None
self._storage_analysis = None
self._checkpoint_interval = checkpoint_interval
# Cast to tuple no mater what 1-D-like input was given
self._analysis_particle_indices = tuple(analysis_particle_indices)
if open_mode is not None:
self.open(open_mode)
# Flag to check whether to overwrite real time statistics file
self._overwrite_statistics = True
@property
def filepath(self):
"""
Returns the string file name of the primary storage file
Classes outside the Reporter can access the file string for error messages and such.
"""
return self._storage_analysis_file_path
@property
def _storage(self):
"""
Return an iterable of the storage objects, avoids having the [list, of, storage, objects] everywhere
Object 0 is always the primary file, all others are subfiles
"""
return self._storage_analysis, self._storage_checkpoint
@property
def _storage_paths(self):
"""
Return an iterable of paths to the storage files
Object 0 is always the primary file, all others are subfiles
"""
return self._storage_analysis_file_path, self._storage_checkpoint_file_path
@property
def _storage_dict(self):
"""Return an iterable dictionary of the self._storage_X objects"""
return {'checkpoint': self._storage_checkpoint, 'analysis': self._storage_analysis}
@property
def n_states(self):
if not self.is_open():
return None
return self._storage_analysis.dimensions['state'].size
@property
def n_replicas(self):
if not self.is_open():
return None
return self._storage_analysis.dimensions['replica'].size
@property
def is_periodic(self):
if not self.is_open():
return None
if 'box_vectors' in self._storage_analysis.variables:
return True
return False
@property
def analysis_particle_indices(self):
"""Return the tuple of indices of the particles which additional information is stored on for analysis"""
return self._analysis_particle_indices
@property
def checkpoint_interval(self):
"""Returns the checkpoint interval"""
return self._checkpoint_interval
def storage_exists(self, skip_size=False):
"""
Check if the storage files exist on disk.
Reads information on the primary file to see existence of others
Parameters
----------
skip_size : bool, Optional, Default: False
Skip the check of the file size. Helpful if you have just initialized a storage file but written nothing to
it yet and/or its still entirely in memory (e.g. just opened NetCDF files)
Returns
-------
files_exist : bool
If the primary storage file and its related subfiles exist, returns True.
If the primary file or any subfiles do not exist, returns False
"""
# This function serves as a way to mask the subfiles from everything outside the reporter
for file_path in self._storage_paths:
if not os.path.exists(file_path):
return False # Return if any files do not exist
elif not os.path.getsize(file_path) > 0 and not skip_size:
return False # File is 0 size
return True
def is_open(self):
"""Return True if the Reporter is ready to read/write."""
if self._storage[0] is None:
return False
else:
return self._storage[0].isopen()
def _are_subfiles_open(self):
"""Internal function to check if subfiles are open"""
open_check_list = []
for storage in self._storage[1:]:
if storage is None:
return False
else:
open_check_list.append(storage.isopen())
return np.all(open_check_list)
def open(self, mode='r', convention='ReplicaExchange', netcdf_format='NETCDF4'):
"""
Open the storage file for reading/writing.
Creates and pre-formats the required files if they don't exist.
This is not necessary if you have indicated in the constructor to open.
Parameters
----------
mode : str, Optional, Default: 'r'
The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+').
convention : str, Optional, Default: 'ReplicaExchange'
NetCDF convention to write
netcdf_format : str, Optional, Default: 'NETCDF4'
The NetCDF file format to use
"""
# Ensure we don't have already another file
# open (possibly in a different mode).
self.close()
# Create directory if we want to write.
if mode != 'r':
for storage_path in self._storage_paths:
# normpath() transform '' to '.' for makedirs().
storage_dir = os.path.normpath(os.path.dirname(storage_path))
os.makedirs(storage_dir, exist_ok=True)
# Analysis file.
# ---------------
# Open analysis file.
self._storage_analysis = self._open_dataset_robustly(self._storage_analysis_file_path,
mode, version=netcdf_format)
# The analysis netcdf file holds a reference UUID so that we can check
# that the secondary netcdf files (currently only the checkpoint
# file) have the same UUID to verify that the user isn't erroneously
# trying to associate the anaysis file to the incorrect checkpoint.
try:
# Check if we have previously created the file.
primary_uuid = self._storage_analysis.UUID
except AttributeError:
# This is a new file. Use uuid4 to avoid assigning hostname information.
primary_uuid = str(uuid.uuid4())
self._storage_analysis.UUID = primary_uuid
# Initialize dataset, if needed.
self._initialize_storage_file(self._storage_analysis, 'analysis', convention)
# Checkpoint file.
# -----------------
# Open checkpoint netcdf files.
msg = ('Could not locate checkpoint subfile. This is okay for analysis if the '
'solvent trajectory is not needed, but not for production simulation!')
self._storage_checkpoint = self._open_dataset_robustly(self._storage_checkpoint_file_path,
mode, catch_io_error=True,
io_error_warning=msg,
version=netcdf_format)
if self._storage_checkpoint is not None:
# Check that the checkpoint file has the same UUID of the analysis file.
try:
assert self._storage_checkpoint.UUID == primary_uuid
except AttributeError:
# This is a new file. Assign UUID.
self._storage_checkpoint.UUID = primary_uuid
except AssertionError:
raise IOError('Checkpoint UUID does not match analysis UUID! '
'This checkpoint file came from another simulation!\n'
'Analysis UUID: {}; Checkpoint UUID: {}'.format(
primary_uuid, self._storage_checkpoint.UUID))
# Initialize dataset, if needed.
self._initialize_storage_file(self._storage_checkpoint, 'checkpoint', convention)
# Further checkpoint interval checks.
# -----------------------------------
if self._storage_analysis is not None:
# The same number will be on checkpoint file as well, but its not guaranteed to be present
on_file_interval = self._storage_analysis.CheckpointInterval
if on_file_interval != self._checkpoint_interval:
logger.debug("checkpoint_interval != on-file checkpoint interval! "
"Using on file analysis interval of {}.".format(on_file_interval))
self._checkpoint_interval = on_file_interval
# Check the special particle indices
# Handle the "variable does not exist" case
if 'analysis_particle_indices' not in self._storage_analysis.variables:
n_particles = len(self._analysis_particle_indices)
# This dimension won't exist if the above statement does not either
self._storage_analysis.createDimension('analysis_particles', n_particles)
ncvar_analysis_particles = \
self._storage_analysis.createVariable('analysis_particle_indices', int, 'analysis_particles')
ncvar_analysis_particles[:] = self._analysis_particle_indices
ncvar_analysis_particles.long_name = ("analysis_particle_indices[analysis_particles] is the indices of "
"the particles with extra information stored about them in the"
"analysis file.")
# Now handle the "variable does exist but does not match the provided ones"
# Although redundant if it was just created, its an easy check to make
stored_analysis_particles = self._storage_analysis.variables['analysis_particle_indices'][:]
if self._analysis_particle_indices != tuple(stored_analysis_particles.astype(int)):
logger.debug("analysis_particle_indices != on-file analysis_particle_indices!"
"Using on file analysis indices of {}".format(stored_analysis_particles))
self._analysis_particle_indices = tuple(stored_analysis_particles.astype(int))
def _open_dataset_robustly(self, *args, n_attempts=5, sleep_time=2,
catch_io_error=False, io_error_warning=None,
**kwargs):
"""Attempt to open the dataset multiple times if it raises an error.
This may be useful to solve some MPI concurrency and locking issues
that routinely and randomly pop up with HDF5. Some sleep time is
added between attempts (in seconds).
If the file is not found and catch_io_error is True, None is returned.
"""
# Catch eventual errors n_attempts - 1 times.
for attempt in range(n_attempts-1):
try:
return netcdf.Dataset(*args, **kwargs)
except:
logger.debug('Attempt {}/{} to open {} failed. Retrying '
'in {} seconds'.format(attempt+1, n_attempts, args[0], sleep_time))
time.sleep(sleep_time)
# Check if file exists and warn if asked
# raise IOError otherwise
if not os.path.isfile(args[0]):
if catch_io_error:
if io_error_warning is not None:
logger.warning(io_error_warning)
return None
raise IOError(f"{args[0]} does not exist")
# At the very last attempt, we try setting the environment variable
# controlling the locking mechanism of HDF5 (see choderalab/yank#1165).
if n_attempts > 1:
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
# Last attempt finally raises any error.
return netcdf.Dataset(*args, **kwargs)
def _initialize_storage_file(self, ncfile, nc_name, convention):
"""Helper function to initialize dimensions and global attributes.
If the dataset has been initialized before, nothing happens. Return True
if the file has been initialized before and False otherwise.
"""
from openmmtools import __version__
if 'scalar' not in ncfile.dimensions:
# Create common dimensions.
ncfile.createDimension('scalar', 1) # Scalar dimension.
ncfile.createDimension('iteration', 0) # Unlimited number of iterations.
ncfile.createDimension('spatial', 3) # Number of spatial dimensions.
# Set global attributes.
ncfile.application = 'YANK'
ncfile.program = 'yank.py'
ncfile.programVersion = __version__
ncfile.Conventions = convention
ncfile.ConventionVersion = '0.2'
ncfile.DataUsedFor = nc_name
ncfile.CheckpointInterval = self._checkpoint_interval
# Create and initialize the global variables
nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar')
nc_last_good_iter[0] = 0
return True
else:
return False
def close(self):
"""Close the storage files"""
for storage_name, storage in self._storage_dict.items():
if storage is not None:
if storage.isopen():
storage.sync()
storage.close()
setattr(self, '_storage' + storage_name, None)
def sync(self):
"""Force any buffer to be flushed to the file"""
for storage in self._storage:
if storage is not None:
storage.sync()
def __del__(self):
"""Synchronize and close the storage."""
self.close()
def read_end_thermodynamic_states(self):
"""Read thermodynamic states at the ends of the protocol."
Returns
-------
end_thermodynamic_states : list of ThermodynamicState
unsampled_states, if present, or first and last sampled states
"""
end_thermodynamic_states = list()
if 'unsampled_states' in self._storage_analysis.groups:
state_type = 'unsampled_states'
else:
state_type = 'thermodynamic_states'
# Read thermodynamic end states
states_serializations = dict()
n_states = len(self._storage_analysis.groups[state_type].variables)
def extract_serialized_state(inner_type, inner_id):
"""Inner function to help extract the correct serialized state while minimizing number of disk reads
Parameters
----------
inner_type : str, 'unsampled_states' or 'thermodynamic_states'
Where to read the data from, inherited from parent function's property or on the recursive loop
inner_id : int
Which state to pull data from
"""
inner_serialized_state = self.read_dict('{}/state{}'.format(inner_type, inner_id))
def isolate_thermodynamic_state(isolating_serialized_state):
"""Helper function to find true bottom level thermodynamic state from any level of nesting, reduce code
"""
isolating_serial_thermodynamic_state = isolating_serialized_state
while 'thermodynamic_state' in isolating_serial_thermodynamic_state:
# The while loop is necessary for nested CompoundThermodynamicStates.
isolating_serial_thermodynamic_state = isolating_serial_thermodynamic_state['thermodynamic_state']
return isolating_serial_thermodynamic_state
serialized_thermodynamic_state = isolate_thermodynamic_state(inner_serialized_state)
# Check if the standard state is in a previous state.
try:
standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state')
except KeyError:
# Cache the standard system serialization for future usage.
standard_system_name = '{}/{}'.format(inner_type, inner_id)
states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system']
else:
# The system serialization can be retrieved from another state.
# Because the unsampled states rely on the thermodynamic states for their deserialization, we have
# to try a secondary/recursive loop to get the thermodynamic states
# However, this loop happens less often as the states_serializations dict fills up.
try:
serialized_standard_system = states_serializations[standard_system_name]
except KeyError:
loop_type, loop_id = standard_system_name.split('/')
looped_standard_state = extract_serialized_state(loop_type, loop_id)
looped_serial_thermodynamic_state = isolate_thermodynamic_state(looped_standard_state)
serialized_standard_system = looped_serial_thermodynamic_state['standard_system']
serialized_thermodynamic_state['standard_system'] = serialized_standard_system
return inner_serialized_state
for state_id in [0, n_states-1]:
serialized_state = extract_serialized_state(state_type, state_id)
# Create ThermodynamicState object.
end_thermodynamic_states.append(deserialize(serialized_state))
return end_thermodynamic_states
@with_timer('Reading thermodynamic states from storage')
def read_thermodynamic_states(self):
"""Retrieve the stored thermodynamic states from the checkpoint file.
Returns
-------
thermodynamic_states : list of ThermodynamicStates
The previously stored thermodynamic states. During the simulation
these are swapped among replicas.
unsampled_states : list of ThermodynamicState
The previously stored unsampled thermodynamic states.
See Also
--------
read_replica_thermodynamic_states
"""
# We have to parse the thermodynamic states first because the
# unsampled states may refer to them for the serialized system.
states = collections.OrderedDict([('thermodynamic_states', list()),
('unsampled_states', list())])
# Caches standard_system_name: serialized_standard_system
states_serializations = dict()
# Read state information.
for state_type, state_list in states.items():
# There may not be unsampled states.
if state_type not in self._storage_analysis.groups:
assert state_type == 'unsampled_states'
continue
# We keep looking for states until we can't find them anymore.
n_states = len(self._storage_analysis.groups[state_type].variables)
for state_id in range(n_states):
serialized_state = self.read_dict('{}/state{}'.format(state_type, state_id))
# Find the thermodynamic state representation.
serialized_thermodynamic_state = serialized_state
while 'thermodynamic_state' in serialized_thermodynamic_state:
# The while loop is necessary for nested CompoundThermodynamicStates.
serialized_thermodynamic_state = serialized_thermodynamic_state['thermodynamic_state']
# Check if the standard state is in a previous state.
try:
standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state')
except KeyError:
# Cache the standard system serialization for future usage.
standard_system_name = '{}/{}'.format(state_type, state_id)
states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system']
else:
# The system serialization can be retrieved from another state.
serialized_standard_system = states_serializations[standard_system_name]
serialized_thermodynamic_state['standard_system'] = serialized_standard_system
# Create ThermodynamicState object.
states[state_type].append(deserialize(serialized_state))
return [states['thermodynamic_states'], states['unsampled_states']]
@with_timer('Storing thermodynamic states')
def write_thermodynamic_states(self, thermodynamic_states, unsampled_states):
"""Store all the ThermodynamicStates to the checkpoint file.
Parameters
----------
thermodynamic_states : list of ThermodynamicState
The thermodynamic states to store.
unsampled_states : list of ThermodynamicState
The unsampled thermodynamic states to store.
See Also
--------
write_replica_thermodynamic_states
"""
# Store all thermodynamic states as serialized dictionaries.
stored_states = dict()
def unnest_thermodynamic_state(serialized):
while 'thermodynamic_state' in serialized:
serialized = serialized['thermodynamic_state']
return serialized
for state_type, states in [('thermodynamic_states', thermodynamic_states),
('unsampled_states', unsampled_states)]:
for state_id, state in enumerate(states):
# Check if any compatible state has been found
found_compatible_state = False
for compare_state in stored_states:
if compare_state.is_state_compatible(state):
serialized_state = serialize(state, skip_system=True)
serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state)
serialized_thermodynamic_state.pop('standard_system') # Remove the unneeded system object
reference_state_name = stored_states[compare_state]
serialized_thermodynamic_state['_Reporter__compatible_state'] = reference_state_name
found_compatible_state = True
break
# If no compatible state is found, do full serialization
if not found_compatible_state:
serialized_state = serialize(state)
serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state)
serialized_standard_system = serialized_thermodynamic_state['standard_system']
reference_state_name = '{}/{}'.format(state_type, state_id)
len_serialization = len(serialized_standard_system)
# Store new compatibility data
stored_states[state] = reference_state_name
logger.debug("Serialized state {} is {}B | {:.3f}KB | {:.3f}MB".format(
reference_state_name, len_serialization, len_serialization/1024.0,
len_serialization/1024.0/1024.0))
# Finally write the dictionary with fixed dimension to improve compression.
self._write_dict('{}/state{}'.format(state_type, state_id),
serialized_state, fixed_dimension=True)
def read_sampler_states(self, iteration, analysis_particles_only=False):
"""Retrieve the stored sampler states on the checkpoint file
If the iteration is not on the checkpoint interval, None is returned.
Exception to this is if``analysis_particles_only``, see the Returns for output behavior.
Parameters
----------
iteration : int
The iteration at which to read the data.
analysis_particles_only : bool, Optional, Default: False
If set, will return the trajectory of ONLY the analysis particles flagged on original creation of the files
Returns
-------
sampler_states : list of SamplerStates or None
The previously stored sampler states for each replica.
If the iteration is not on the ``checkpoint_interval`` and the ``analysis_particles_only`` is not set,
None is returned
If ``analysis_particles_only`` is set, will return only the subset of particles which were defined by the
``analysis_particle_indices`` on reporter creation
"""
if analysis_particles_only:
if len(self._analysis_particle_indices) == 0:
raise ValueError("No particles were flagged for special analysis! "
"No such trajectory would have been written!")
return self._read_sampler_states_from_given_file(iteration, storage_file='analysis',
obey_checkpoint_interval=False)
else:
return self._read_sampler_states_from_given_file(iteration, storage_file='checkpoint',
obey_checkpoint_interval=True)
@with_timer('Storing sampler states')
def write_sampler_states(self, sampler_states: list, iteration: int):
"""Store all sampler states for a given iteration on the checkpoint file
If the iteration is not on the checkpoint interval, only the ``analysis_particle_indices`` data is written,
if set.
Parameters
----------
sampler_states : list of SamplerStates
The sampler states to store for each replica.
iteration : int
The iteration at which to store the data.
"""
# Case of no special atoms, write to normal checkpoint
self._write_sampler_states_to_given_file(sampler_states, iteration, storage_file='checkpoint',
obey_checkpoint_interval=True)
if len(self._analysis_particle_indices) > 0:
# Special case, pre-process the sampler_states
sampler_subset = []
for sampler_state in sampler_states:
positions = sampler_state.positions
# Subset positions
# Need the [arg, :] to get uniform behavior with tuple and list for arg
# since a ndarray[tuple] is different than ndarray[list]
position_subset = positions[self._analysis_particle_indices, :]
velocities_subset = None
if sampler_state._unitless_velocities is not None:
velocities = sampler_state.velocities
velocities_subset = velocities[self._analysis_particle_indices, :]
sampler_subset.append(states.SamplerState(position_subset, velocities=velocities_subset,
box_vectors=sampler_state.box_vectors))
self._write_sampler_states_to_given_file(sampler_subset, iteration, storage_file='analysis',
obey_checkpoint_interval=False)
def read_replica_thermodynamic_states(self, iteration=slice(None)):
"""Retrieve the indices of the ThermodynamicStates for each replica on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once.
Returns
-------
state_indices : np.ndarray of int
At the given iteration, replica ``i`` propagated the system in
SamplerState ``sampler_states[i]`` and ThermodynamicState
``thermodynamic_states[states_indices[i]]``.
If a slice is given, returns shape ``[len(slice), `len(sampler_states)]``
"""
iteration = self._map_iteration_to_good(iteration)
logger.debug('read_replica_thermodynamic_states: iteration = {}'.format(iteration))
return self._storage_analysis.variables['states'][iteration].astype(np.int64)
def write_replica_thermodynamic_states(self, state_indices, iteration):
"""Store the indices of the ThermodynamicStates for each replica on the analysis file
Parameters
----------
state_indices : list of int of size n_replicas
At the given iteration, replica ``i`` propagated the system in
SamplerState ``sampler_states[i]`` and ThermodynamicState
``thermodynamic_states[replica_thermodynamic_states[i]]``.
iteration : int
The iteration at which to store the data.
"""
# Initialize schema if needed.
if 'states' not in self._storage_analysis.variables:
n_replicas = len(state_indices)
# Create dimension if they don't exist.
self._ensure_dimension_exists('replica', n_replicas)
# Create variables and attach units and description.
ncvar_states = self._storage_analysis.createVariable(
'states', 'i4', ('iteration', 'replica'),
zlib=False, chunksizes=(1, n_replicas)
)
ncvar_states.units = 'none'
ncvar_states.long_name = ("states[iteration][replica] is the thermodynamic state index "
"(0..n_states-1) of replica 'replica' of iteration 'iteration'.")
# Store thermodynamic states indices.
self._storage_analysis.variables['states'][iteration, :] = state_indices[:]
def read_mcmc_moves(self):
"""Return the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation on the checkpoint
Returns
-------
mcmc_moves : list of MCMCMove
The MCMCMoves used to propagate the simulation.
"""
n_moves = len(self._storage_analysis.groups['mcmc_moves'].variables)
# Retrieve all moves in order.
mcmc_moves = list()
for move_id in range(n_moves):
serialized_move = self.read_dict('mcmc_moves/move{}'.format(move_id))
mcmc_moves.append(deserialize(serialized_move))
return mcmc_moves
def write_mcmc_moves(self, mcmc_moves):
"""Store the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation or subclasses on the checkpoint
Parameters
----------
mcmc_moves : list of MCMCMove
The MCMCMoves used to propagate the simulation.
"""
for move_id, move in enumerate(mcmc_moves):
serialized_move = serialize(move)
self.write_dict('mcmc_moves/move{}'.format(move_id), serialized_move)
def read_energies(self, iteration=slice(None)):
"""Retrieve the energy matrix at the given iteration on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once.
Returns
-------
energy_thermodynamic_states : n_replicas x n_states numpy.ndarray
``energy_thermodynamic_states[iteration, i, j]`` is the reduced potential computed at
SamplerState ``sampler_states[iteration, i]`` and ThermodynamicState ``thermodynamic_states[iteration, j]``.
energy_neighborhoods : n_replicas x n_states numpy.ndarray
energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state,
0 otherwise
energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray
``energy_unsampled_states[iteration, i, j]`` is the reduced potential computed at SamplerState
``sampler_states[iteration, i]`` and ThermodynamicState ``unsampled_thermodynamic_states[iteration, j]``.
"""
# Determine last consistent iteration
iteration = self._map_iteration_to_good(iteration)
# Retrieve energies at all thermodynamic states
energy_thermodynamic_states = np.array(self._storage_analysis.variables['energies'][iteration, :, :], np.float64)
# Retrieve neighborhoods, assuming global neighborhoods if reading a pre-neighborhoods file
try:
energy_neighborhoods = np.array(self._storage_analysis.variables['neighborhoods'][iteration, :, :], 'i1')
except KeyError:
energy_neighborhoods = np.ones(energy_thermodynamic_states.shape, 'i1')
# Read energies at unsampled states, if present
try:
energy_unsampled_states = np.array(self._storage_analysis.variables['unsampled_energies'][iteration, :, :], np.float64)
except KeyError:
# There are no unsampled thermodynamic states.
unsampled_shape = energy_thermodynamic_states.shape[:-1] + (0,)
energy_unsampled_states = np.zeros(unsampled_shape)
return energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states
def write_energies(self, energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states, iteration):
"""Store the energy matrix at the given iteration on the analysis file
Parameters
----------
energy_thermodynamic_states : n_replicas x n_states numpy.ndarray
``energy_thermodynamic_states[i][j]`` is the reduced potential computed at
SamplerState ``sampler_states[i]`` and ThermodynamicState ``thermodynamic_states[j]``.
energy_neighborhoods : n_replicas x n_states numpy.ndarray
energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state,
0 otherwise
energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray
``energy_unsampled_states[i][j]`` is the reduced potential computed at SamplerState
``sampler_states[i]`` and ThermodynamicState ``unsampled_thermodynamic_states[j]``.
iteration : int
The iteration at which to store the data.
"""
n_replicas, n_states = energy_thermodynamic_states.shape
# Create dimensions and variables if they weren't created by other functions.
self._ensure_dimension_exists('replica', n_replicas)
self._ensure_dimension_exists('state', n_states)
if 'energies' not in self._storage_analysis.variables:
ncvar_energies = self._storage_analysis.createVariable(
'energies', 'f8', ('iteration', 'replica', 'state'),
zlib=False, chunksizes=(1, n_replicas, n_states)
)
ncvar_energies.units = 'kT'
ncvar_energies.long_name = ("energies[iteration][replica][state] is the reduced (unitless) "
"energy of replica 'replica' from iteration 'iteration' evaluated "
"at the thermodynamic state 'state'.")
if 'neighborhoods' not in self._storage_analysis.variables:
ncvar_neighborhoods = self._storage_analysis.createVariable(
'neighborhoods', 'i1', ('iteration', 'replica', 'state'),
zlib=False, fill_value=1, # old-style files will be upgraded to have all states
chunksizes=(1, n_replicas, n_states)
)
ncvar_neighborhoods.long_name = ("neighborhoods[iteration][replica][state] is 1 if "
"this energy was computed during this iteration.")
if 'unsampled_energies' not in self._storage_analysis.variables:
# Check if we have unsampled states.
if energy_unsampled_states.shape[1] > 0:
n_unsampled_states = len(energy_unsampled_states[0])
self._ensure_dimension_exists('unsampled', n_unsampled_states)
if 'unsampled_energies' not in self._storage_analysis.variables:
# Create variable for thermodynamic state energies with units and descriptions.
ncvar_unsampled = self._storage_analysis.createVariable(
'unsampled_energies', 'f8', ('iteration', 'replica', 'unsampled'),
zlib=False, chunksizes=(1, n_replicas, n_unsampled_states)
)
ncvar_unsampled.units = 'kT'
ncvar_unsampled.long_name = ("unsampled_energies[iteration][replica][state] is the reduced "
"(unitless) energy of replica 'replica' from iteration 'iteration' "
"evaluated at unsampled thermodynamic state 'state'.")
# Store values
self._storage_analysis.variables['energies'][iteration,:,:] = energy_thermodynamic_states
self._storage_analysis.variables['neighborhoods'][iteration,:,:] = energy_neighborhoods
if energy_unsampled_states.shape[1] > 0:
self._storage_analysis.variables['unsampled_energies'][iteration, :, :] = energy_unsampled_states[:, :]
def read_mixing_statistics(self, iteration=slice(None)):
"""Retrieve the mixing statistics for the given iteration on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data.
Returns
-------
n_accepted_matrix : numpy.ndarray with shape (n_states, n_states)
``n_accepted_matrix[i][j]`` is the number of accepted moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
n_proposed_matrix : numpy.ndarray with shape (n_states, n_states)
``n_proposed_matrix[i][j]`` is the number of proposed moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
"""
iteration = self._map_iteration_to_good(iteration)
n_accepted_matrix = self._storage_analysis.variables['accepted'][iteration, :, :].astype(np.int64)
n_proposed_matrix = self._storage_analysis.variables['proposed'][iteration, :, :].astype(np.int64)
return n_accepted_matrix, n_proposed_matrix
def write_mixing_statistics(self, n_accepted_matrix, n_proposed_matrix, iteration):
"""Store the mixing statistics for the given iteration on the analysis file
Parameters
----------
n_accepted_matrix : numpy.ndarray with shape (n_states, n_states)
``n_accepted_matrix[i][j]`` is the number of accepted moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from iteration-1 to iteration (not cumulative).
n_proposed_matrix : numpy.ndarray with shape (n_states, n_states)
``n_proposed_matrix[i][j]`` is the number of proposed moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
iteration : int
The iteration for which to store the data.
"""
# Create schema if necessary.
if 'accepted' not in self._storage_analysis.variables:
n_states = n_accepted_matrix.shape[0]
# Create dimension if it doesn't already exist
self._ensure_dimension_exists('state', n_states)
# Create variables with units and descriptions.
ncvar_accepted = self._storage_analysis.createVariable(
'accepted', 'i4', ('iteration', 'state', 'state'),
zlib=False, chunksizes=(1, n_states, n_states)
)
ncvar_proposed = self._storage_analysis.createVariable(
'proposed', 'i4', ('iteration', 'state', 'state'),
zlib=False, chunksizes=(1, n_states, n_states)
)
ncvar_accepted.units = 'none'
ncvar_proposed.units = 'none'
ncvar_accepted.long_name = ("accepted[iteration][i][j] is the number of proposed transitions "
"between states i and j from iteration 'iteration-1'.")
ncvar_proposed.long_name = ("proposed[iteration][i][j] is the number of proposed transitions "
"between states i and j from iteration 'iteration-1'.")
# Store statistics.
self._storage_analysis.variables['accepted'][iteration, :, :] = n_accepted_matrix[:, :]
self._storage_analysis.variables['proposed'][iteration, :, :] = n_proposed_matrix[:, :]
def read_timestamp(self, iteration=slice(None)):
"""Return the timestamp for the given iteration.
Read from the analysis file, although there is a paired timestamp on the checkpoint file as well
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data.
Returns
-------
timestamp : str
The timestamp at which the iteration was stored.
"""
iteration = self._map_iteration_to_good(iteration)
return self._storage_analysis.variables['timestamp'][iteration]
def write_timestamp(self, iteration: int):
"""Store a timestamp for the given iteration on both analysis and checkpoint file.
If the iteration is not on the ``checkpoint_interval``, no timestamp is written on the checkpoint file
Parameters
----------
iteration : int
The iteration at which to read the data.
"""
# Create variable if needed.
for storage_key, storage in self._storage_dict.items():
if 'timestamp' not in storage.variables:
storage.createVariable('timestamp', str, ('iteration',),
zlib=False, chunksizes=(1,))
timestamp = time.ctime()