Skip to content

Commit

Permalink
Merge pull request #683 from mv1388/message_passing_setting_types
Browse files Browse the repository at this point in the history
Message passing setting types
  • Loading branch information
mv1388 committed Jul 9, 2022
2 parents 93a3e11 + 4b0e097 commit 210876a
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 117 deletions.
6 changes: 3 additions & 3 deletions aitoolbox/torchtrain/callbacks/performance_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from aitoolbox.torchtrain.callbacks.abstract import AbstractCallback, AbstractExperimentCallback
from aitoolbox.torchtrain.train_loop.components import message_passing as msg_passing_settings
from aitoolbox.torchtrain.train_loop.components.message_passing import MessageHandling
from aitoolbox.cloud.AWS.results_save import BaseResultsSaver as BaseResultsS3Saver
from aitoolbox.cloud.GoogleCloud.results_save import BaseResultsGoogleStorageSaver
from aitoolbox.cloud import s3_available_options, gcs_available_options
Expand Down Expand Up @@ -435,7 +435,7 @@ def plot_current_train_history(self, prefix=''):
results_file_local_paths = [result_local_path for _, result_local_path in saved_local_results_details]
self.message_service.write_message('ModelTrainHistoryPlot_results_file_local_paths',
results_file_local_paths,
msg_handling_settings=msg_passing_settings.UNTIL_END_OF_EPOCH)
msg_handling_settings=MessageHandling.UNTIL_END_OF_EPOCH)

if self.cloud_results_saver is not None:
experiment_cloud_path = \
Expand Down Expand Up @@ -519,7 +519,7 @@ def write_current_train_history(self, prefix=''):

self.message_service.write_message('ModelTrainHistoryFileWriter_results_file_local_paths',
[results_file_local_path],
msg_handling_settings=msg_passing_settings.UNTIL_END_OF_EPOCH)
msg_handling_settings=MessageHandling.UNTIL_END_OF_EPOCH)

if self.cloud_results_saver is not None:
experiment_cloud_path = \
Expand Down
57 changes: 31 additions & 26 deletions aitoolbox/torchtrain/train_loop/components/message_passing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
KEEP_FOREVER = 'keep_forever'
UNTIL_END_OF_EPOCH = 'until_end_of_epoch'
UNTIL_READ = 'until_read'
OVERWRITE = 'overwrite'
from enum import Enum

ACCEPTED_SETTINGS = (KEEP_FOREVER, UNTIL_END_OF_EPOCH, UNTIL_READ, OVERWRITE)

class MessageHandling(Enum):
KEEP_FOREVER = 'keep_forever'
UNTIL_END_OF_EPOCH = 'until_end_of_epoch'
UNTIL_READ = 'until_read'
OVERWRITE = 'overwrite'


class Message:
Expand All @@ -13,7 +15,8 @@ def __init__(self, key, value, msg_handling_settings):
Args:
key (str): message key
value: message value
msg_handling_settings (str or list): selected message handling settings for this particular message
msg_handling_settings (MessageHandling or list[MessageHandling]): selected message handling settings
for this particular message
"""
self.key = key
self.value = value
Expand Down Expand Up @@ -41,20 +44,23 @@ def read_messages(self, key):
"""
if key in self.message_store:
messages = [msg.value for msg in self.message_store[key]]
self.message_store[key] = [msg for msg in self.message_store[key] if UNTIL_READ not in msg.msg_handling_settings]
self.message_store[key] = [
msg for msg in self.message_store[key]
if MessageHandling.UNTIL_READ not in msg.msg_handling_settings
]
return messages
else:
return None

def write_message(self, key, value, msg_handling_settings=UNTIL_END_OF_EPOCH):
def write_message(self, key, value, msg_handling_settings=MessageHandling.UNTIL_END_OF_EPOCH):
"""Write a new message to the message service
Args:
key (str): message key
value: message content
msg_handling_settings (str or list): setting how to handle the lifespan of the message.
Can use one of the following message lifecycle handling settings which are variables imported from this
script file and can be found defined at the beginning of the script:
msg_handling_settings (MessageHandling or list[MessageHandling]): setting how to handle the lifespan of
the message. Can use one of the following message lifecycle handling settings which are variables
imported from this script file and can be found defined at the beginning of the script:
* ``KEEP_FOREVER``
* ``UNTIL_END_OF_EPOCH``
Expand All @@ -64,14 +70,14 @@ def write_message(self, key, value, msg_handling_settings=UNTIL_END_OF_EPOCH):
Returns:
None
"""
self.validate_msg_handling_settings(msg_handling_settings)
msg_handling_settings = self.validate_msg_handling_settings(msg_handling_settings)

if key not in self.message_store:
self.message_store[key] = []

message = Message(key, value, msg_handling_settings)

if OVERWRITE in msg_handling_settings:
if MessageHandling.OVERWRITE in msg_handling_settings:
self.message_store[key] = [message]
else:
self.message_store[key].append(message)
Expand All @@ -86,25 +92,24 @@ def end_of_epoch_trigger(self):
"""
for key, msgs_list in list(self.message_store.items()):
self.message_store[key] = [msg for msg in self.message_store[key]
if UNTIL_END_OF_EPOCH not in msg.msg_handling_settings]
if MessageHandling.UNTIL_END_OF_EPOCH not in msg.msg_handling_settings]

if len(self.message_store[key]) == 0:
del self.message_store[key]

@staticmethod
def validate_msg_handling_settings(msg_handling_settings):
if type(msg_handling_settings) == str:
if msg_handling_settings not in ACCEPTED_SETTINGS:
raise ValueError(f'Provided msg_handling_settings {msg_handling_settings} is not supported. '
f'Currently supported settings are: {ACCEPTED_SETTINGS}.')
elif type(msg_handling_settings) == list:
if type(msg_handling_settings) == list:
for msg_setting in msg_handling_settings:
if msg_setting not in ACCEPTED_SETTINGS:
raise ValueError(f'Provided msg_handling_settings {msg_setting} is not supported. '
f'Currently supported settings are: {ACCEPTED_SETTINGS}.')
if type(msg_setting) != MessageHandling:
raise TypeError('msg_setting is not of the correct MessageHandling type. '
f'It is {type(msg_setting)}.')

if len(msg_handling_settings) > 1 and OVERWRITE not in msg_handling_settings:
if len(msg_handling_settings) > 1 and MessageHandling.OVERWRITE not in msg_handling_settings:
raise ValueError(f'Provided two incompatible msg_handling_settings {msg_handling_settings}. '
f'Only OVERRIDE setting can currently be combined with another available setting')
else:
raise ValueError(f'Provided msg_handling_settings {msg_handling_settings} type not supported str or list')
'Only OVERRIDE setting can currently be combined with another available setting.')
elif type(msg_handling_settings) != MessageHandling:
raise ValueError(f'Provided msg_handling_settings {msg_handling_settings} type not of the supported '
'MessageHandling or list of MessageHandling.')

return msg_handling_settings if type(msg_handling_settings) is list else [msg_handling_settings]

0 comments on commit 210876a

Please sign in to comment.