In [None]:
# Dependencies.
from typing import Union, Optional


In [1]:

class EarlyStopper(object):
	r"""Implements early stopping of the training of a model based on the
	validation performance improvements subject to the `patience`. The class
	works for models implemented in PyTorch. This class is a template; an
	actual early stoppers must subclass it and implement the methods that
	currently raise an exception.
	"""

	#
	def __init__(self,
		patience: int, least_count: Union[float, int], is_bigger_better: bool,
		is_verbose: bool = False, **kwargs
	) -> None:
		r"""The initializer.

		Parameters
		----------
		patience:
			The number of epochs to tolerate no improvements in performance.

		least_count:
			The least amount by which the performance must improve for it to
			be considered significant.

		is_bigger_better:
			Whether the tracked values are rewards/accuracies, in which case
			the bigger values are the better ones. If `False`, smaller values
			are better; e.g., the validation loss while training the model.

		is_verbose:
			Whether to display warnings.

		kwargs:
			Any other stuff to be used by the early stopper.
		"""
		# Set args.
		self.patience: int = patience
		assert self.patience > 0, \
			'[ERROR] patience value must be a positive integer.'
		self.least_count: Union[float, int] = least_count
		assert self.least_count > 0, \
			'[ERROR] least count must be a positive real number.'
		self.is_bigger_better: bool = is_bigger_better
		self.is_verbose: bool = is_verbose
		self.kwargs = kwargs
		# Set attributes.
		self.best_val: Union[int, float, None] = None
		self.anxiety: int = 0
		self.should_stop: bool = False

	#
	def reset(self) -> None:
		r"""Resets the early stopper.

		Parameters
		----------

		Returns
		-------
		"""
		self.best_val = None
		self.anxiety = 0
		self.should_stop = False

	#
	def step(self, val: Union[float, int], *args, **kwargs) -> Optional[object]:
		r"""Implements the method that updates the early stopper based on the
		current value and other args.

		Parameters
		----------
		val:
			The value of the tracked variable for the current iteration.

		args, kwargs:
			All other arguments that are to be used by triggered calls.

		Returns
		-------

		Notes
		-----
		This method will trigger a call to `step_after_early_stopper_stops`
		if `step` is called even if the early stopper instance should stop.

		Initially, `best_val` is `None`, `anxiety` is `0`, `should_stop` is `False`.

		In this case, a call to `step` would simply set `best_val` to `val` and
		trigger performance improvement by calling `on_improvement`.

		Otherwise, it is NOT the first `step` call.

		Now, check if there is an improvement. If so, set the `best_val` to
		`val`, reset `anxiety` to `0`, and keep `should_stop` to `False` (as is).
		Then, call the `on_improvement method`.

		Otherwise, there is no improvement and so, increment the `anxiety`.
		Keep the `best_val` unchanged. However, if the `anxiety` crosses
		`patience`, then set the `should_stop` flag to `True` and call
		`on_stop`. If the `anxiety` does not cross `patience`, call
		`on_no_improvement`.

		All trigger methods MUST be implemented by the subclass that actually
		implements an early stopper. Further, these methods can optionally
		return stuff.
		"""
		# If the early stopper is already stopped, cater separately.
		if self.should_stop:
			return self.step_after_early_stopper_stops(*args, **kwargs)
		# Else, proceed.
		else:
			# Cater for the first `step` call.
			if self.best_val is None:
				self.best_val = val
				assert not self.should_stop, \
					'[ERROR] on improvement, the early stopper must continue.'
				return self.on_improvement(*args, **kwargs)
			# Else, the `best_val` is already set in a previous call.
			else:
				# If there is an improvement, set attribs and triggers.
				if self.is_improvement(val=val):
					self.best_val = val
					self.anxiety = 0
					assert not self.should_stop, \
						'[ERROR] on improvement, the early stopper must continue.'
					self.on_improvement(*args, **kwargs)
				# If no improvement, ...
				else:
					self.anxiety += 1
					# If `patience` exhausted, stop the early stopper.
					if self.anxiety > self.patience:
						self.should_stop = True
						if self.is_verbose:
							print(
								'[WARNING] the early stopper stopped as ' + \
								f'anxiety: {self.anxiety} crossed patience: {self.patience}'
							)
						return self.on_stop(*args, **kwargs)
					else:
						assert not self.should_stop, \
							'[ERROR] if patience does not exhause upon no improvements, ' + \
							'the early stopper must continue.'
						return self.on_no_improvement(*args, **kwargs)

	##
	def step_after_early_stopper_stops(self, *args, **kwargs):
		r"""The method to be called if `step` is called after the early
		stopper has `should_stop` set as `True`.

		Parameters
		----------

		Returns
		-------

		Examples
		--------
		A basic implementation is to just raise out-of-range error.
		`
		raise IndexError(
			'[ERROR] the early stopper is already stopped; ' + \
			'self.should_stop: {}'.format(self.should_stop)
		)
		`
		"""
		raise NotImplementedError(
			'[ERROR] the subclass must implement this method.'
		)

	##
	def is_improvement(self, val: Union[float, int]) -> bool:
		r"""Checks if the current `val` is an improvement over the `best_val`.

		Parameters
		----------
		val:
			The value for the current iteration.

		Returns
		-------
		is_improvement:
			Whether `val` is an improvement over the current `best_val`.

		Notes
		-----
		If `is_bigger_better` is `True`, `val` being greater than the current
		`best_val` is an improvement. However, we often need a significant
		increase over the current `best_val`. Thus, in this case, we want
		`val > best_val + least_count` for it to be an improvement.

		Else, along the same lines, we want `val < best_val - least_count` for
		it to be a significant improvement.

		Set `least_count = 0` in the initializer for strict improvements.
		"""
		if self.is_bigger_better:
			return val > (self.best_val + self.least_count)
		else:
			return val < (self.best_val - self.least_count)

	##
	def on_improvement(self, *args, **kwargs):
		r"""The method to be called when the performance improves.

		Parameters
		----------

		Returns
		-------

		Examples
		--------
		A basic implementation would checkpoint the current model, which is
		the best till now, at the given path.
		"""
		raise NotImplementedError(
			'[ERROR] the subclass must implement this method.'
		)

	##
	def on_stop(self, *args, **kwargs):
		r"""The method to be called when the early stopper stops.

		Parameters
		----------

		Returns
		-------

		Examples
		--------
		A basic implementation would checkpoint the last model at given path.
		"""
		raise NotImplementedError(
			'[ERROR] the subclass must implement this method.'
		)

	##
	def on_no_improvement(self, *args, **kwargs):
		r"""The method to be called when the early stopper is still running
		but the values do not improve.

		Parameters
		----------

		Returns
		-------

		Examples
		--------
		A basic implementation would consist of doing nothing (possibly
		printing the current `anxiety` and the `patience` values).
		"""
		raise NotImplementedError(
			'[ERROR] the subclass must implement this method.'
		)
    

NameError: name 'Union' is not defined