-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
68 lines (55 loc) · 2.12 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Model configuration."""
from layers_torch import *
from utils import _get_vocab_size
def build_config(args: dict, device: str, PATH: str) -> dict:
"""
Returns configuration dictionary for the model.
@param args (dict): arguments from terminal (wether to train, test, or fine-tune)
@param device (str): device to run model training and evaluation on.
@param PATH (str): path to the repository.
@returns MODEL_CONFIG (dict): dictionary with hyperparameters and layers of the model.
"""
training_params = {
'--corpus': f"{PATH}/data/shakespeare.txt",
'--to_path': f"{PATH}/models/my_pretrained_model.json",
"n_iter": 150000,
"n_timesteps": 512,
"batch_size": 16,
"learning_rate": 0.0006,
"regularization": 0.001,
"patience": 7,
"evaluation_interval": 250
}
fine_tuning_params = {
'--corpus': f"{PATH}/data/shakespeare.txt",
'--to_path': f"{PATH}/models/my_model.json",
'--from_path': f"{PATH}/models/my_pretrained_model.json",
"n_iter": 20000,
"n_timesteps": 512,
"batch_size": 16,
"learning_rate": 0.0001,
"regularization": 0.001,
"patience": 7,
"evaluation_interval": 250
}
testing_params = {
'--from_path': f"{PATH}/models/my_pretrained_lstm_model.json",
'n_timesteps': 750,
'--seed': ". "
}
#gets the vocabulary size (num of unique characters) that the model will accept as input.
vocab_size = _get_vocab_size(args,training_params['--corpus'],fine_tuning_params['--from_path'],testing_params['--from_path'])
model_layers = [
Embedding(vocab_size, 256, device = device),
RNNBlock(256, 256, device = device),
RNNBlock(256, 256, device = device),
TemporalDense(256, vocab_size, device = device),
TemporalSoftmax(device = device)
]
MODEL_CONFIG = {
'training_params': training_params,
'fine_tuning_params': fine_tuning_params,
'testing_params':testing_params,
'model_layers':model_layers
}
return MODEL_CONFIG