/
train_models.py
36 lines (26 loc) · 1.17 KB
/
train_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from main.core.algorithm.trainer import LSTMTrainer, TrainingConfiguration
import pickle
DATASET_NAME = 'dsl_dataset'
DATASET_DIR = f'main\\core\\algorithm\\datasets\\{DATASET_NAME}'
MODELS_DIR = f'main\\core\\algorithm\\models\\{DATASET_NAME}'
TENSORBOARD_DIR = f'{MODELS_DIR}\\logs'
MODEL_NAME = "model"
EPOCHS = 25
LAYER_SIZES = [64]
DENSE_LAYERS = [0]
DENSE_ACTIVATION = "relu"
LSTM_LAYERS = [2]
LSTM_ACTIVATION = "tanh"
OUTPUT_ACTIVATION = "softmax"
def train_models():
dataset = pickle.load(open(f"{DATASET_DIR}\\{DATASET_NAME}.pickle", "rb"))
training_conf = TrainingConfiguration(model_name=MODEL_NAME, epochs=EPOCHS, layer_sizes=LAYER_SIZES,
lstm_activation=LSTM_ACTIVATION, dense_layers=DENSE_LAYERS,
lstm_layers=LSTM_LAYERS, dense_activation=DENSE_ACTIVATION,
output_activation=OUTPUT_ACTIVATION)
trainer = LSTMTrainer(training_conf)
models = trainer.create_models(dataset)
trainer.train_models(models, dataset, TENSORBOARD_DIR)
trainer.save_models(models, dataset, MODELS_DIR)
if __name__ == '__main__':
train_models()