Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
…#15) * WIP: implementation of algorithm_registry and base classifier. * Implement main worker process of VM. * Addressed review comments * change FIXED_TIME_WAITING_SECS to FIXED_TIME_WAITING_PERIOD in vmconf * Fix main.py code. * Addressed review comments
- Loading branch information
1 parent
c62e3f8
commit 543df4e
Showing
8 changed files
with
372 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# coding: utf-8 | ||
# | ||
# Copyright 2017 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Base class for classification algorithms""" | ||
|
||
import abc | ||
|
||
|
||
class BaseClassifier(object): | ||
"""A base class for classifiers that uses supervised learning to match | ||
free-form text answers to answer groups. The classifier trains on answers | ||
that exploration editors have assigned to an answer group. Given a new | ||
answer, it predicts the answer group. | ||
Below are some concepts used in this class. | ||
training_data: list(dict). The training data that is used for training | ||
the classifier. | ||
label - An answer group that the training sample should correspond to. If a | ||
sample is being added to train a model, labels are provided. | ||
""" | ||
|
||
__metaclass__ = abc.ABCMeta | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def to_dict(self, model): | ||
"""Returns a dict representing this classifier. | ||
Returns: | ||
dict. A dictionary representation of classifier referred as | ||
'classifier_data'. This data is used for prediction. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def train(self, training_data): | ||
"""Trains classifier using given training_data. | ||
Args: | ||
training_data: list(dict). The training data that is used for | ||
training the classifier. The list contains dicts where each dict | ||
represents a single training data group, for example: | ||
training_data = [ | ||
{ | ||
'answer_group_index': 1, | ||
'answers': ['a1', 'a2'] | ||
}, | ||
{ | ||
'answer_group_index': 2, | ||
'answers': ['a2', 'a3'] | ||
} | ||
] | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def validate(self, classifier_data): | ||
"""Validates classifier data. | ||
Args: | ||
classifier_data: dict of the classifier attributes specific to | ||
the classifier algorithm used. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# coding: utf-8 | ||
# | ||
# Copyright 2017 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Registry for classification algorithms/classifiers.""" | ||
|
||
import os | ||
import pkgutil | ||
|
||
import vmconf | ||
|
||
|
||
class Registry(object): | ||
"""Registry of all classifier classes.""" | ||
|
||
# pylint: disable=fixme | ||
# TODO (prasanna08): Add unittest for algorithm registry when we have | ||
# classifier(s) to test it. | ||
|
||
# Dict mapping algorithm IDs to classifier classes. | ||
_classifier_classes = {} | ||
|
||
@classmethod | ||
def get_all_classifier_algorithm_ids(cls): | ||
"""Retrieves a list of all classifier algorithm IDs. | ||
Returns: | ||
A list containing all the classifier algorithm IDs. | ||
""" | ||
return [classifier_id | ||
for classifier_id in vmconf.ALGORITHM_IDS] | ||
|
||
@classmethod | ||
def _refresh(cls): | ||
"""Refreshes the dict mapping algorithm IDs to instances of | ||
classifiers. | ||
""" | ||
cls._classifier_classes.clear() | ||
|
||
all_classifier_ids = cls.get_all_classifier_algorithm_ids() | ||
|
||
# Assemble all paths to the classifiers. | ||
extension_paths = [ | ||
os.path.join(vmconf.CLASSIFIERS_DIR, classifier_id) | ||
for classifier_id in all_classifier_ids] | ||
|
||
# Crawl the directories and add new classifier instances to the | ||
# registry. | ||
for loader, name, _ in pkgutil.iter_modules(path=extension_paths): | ||
module = loader.find_module(name).load_module(name) | ||
clazz = getattr(module, name) | ||
|
||
ancestor_names = [ | ||
base_class.__name__ for base_class in clazz.__bases__] | ||
if 'BaseClassifier' in ancestor_names: | ||
cls._classifier_classes[clazz.__name__] = clazz | ||
|
||
@classmethod | ||
def get_all_classifiers(cls): | ||
"""Retrieves a list of instances of all classifiers. | ||
Returns: | ||
A list of instances of all the classification algorithms. | ||
""" | ||
if not cls._classifier_classes: | ||
cls._refresh() | ||
return [clazz() for clazz in cls._classifier_classes.values()] | ||
|
||
@classmethod | ||
def get_classifier_by_algorithm_id(cls, classifier_algorithm_id): | ||
"""Retrieves a classifier instance by its algorithm id. | ||
Refreshes once if the classifier is not found; subsequently, throws a | ||
KeyError. | ||
Args: | ||
classifier_algorithm_id: str. The ID of the classifier algorithm. | ||
Raises: | ||
KeyError: If the classifier is not found the first time. | ||
Returns: | ||
An instance of the classifier. | ||
""" | ||
if classifier_algorithm_id not in cls._classifier_classes: | ||
cls._refresh() | ||
clazz = cls._classifier_classes[classifier_algorithm_id] | ||
return clazz() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# coding: utf-8 | ||
# | ||
# Copyright 2017 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""This module contains functions used for polling, training and saving jobs.""" | ||
|
||
from core.classifiers import algorithm_registry | ||
from core.services import remote_access_services | ||
|
||
# pylint: disable=too-many-branches | ||
def _validate_job_data(job_data): | ||
if not isinstance(job_data, dict): | ||
raise Exception('Invalid format of job data') | ||
|
||
if 'job_id' not in job_data: | ||
raise Exception('job data should contain job id') | ||
|
||
if 'training_data' not in job_data: | ||
raise Exception('job data should contain training data') | ||
|
||
if 'algorithm_id' not in job_data: | ||
raise Exception('job data should contain algorithm id') | ||
|
||
if not isinstance(job_data['job_id'], str): | ||
raise Exception( | ||
'Expected job id to be a string, received %s' % | ||
job_data['job_id']) | ||
|
||
if not isinstance(job_data['algorithm_id'], str): | ||
raise Exception( | ||
'Expected algorithm id to be a string, received %s' % | ||
job_data['algorithm_id']) | ||
|
||
if not isinstance(job_data['training_data'], list): | ||
raise Exception( | ||
'Expected training data to be a list, received %s' % | ||
job_data['training_data']) | ||
|
||
algorithm_ids = ( | ||
algorithm_registry.Registry.get_all_classifier_algorithm_ids()) | ||
if job_data['algorithm_id'] not in algorithm_ids: | ||
raise Exception('Invalid algorithm id %s' % job_data['algorithm_id']) | ||
|
||
for grouped_answers in job_data['training_data']: | ||
if 'answer_group_index' not in grouped_answers: | ||
raise Exception( | ||
'Expected answer_group_index to be a key in training_data', | ||
' list item') | ||
if 'answers' not in grouped_answers: | ||
raise Exception( | ||
'Expected answers to be a key in training_data list item') | ||
if not isinstance(grouped_answers['answer_group_index'], int): | ||
raise Exception( | ||
'Expected answer_group_index to be an int, received %s' % | ||
grouped_answers['answer_group_index']) | ||
if not isinstance(grouped_answers['answers'], list): | ||
raise Exception( | ||
'Expected answers to be a list, received %s' % | ||
grouped_answers['answers']) | ||
|
||
|
||
def get_next_job(): | ||
"""Get next job request. | ||
Returns: dict. A dictionary containing job data. | ||
""" | ||
job_data = remote_access_services.fetch_next_job_request() | ||
_validate_job_data(job_data) | ||
return job_data | ||
|
||
|
||
def train_classifier(algorithm_id, training_data): | ||
"""Train classifier associated with 'algorithm_id' using 'training_data'. | ||
Args: | ||
algorithm_id: str. ID of classifier algorithm. | ||
training_data: list(dict). A list containing training data. Each dict | ||
stores 'answer_group_index' and 'answers'. | ||
Returns: | ||
dict. Result of trained classifier algorithm. | ||
""" | ||
classifier = algorithm_registry.Registry.get_classifier_by_algorithm_id( | ||
algorithm_id) | ||
classifier.train(training_data) | ||
classifier_data = classifier.to_dict() | ||
classifier.validate(classifier_data) | ||
return classifier_data | ||
|
||
|
||
def store_job_result(job_id, classifier_data): | ||
"""Store result of job in the Oppia server. | ||
Args: | ||
job_id: str. ID of the job whose result is to be stored. | ||
classifier_data: dict. A dictionary representing result of the job. | ||
Returns: | ||
int. Status code of response. | ||
""" | ||
job_result_dict = { | ||
'job_id': job_id, | ||
'classifier_data': classifier_data | ||
} | ||
|
||
status = remote_access_services.store_trained_classifier_model( | ||
job_result_dict) | ||
return status |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.