Skip to content

Commit

Permalink
test magic removal: get rid of pre-Job/JobChain fixtures from disco.t…
Browse files Browse the repository at this point in the history
…est and start rewriting tests
  • Loading branch information
jflatow committed Apr 3, 2011
1 parent 06782cc commit 34ec37c
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 259 deletions.
4 changes: 2 additions & 2 deletions bin/discocli.py
Expand Up @@ -276,12 +276,12 @@ def test(program, *tests):
Test names is an optional list of names of modules in the ``$DISCO_HOME/tests`` directory (e.g. ``test_simple``).
Test names may also include the names of specific test cases (e.g. ``test_sort.MemorySortTestCase``).
"""
from disco.test import DiscoTestRunner
from disco.test import TestRunner
if not tests:
tests = list(program.tests)
os.environ.update(program.settings.env)
sys.path.insert(0, program.tests_path)
DiscoTestRunner(program.settings).run(*tests)
TestRunner(program.settings).run(*tests)

@Disco.command
def config(program):
Expand Down
8 changes: 6 additions & 2 deletions lib/disco/job.py
Expand Up @@ -131,9 +131,9 @@ def run(self, **jobargs):
return self

class JobChain(dict):
def run(self, interval=1):
def wait(self, poll_interval=1):
while sum(self.walk()) < len(self):
time.sleep(interval)
time.sleep(poll_interval)
return self

def walk(self):
Expand Down Expand Up @@ -163,6 +163,10 @@ def inputs(self, job):
else:
yield [input]

def purge(self):
for job in self:
job.purge()

class JobPack(object):
"""
This class implements :ref:`jobpack` in Python.
Expand Down
288 changes: 92 additions & 196 deletions lib/disco/test.py
@@ -1,71 +1,21 @@
import os, signal
import os, signal, unittest
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from SocketServer import ThreadingMixIn
from httplib import OK, INTERNAL_SERVER_ERROR
from threading import Thread
from unittest import TestCase, TestLoader, TextTestRunner

try:
from unittest import SkipTest
except ImportError:
class SkipTest(Exception):
pass

import disco
from disco.core import Disco, result_iterator
from disco.job import Job
from disco.ddfs import DDFS
from disco.settings import DiscoSettings
from disco.util import iterify

class TestServer(ThreadingMixIn, HTTPServer):
allow_reuse_address = True

@property
def address(self):
return 'http://%s:%d' % self.server_address

@classmethod
def create(cls, server_address, data_generator):
return cls(server_address, handler(data_generator))

def start(self):
self.thread = Thread(target=self.serve_forever)
self.thread.start()

def stop(self):
# Workaround for Python2.5 which doesn't have the shutdown
# method
if hasattr(self, "shutdown"):
self.shutdown()
self.socket.close()

def urls(self, inputs):
def serverify(input):
return '%s/%s' % (self.address, input)
return [[serverify(url) for url in iterify(input)] for input in inputs]

class FailedReply(Exception):
pass

def handler(data_generator):
class Handler(BaseHTTPRequestHandler):
def send_data(self, data):
self.send_response(OK)
self.send_header('Content-length', len(data or []))
self.end_headers()
self.wfile.write(data)

def do_GET(self):
try:
self.send_data(data_generator(self.path.strip('/')))
except FailedReply, e:
self.send_error(INTERNAL_SERVER_ERROR, str(e))

def log_request(*args):
pass # suppress logging output for now

return Handler

class InterruptTest(KeyboardInterrupt, SkipTest):
def __init__(self, test):
super(InterruptTest, self).__init__("Test interrupted: May not have finished cleaning up")
Expand All @@ -76,7 +26,7 @@ def __call__(self, signum, frame):
self.test.is_running = False
raise self

class DiscoTestCase(TestCase):
class TestCase(unittest.TestCase):
settings = DiscoSettings()

@property
Expand All @@ -87,51 +37,6 @@ def ddfs(self):
def disco(self):
return Disco(settings=self.settings)

def assertCommErrorCode(self, code, callable):
from disco.error import CommError
try:
ret = callable()
except CommError, e:
return self.assertEquals(code, e.code)
except Exception, e:
raise AssertionError('CommError not raised, got %s' % e)
raise AssertionError('CommError not raised (expected %d), '
'returned %s' % (code, ret))

def run(self, result=None):
self.is_running = True
signal.signal(signal.SIGINT, InterruptTest(self))
super(DiscoTestCase, self).run(result)
self.is_running = False

class DiscoJobTestFixture(object):
jobargs = ('input',
'combiner',
'map',
'map_init',
'map_input_stream',
'map_output_stream',
'map_reader',
'merge_partitions',
'params',
'partition',
'partitions',
'profile',
'save',
'scheduler',
'sort',
'reduce',
'reduce_init',
'reduce_input_stream',
'reduce_output_stream',
'reduce_reader',
'required_files',
'required_modules',
'ext_params',
'ext_map',
'ext_reduce')
result_reader = staticmethod(disco.func.chain_reader)

@property
def nodes(self):
return dict((host, info['max_workers'])
Expand All @@ -147,131 +52,122 @@ def test_server_address(self):
return (str(self.settings['DISCO_TEST_HOST']),
int(self.settings['DISCO_TEST_PORT']))

@property
def profile(self):
return bool(self.settings['DISCO_TEST_PROFILE'])
def assertAllEqual(self, results, answers):
from disco.future import izip_longest as zip
for result, answer in zip(results, answers):
self.assertEquals(result, answer)

@property
def results(self):
return result_iterator(self.job.wait(), reader=self.result_reader)
def assertCommErrorCode(self, code, callable):
from disco.error import CommError
try:
ret = callable()
except CommError, e:
return self.assertEquals(code, e.code)
except Exception, e:
raise AssertionError('CommError not raised, got %s' % e)
raise AssertionError('CommError not raised (expected %d), '
'returned %s' % (code, ret))

@property
def input(self):
return self.test_server.urls(self.inputs)
def assertResults(self, job, answers):
self.assertAllEqual(self.results(job), answers)

def getdata(self, path):
pass
def results(self, job, **kwargs):
return result_iterator(job.wait(), **kwargs)

def setUp(self):
self.test_server = TestServer.create(self.test_server_address, self.getdata)
self.test_server.start()
try:
jobargs = {'name': self.__class__.__name__}
for jobarg in self.jobargs:
if hasattr(self, jobarg):
jobargs[jobarg] = getattr(self, jobarg)
def run(self, result=None):
self.is_running = True
signal.signal(signal.SIGINT, InterruptTest(self))
super(TestCase, self).run(result)
self.is_running = False

self.job = self.disco.new_job(**jobargs)
except:
self.test_server.stop()
raise
def setUp(self):
if hasattr(self, 'serve'):
self.test_server = TestServer.create(self.test_server_address, self.serve)
self.test_server.start()

def tearDown(self):
self.test_server.stop()
if self.settings['DISCO_TEST_PURGE']:
if hasattr(self, 'serve'):
self.test_server.stop()
if hasattr(self, 'job') and self.settings['DISCO_TEST_PURGE']:
self.job.purge()

def runTest(self):
from disco.future import izip_longest as zip
for result, answer in zip(self.results, self.answers):
self.assertEquals(result, answer)

def skipTest(self, message):
# Workaround for python2.5 which doesn't have skipTest in unittests
# make sure calls to skipTest are the last statement in a code branch
# (until we drop 2.5 support)
try:
super(DiscoJobTestFixture, self).skipTest(message)
super(TestCase, self).skipTest(message)
except AttributeError, e:
pass

class DiscoMultiJobTestFixture(DiscoJobTestFixture):
@staticmethod
def result_reader(m):
return disco.func.chain_reader

def profile(self, m):
class TestJob(Job):
@property
def profile(self):
return bool(self.settings['DISCO_TEST_PROFILE'])

def results(self, m):
return result_iterator(self.jobs[m].wait(),
reader=getattr(self, 'result_reader_%d' % (m + 1)))

def input(self, m):
return self.test_servers[m].urls(getattr(self, 'inputs_%d' % (m + 1)))

def __getattribute__(self, name):
try:
return super(DiscoMultiJobTestFixture, self).__getattribute__(name)
except AttributeError:
for prefix in ('input', 'profile', 'results', 'result_reader'):
if name.startswith('%s_' % prefix):
attribute, n = name.rsplit('_', 1)
return getattr(self, attribute)(int(n) - 1)
raise

def setUp(self):
host, port = self.test_server_address
self.test_servers = [None] * self.njobs
self.jobs = [None] * self.njobs
for m in xrange(self.njobs):
n = m + 1
self.test_servers[m] = TestServer.create((host, port + m),
getattr(self, 'getdata_%d' % n, self.getdata))
self.test_servers[m].start()
try:
jobargs = {'name': '%s_%d' % (self.__class__.__name__, n)}
for jobarg in self.jobargs:
attr = getattr(self, '%s_%d' % (jobarg, n), None)
if attr:
jobargs[jobarg] = attr

self.jobs[m] = self.disco.new_job(**jobargs)
setattr(self, 'job_%d' % n, self.jobs[m])
except:
for k in xrange(n):
self.test_servers[k].stop()
raise

def tearDown(self):
for m in xrange(self.njobs):
self.test_servers[m].stop()
if self.settings['DISCO_TEST_PURGE']:
self.jobs[m].purge()

def runTest(self):
for m in xrange(self.njobs):
n = m + 1
for result, answer in zip(getattr(self, 'results_%d' % n),
getattr(self, 'answers_%d' % n)):
self.assertEquals(result, answer)

class DiscoTestLoader(TestLoader):
class TestLoader(unittest.TestLoader):
def __init__(self, settings):
super(DiscoTestLoader, self).__init__()
super(TestLoader, self).__init__()
self.settings = settings

def loadTestsFromTestCase(self, testCaseClass):
if issubclass(testCaseClass, DiscoTestCase):
if issubclass(testCaseClass, TestCase):
testCaseClass.settings = self.settings
return super(DiscoTestLoader, self).loadTestsFromTestCase(testCaseClass)
return super(TestLoader, self).loadTestsFromTestCase(testCaseClass)

class DiscoTestRunner(TextTestRunner):
class TestRunner(unittest.TextTestRunner):
def __init__(self, settings):
debug_levels = {'off': 0, 'log': 1, 'trace': 2}
super(DiscoTestRunner, self).__init__(verbosity=debug_levels[settings['DISCO_DEBUG']])
super(TestRunner, self).__init__(verbosity=debug_levels[settings['DISCO_DEBUG']])
self.settings = settings

def run(self, *names):
suite = DiscoTestLoader(self.settings).loadTestsFromNames(names)
return super(DiscoTestRunner, self).run(suite)
suite = TestLoader(self.settings).loadTestsFromNames(names)
return super(TestRunner, self).run(suite)

class TestServer(ThreadingMixIn, HTTPServer):
allow_reuse_address = True

@property
def address(self):
return 'http://%s:%d' % self.server_address

@classmethod
def create(cls, server_address, data_generator):
return cls(server_address, handler(data_generator))

def start(self):
self.thread = Thread(target=self.serve_forever)
self.thread.start()

def stop(self):
# Workaround for Python2.5 which doesn't have shutdown
if hasattr(self, "shutdown"):
self.shutdown()
self.socket.close()

def urls(self, inputs):
def serverify(input):
return '%s/%s' % (self.address, input)
return [[serverify(url) for url in iterify(input)] for input in inputs]

class FailedReply(Exception):
pass

def handler(data_generator):
class Handler(BaseHTTPRequestHandler):
def send_data(self, data):
self.send_response(OK)
self.send_header('Content-length', len(data or []))
self.end_headers()
self.wfile.write(data)

def do_GET(self):
try:
self.send_data(data_generator(self.path.strip('/')))
except FailedReply, e:
self.send_error(INTERNAL_SERVER_ERROR, str(e))

def log_request(*args):
pass # suppress logging output for now
return Handler

0 comments on commit 34ec37c

Please sign in to comment.