##ModelFactory

Create models for easy lookup.

FUTURE: use metaclass programming.

In [1]:
debug = True

DRIVE_PATH = "/content/drive/MyDrive/data606"

# Set the location of this script in GDrive
SCRIPT_PATH = DRIVE_PATH + "/src/"

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd $SCRIPT_PATH

/content/drive/MyDrive/data606/src


In [4]:
# Load model classes
%run -i "./Model_Base.ipynb"
%run -i "./Model_Densev1.ipynb"
#%run -i "./Model_LSTMv1.ipynb"
#%run -i "./Model_LSTMv2.ipynb"
%run -i "./Model_LSTMv3.ipynb"
%run -i "./Model_Transformerv1.ipynb"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/data606/src
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/data606/src
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/data606/src


In [5]:

class ModelFactory():
  """
  Construct a ModelFactory that provides keyword access to models for training.
  """
  def __init__(self, window_size=30, label_window=1, num_labels=1, num_epochs=300, debug=False):

    self.debug = debug

    self.window_size = window_size
    self.label_window = label_window
    self.num_labels = num_labels
    self.num_epochs = num_epochs

    self.total_labels = label_window * num_labels

    self.init_models()


  def __repr__(self):
    """
    Print object stats.
    """
    return '\n'.join([
         f'window_size: {self.window_size}',
         f'label_window: {self.label_window}',
         f'num_labels: {self.num_labels}',
         f'num_epochs: {self.num_epochs}'
        ])

  def init_models(self):
    """
    FUTURE: discover models
    """
    self.models = {}

    #TODO please use something dynamic
    model = Model_Densev1(window_size=self.window_size,
                          label_window=self.label_window,
                          num_labels=self.num_labels,
                          num_epochs=self.num_epochs, debug=self.debug)
    self.models[model.get_name()] = model

    model = Model_LSTMv3(window_size=self.window_size,
                          label_window=self.label_window,
                          num_labels=self.num_labels,
                          num_epochs=self.num_epochs, debug=self.debug)
    self.models[model.get_name()] = model

    model = Model_Transformerv1(window_size=self.window_size,
                          label_window=self.label_window,
                          num_labels=self.num_labels,
                          num_epochs=self.num_epochs, debug=self.debug)
    self.models[model.get_name()] = model

  def get(self, model_name):
    return self.models[model_name]

  def get_saved(self, model_path):
    if (not model_path or model_path is None):
      raise AssertionError('Model path must be provided')

    # Load model to ensure we have a real model file
    model = load_model(model_path)

    # Determine what we called this model
    model_name = None
    for p in self.models.keys():
      # terminate model name with separator -
      # NOTE - highly dependent upon filename format!
      if (str(f'{p}-') in model_path):
        model_name = p
        break
    if (model_name is None):
      # Could just log and return the hdf5 model
      raise AssertionError('Model type could not be found')
    # Fetch the model and make it usable
    known_model = self.get(model_name)
    known_model.set_model(model)
    return known_model

---

**Unit testing**

---

In [6]:
WG_UNIT_TEST = False

In [7]:
if WG_UNIT_TEST:

  print(f'-------Case 1: get single model -----------')
  mf = ModelFactory()

  model = mf.get("Densev1")
  assert(model)
  print(model)

-------Case 1: get single model -----------
Model: Densev1
	window_size: 30
	label_window: 1
	num_labels: 1
	num_epochs: 300
	alpha: 0.0001


In [8]:
if WG_UNIT_TEST:

  print(f'-------Case 2: get all models -----------')
  mf = ModelFactory()

  #models = ["Densev1",'LSTMv1','LSTMv2','LSTMv3']
  models = ['Densev1','TXERv1','LSTMv3']

  for m in models:
    model = mf.get(m)
    assert(model)
    print(model)

-------Case 2: get all models -----------
Model: Densev1
	window_size: 30
	label_window: 1
	num_labels: 1
	num_epochs: 300
	alpha: 0.0001
Model: TXERv1
	window_size: 30
	label_window: 1
	num_labels: 1
	num_epochs: 300
	alpha: 0.0001
Model: LSTMv3
	window_size: 30
	label_window: 1
	num_labels: 1
	num_epochs: 300
	alpha: 0.0001
