Skip to content

Commit

Permalink
feat: merge speller and hybrid training logic
Browse files Browse the repository at this point in the history
  • Loading branch information
my-master committed Dec 17, 2017
1 parent d9d4a7d commit f088ca4
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions deeppavlov/core/commands/train.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
from deeppavlov.core.common.file import read_json
from deeppavlov.core.common.registry import _REGISTRY
from deeppavlov.core.commands.utils import set_vocab_path, build_agent_from_config, set_usr_dir, \
USR_DIR
from deeppavlov.core.commands.utils import set_vocab_path, build_agent_from_config, set_usr_dir
from deeppavlov.core.common.params import from_params
from deeppavlov.core.models.trainable import Trainable
from deeppavlov.core.common import paths


# TODO pass paths to local model configs to agent config.

def get_data(skill_config, dataset_config, vocab_path):
dataset_name = dataset_config['name']
data_path = skill_config['data_path']

data_reader = from_params(_REGISTRY[dataset_name], dataset_config)
data = data_reader.read(data_path)
data_reader.save_vocab(data, vocab_path)
return data
# def get_data(skill_config, datareader_config, vocab_path):
# datareader_name = datareader_config['name']
# data_path = skill_config['data_path']
#
# data_reader = from_params(_REGISTRY[datareader_name], datareader_config)
# data = data_reader.read(data_path)
# data_reader.save_vocab(data, vocab_path)
# return data


def train_agent_models(config_path: str):
set_usr_dir(config_path, USR_DIR)
usr_dir = paths.USR_PATH
a = build_agent_from_config(config_path)

for skill_config in a.skill_configs:
Expand All @@ -28,7 +28,9 @@ def train_agent_models(config_path: str):
model_name = model_config['name']

if issubclass(_REGISTRY[model_name], Trainable):
data = get_data(skill_config, skill_config['dataset_reader'], vocab_path)
reader_config = skill_config['dataset_reader']
reader = from_params(_REGISTRY[reader_config['name']], {})
data = reader.read(reader_config.get('data_path', usr_dir))

model = from_params(_REGISTRY[model_name], model_config, vocab_path=vocab_path)

Expand All @@ -44,24 +46,22 @@ def train_agent_models(config_path: str):
# "Only TFModel instances can train for now.")


def train_model_from_config(config_path: str, usr_dir_name=USR_DIR):
# make a serialization user dir
usr_dir_path = set_usr_dir(config_path, usr_dir_name)

def train_model_from_config(config_path: str):
usr_dir = paths.USR_PATH
config = read_json(config_path)
vocab_path = usr_dir_path.joinpath('vocab.txt')

data = get_data(config, config['dataset_reader'], vocab_path)
reader_config = config['dataset_reader']
reader = from_params(_REGISTRY[reader_config['name']], {})
data = reader.read(reader_config.get('data_path', usr_dir))

dataset_config = config['dataset']
dataset_name = dataset_config['name']
dataset = from_params(_REGISTRY[dataset_name], dataset_config, data=data)

model_config = config['model']
model_name = model_config['name']
model = from_params(_REGISTRY[model_name], model_config, vocab_path=vocab_path)

num_epochs = config['num_epochs']
num_tr_data = config['num_train_instances']
model = from_params(_REGISTRY[model_name], model_config)

####### Train
# TODO do batching in the train script.
model.train(data, num_epochs=num_epochs, num_tr_data=num_tr_data)
model.train(dataset)

# The result is a saved to user_dir trained model.

0 comments on commit f088ca4

Please sign in to comment.