Skip to content

Commit

Permalink
Fix #7: implement main worker process, algorithm_registry and logging (
Browse files Browse the repository at this point in the history
…#15)

* WIP: implementation of algorithm_registry and base classifier.

* Implement main worker process of VM.

* Addressed review comments

* change FIXED_TIME_WAITING_SECS to FIXED_TIME_WAITING_PERIOD in vmconf

* Fix main.py code.

* Addressed review comments
  • Loading branch information
prasanna08 committed Jun 23, 2017
1 parent c62e3f8 commit 543df4e
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 8 deletions.
79 changes: 79 additions & 0 deletions core/classifiers/BaseClassifier/BaseClassifier.py
@@ -0,0 +1,79 @@
# coding: utf-8
#
# Copyright 2017 The Oppia Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for classification algorithms"""

import abc


class BaseClassifier(object):
"""A base class for classifiers that uses supervised learning to match
free-form text answers to answer groups. The classifier trains on answers
that exploration editors have assigned to an answer group. Given a new
answer, it predicts the answer group.
Below are some concepts used in this class.
training_data: list(dict). The training data that is used for training
the classifier.
label - An answer group that the training sample should correspond to. If a
sample is being added to train a model, labels are provided.
"""

__metaclass__ = abc.ABCMeta

def __init__(self):
pass

@abc.abstractmethod
def to_dict(self, model):
"""Returns a dict representing this classifier.
Returns:
dict. A dictionary representation of classifier referred as
'classifier_data'. This data is used for prediction.
"""
raise NotImplementedError

@abc.abstractmethod
def train(self, training_data):
"""Trains classifier using given training_data.
Args:
training_data: list(dict). The training data that is used for
training the classifier. The list contains dicts where each dict
represents a single training data group, for example:
training_data = [
{
'answer_group_index': 1,
'answers': ['a1', 'a2']
},
{
'answer_group_index': 2,
'answers': ['a2', 'a3']
}
]
"""
raise NotImplementedError

@abc.abstractmethod
def validate(self, classifier_data):
"""Validates classifier data.
Args:
classifier_data: dict of the classifier attributes specific to
the classifier algorithm used.
"""
raise NotImplementedError
100 changes: 100 additions & 0 deletions core/classifiers/algorithm_registry.py
@@ -0,0 +1,100 @@
# coding: utf-8
#
# Copyright 2017 The Oppia Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registry for classification algorithms/classifiers."""

import os
import pkgutil

import vmconf


class Registry(object):
"""Registry of all classifier classes."""

# pylint: disable=fixme
# TODO (prasanna08): Add unittest for algorithm registry when we have
# classifier(s) to test it.

# Dict mapping algorithm IDs to classifier classes.
_classifier_classes = {}

@classmethod
def get_all_classifier_algorithm_ids(cls):
"""Retrieves a list of all classifier algorithm IDs.
Returns:
A list containing all the classifier algorithm IDs.
"""
return [classifier_id
for classifier_id in vmconf.ALGORITHM_IDS]

@classmethod
def _refresh(cls):
"""Refreshes the dict mapping algorithm IDs to instances of
classifiers.
"""
cls._classifier_classes.clear()

all_classifier_ids = cls.get_all_classifier_algorithm_ids()

# Assemble all paths to the classifiers.
extension_paths = [
os.path.join(vmconf.CLASSIFIERS_DIR, classifier_id)
for classifier_id in all_classifier_ids]

# Crawl the directories and add new classifier instances to the
# registry.
for loader, name, _ in pkgutil.iter_modules(path=extension_paths):
module = loader.find_module(name).load_module(name)
clazz = getattr(module, name)

ancestor_names = [
base_class.__name__ for base_class in clazz.__bases__]
if 'BaseClassifier' in ancestor_names:
cls._classifier_classes[clazz.__name__] = clazz

@classmethod
def get_all_classifiers(cls):
"""Retrieves a list of instances of all classifiers.
Returns:
A list of instances of all the classification algorithms.
"""
if not cls._classifier_classes:
cls._refresh()
return [clazz() for clazz in cls._classifier_classes.values()]

@classmethod
def get_classifier_by_algorithm_id(cls, classifier_algorithm_id):
"""Retrieves a classifier instance by its algorithm id.
Refreshes once if the classifier is not found; subsequently, throws a
KeyError.
Args:
classifier_algorithm_id: str. The ID of the classifier algorithm.
Raises:
KeyError: If the classifier is not found the first time.
Returns:
An instance of the classifier.
"""
if classifier_algorithm_id not in cls._classifier_classes:
cls._refresh()
clazz = cls._classifier_classes[classifier_algorithm_id]
return clazz()
120 changes: 120 additions & 0 deletions core/services/job_services.py
@@ -0,0 +1,120 @@
# coding: utf-8
#
# Copyright 2017 The Oppia Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains functions used for polling, training and saving jobs."""

from core.classifiers import algorithm_registry
from core.services import remote_access_services

# pylint: disable=too-many-branches
def _validate_job_data(job_data):
if not isinstance(job_data, dict):
raise Exception('Invalid format of job data')

if 'job_id' not in job_data:
raise Exception('job data should contain job id')

if 'training_data' not in job_data:
raise Exception('job data should contain training data')

if 'algorithm_id' not in job_data:
raise Exception('job data should contain algorithm id')

if not isinstance(job_data['job_id'], str):
raise Exception(
'Expected job id to be a string, received %s' %
job_data['job_id'])

if not isinstance(job_data['algorithm_id'], str):
raise Exception(
'Expected algorithm id to be a string, received %s' %
job_data['algorithm_id'])

if not isinstance(job_data['training_data'], list):
raise Exception(
'Expected training data to be a list, received %s' %
job_data['training_data'])

algorithm_ids = (
algorithm_registry.Registry.get_all_classifier_algorithm_ids())
if job_data['algorithm_id'] not in algorithm_ids:
raise Exception('Invalid algorithm id %s' % job_data['algorithm_id'])

for grouped_answers in job_data['training_data']:
if 'answer_group_index' not in grouped_answers:
raise Exception(
'Expected answer_group_index to be a key in training_data',
' list item')
if 'answers' not in grouped_answers:
raise Exception(
'Expected answers to be a key in training_data list item')
if not isinstance(grouped_answers['answer_group_index'], int):
raise Exception(
'Expected answer_group_index to be an int, received %s' %
grouped_answers['answer_group_index'])
if not isinstance(grouped_answers['answers'], list):
raise Exception(
'Expected answers to be a list, received %s' %
grouped_answers['answers'])


def get_next_job():
"""Get next job request.
Returns: dict. A dictionary containing job data.
"""
job_data = remote_access_services.fetch_next_job_request()
_validate_job_data(job_data)
return job_data


def train_classifier(algorithm_id, training_data):
"""Train classifier associated with 'algorithm_id' using 'training_data'.
Args:
algorithm_id: str. ID of classifier algorithm.
training_data: list(dict). A list containing training data. Each dict
stores 'answer_group_index' and 'answers'.
Returns:
dict. Result of trained classifier algorithm.
"""
classifier = algorithm_registry.Registry.get_classifier_by_algorithm_id(
algorithm_id)
classifier.train(training_data)
classifier_data = classifier.to_dict()
classifier.validate(classifier_data)
return classifier_data


def store_job_result(job_id, classifier_data):
"""Store result of job in the Oppia server.
Args:
job_id: str. ID of the job whose result is to be stored.
classifier_data: dict. A dictionary representing result of the job.
Returns:
int. Status code of response.
"""
job_result_dict = {
'job_id': job_id,
'classifier_data': classifier_data
}

status = remote_access_services.store_trained_classifier_model(
job_result_dict)
return status
6 changes: 3 additions & 3 deletions core/services/remote_access_services.py
Expand Up @@ -115,10 +115,10 @@ def store_trained_classifier_model(job_result_dict):
if not isinstance(job_result_dict, dict):
raise Exception('job_result_dict must be in dict format.')

if 'job_id' not in job_result_dict.keys():
if 'job_id' not in job_result_dict:
raise Exception('job_result_dict must contain \'job_id\'.')

if 'classifier_data' not in job_result_dict.keys():
if 'classifier_data' not in job_result_dict:
raise Exception('job_result_dict must contain \'classifier_data\'.')

payload = job_result_dict
Expand All @@ -128,4 +128,4 @@ def store_trained_classifier_model(job_result_dict):
request_url = "%s:%s/%s" % (
_get_url(), _get_port(), vmconf.STORE_TRAINED_CLASSIFIER_MODEL_HANDLER)
response = requests.post(request_url, json=payload)
return response
return response.status_code
4 changes: 2 additions & 2 deletions core/services/remote_access_services_test.py
Expand Up @@ -73,9 +73,9 @@ def post_callback(request):
self.assertDictEqual(classifier_data, payload['classifier_data'])

with self.set_job_result_post_callback(post_callback):
resp = remote_access_services.store_trained_classifier_model(
status = remote_access_services.store_trained_classifier_model(
job_result_dict)
self.assertEqual(resp.status_code, 200)
self.assertEqual(status, 200)

def test_exception_is_raised_when_classifier_data_is_inappropriate(self):
"""Test that correct results are stored."""
Expand Down
40 changes: 40 additions & 0 deletions main.py
Expand Up @@ -18,5 +18,45 @@
# This step should be performed before importing any of the
# third party libraries.

import logging
import sys
import time

import vm_config
vm_config.configure()

# pylint: disable=wrong-import-position
from core.services import job_services
import vmconf

def main():
"""Main process of VM."""
try:
job_data = job_services.get_next_job()
if job_data is None:
logging.info('No pending job requests.')
if vmconf.DEFAULT_WAITING_METHOD == vmconf.FIXED_TIME_WAITING:
time.sleep(vmconf.FIXED_TIME_WAITING_PERIOD)
return
classifier_data = job_services.train_classifier(
job_data['algorithm_id'], job_data['training_data'])
status = job_services.store_job_result(
job_data['job_id'], classifier_data)

if status != 200:
logging.warning(
'Failed to store result of the job with \'%s\' job_id',
job_data['job_id'])
return

except KeyboardInterrupt:
logging.info('Exiting')
sys.exit(0)

except Exception as e: # pylint: disable=broad-except
# Log any exceptions that arises during processing of job.
logging.error(e.message)

if __name__ == '__main__':
while True:
main()

0 comments on commit 543df4e

Please sign in to comment.