diff --git a/pydgraph/client.py b/pydgraph/client.py index 639c07a..d12e5da 100755 --- a/pydgraph/client.py +++ b/pydgraph/client.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Dgraph python client.""" + import random from pydgraph import txn, util @@ -19,39 +21,49 @@ from pydgraph.proto import api_pb2 as api __author__ = 'Mohit Ranka ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' __version__ = VERSION __status__ = 'development' class DgraphClient(object): """Creates a new Client for interacting with the Dgraph store. - + The client can be backed by multiple connections (to the same server, or multiple servers in a cluster). """ def __init__(self, *clients): - if len(clients) == 0: + if not clients: raise ValueError('No clients provided in DgraphClient constructor') self._clients = clients[:] self._lin_read = api.LinRead() - def alter(self, op, timeout=None, metadata=None, credentials=None): - return self.any_client().alter(op, timeout=timeout, metadata=metadata, credentials=credentials) - - def query(self, q, variables=None, timeout=None, metadata=None, credentials=None): - return self.txn().query(q, variables=variables, timeout=timeout, metadata=metadata, credentials=credentials) + def alter(self, operation, timeout=None, metadata=None, credentials=None): + """Runs a modification via this client.""" + return self.any_client().alter(operation, timeout=timeout, + metadata=metadata, + credentials=credentials) + + def query(self, query, variables=None, timeout=None, metadata=None, + credentials=None): + """Runs a query via this client.""" + return self.txn().query(query, variables=variables, timeout=timeout, + metadata=metadata, credentials=credentials) def txn(self): + """Creates a transaction.""" return txn.Txn(self) def set_lin_read(self, ctx): + """Sets linread map in ctx to the one from this instace.""" ctx.lin_read.MergeFrom(self._lin_read) def merge_lin_reads(self, src): + """Merges linread map in this instance with src.""" util.merge_lin_reads(self._lin_read, src) def any_client(self): + """Returns a random client.""" return random.choice(self._clients) diff --git a/pydgraph/client_stub.py b/pydgraph/client_stub.py index ecb3aa8..1e16cd2 100644 --- a/pydgraph/client_stub.py +++ b/pydgraph/client_stub.py @@ -12,41 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Stub for RPC request.""" + import grpc from pydgraph.meta import VERSION from pydgraph.proto import api_pb2_grpc as api_grpc __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' __version__ = VERSION __status__ = 'development' class DgraphClientStub(object): + """Stub for the Dgraph grpc client.""" + def __init__(self, addr='localhost:9080', credentials=None, options=None): if credentials is None: self.channel = grpc.insecure_channel(addr, options) else: - self.channel = grpc.secure_channel(addr, credentials, options) - + self.channel = grpc.secure_channel(addr, credentials, options) + self.stub = api_grpc.DgraphStub(self.channel) - def alter(self, op, timeout=None, metadata=None, credentials=None): - return self.stub.Alter(op, timeout=timeout, metadata=metadata, credentials=credentials) + def alter(self, operation, timeout=None, metadata=None, credentials=None): + """Runs alter operation.""" + return self.stub.Alter(operation, timeout=timeout, metadata=metadata, + credentials=credentials) def query(self, req, timeout=None, metadata=None, credentials=None): - return self.stub.Query(req, timeout=timeout, metadata=metadata, credentials=credentials) - - def mutate(self, mu, timeout=None, metadata=None, credentials=None): - return self.stub.Mutate(mu, timeout=timeout, metadata=metadata, credentials=credentials) - - def commit_or_abort(self, ctx, timeout=None, metadata=None, credentials=None): - return self.stub.CommitOrAbort(ctx, timeout=timeout, metadata=metadata, credentials=credentials) - - def check_version(self, check, timeout=None, metadata=None, credentials=None): - return self.stub.CheckVersion(check, timeout=timeout, metadata=metadata, credentials=credentials) + """Runs query operation.""" + return self.stub.Query(req, timeout=timeout, metadata=metadata, + credentials=credentials) + + def mutate(self, mutation, timeout=None, metadata=None, credentials=None): + """Runs mutate operation.""" + return self.stub.Mutate(mutation, timeout=timeout, metadata=metadata, + credentials=credentials) + + def commit_or_abort(self, ctx, timeout=None, metadata=None, + credentials=None): + """Runs commit or abort operation.""" + return self.stub.CommitOrAbort(ctx, timeout=timeout, metadata=metadata, + credentials=credentials) + + def check_version(self, check, timeout=None, metadata=None, + credentials=None): + """Returns the version of the Dgraph instance.""" + return self.stub.CheckVersion(check, timeout=timeout, + metadata=metadata, + credentials=credentials) def close(self): + """Deletes channel and stub.""" del self.channel del self.stub diff --git a/pydgraph/errors.py b/pydgraph/errors.py index 4f54f37..da453e2 100644 --- a/pydgraph/errors.py +++ b/pydgraph/errors.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Errors thrown by the Dgraph client.""" + from pydgraph.meta import VERSION __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' __version__ = VERSION __status__ = 'development' class AbortedError(Exception): + """Error thrown by aborted transactions.""" + def __init__(self): - super(AbortedError, self).__init__('Transaction has been aborted. Please retry') + super(AbortedError, self).__init__( + 'Transaction has been aborted. Please retry') diff --git a/pydgraph/meta.py b/pydgraph/meta.py index c2b6553..cf544b1 100644 --- a/pydgraph/meta.py +++ b/pydgraph/meta.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Metadata about this package.""" + VERSION = '1.0.0' diff --git a/pydgraph/txn.py b/pydgraph/txn.py index cfe997a..c7232bb 100644 --- a/pydgraph/txn.py +++ b/pydgraph/txn.py @@ -12,30 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc +"""Dgraph atomic transaction support.""" + import json +import grpc from pydgraph import errors, util from pydgraph.meta import VERSION from pydgraph.proto import api_pb2 as api __author__ = 'Shailesh Kochhar ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' __version__ = VERSION __status__ = 'development' class Txn(object): """Txn is a single atomic transaction. - + A transaction lifecycle is as follows: - + 1. Created using Client.newTxn. - - 2. Various query and mutate calls made. - - 3. commit or discard used. If any mutations have been made, It's important - that at least one of these methods is called to clean up resources. discard + + 2. Modified via calls to query and mutate. + + 3. Committed or discarded. If any mutations have been made, it's important + that at least one of these methods is called to clean up resources. Discard is a no-op if commit has already been called, so it's safe to call discard after calling commit. """ @@ -50,129 +52,153 @@ def __init__(self, client): self._sequencing = api.LinRead.CLIENT_SIDE def sequencing(self, sequencing): + """Sets sequencing.""" self._sequencing = sequencing - def query(self, q, variables=None, timeout=None, metadata=None, credentials=None): - req = self._common_query(q, variables=variables) - res = self._dc.any_client().query(req, timeout=timeout, metadata=metadata, credentials=credentials) + def query(self, query, variables=None, timeout=None, metadata=None, + credentials=None): + """Adds a query operation to the transaction.""" + req = self._common_query(query, variables=variables) + res = self._dc.any_client().query(req, timeout=timeout, + metadata=metadata, + credentials=credentials) self.merge_context(res.txn) return res - - def _common_query(self, q, variables=None): + + def _common_query(self, query, variables=None): if self._finished: - raise Exception('Transaction has already been committed or discarded') + raise Exception( + 'Transaction has already been committed or discarded') lin_read = self._ctx.lin_read lin_read.sequencing = self._sequencing - req = api.Request(query=q, start_ts=self._ctx.start_ts, lin_read=lin_read) + req = api.Request(query=query, start_ts=self._ctx.start_ts, + lin_read=lin_read) if variables is not None: for key, value in variables.items(): if util.is_string(key) and util.is_string(value): req.vars[key] = value - + return req - def mutate(self, mu=None, set_obj=None, del_obj=None, set_nquads=None, del_nquads=None, commit_now=None, - ignore_index_conflict=None, timeout=None, metadata=None, credentials=None): - mu = self._common_mutate(mu=mu, set_obj=set_obj, del_obj=del_obj, set_nquads=set_nquads, del_nquads=del_nquads, - commit_now=commit_now, ignore_index_conflict=ignore_index_conflict) + def mutate(self, mutation=None, set_obj=None, del_obj=None, set_nquads=None, + del_nquads=None, commit_now=None, ignore_index_conflict=None, + timeout=None, metadata=None, credentials=None): + """Adds a mutate operation to the transaction.""" + mutation = self._common_mutate( + mutation=mutation, set_obj=set_obj, del_obj=del_obj, + set_nquads=set_nquads, del_nquads=del_nquads, + commit_now=commit_now, ignore_index_conflict=ignore_index_conflict) try: - ag = self._dc.any_client().mutate(mu, timeout=timeout, metadata=metadata, credentials=credentials) - except Exception as e: + assigned = self._dc.any_client().mutate(mutation, timeout=timeout, + metadata=metadata, + credentials=credentials) + except Exception as error: try: - self.discard(timeout=timeout, metadata=metadata, credentials=credentials) + self.discard(timeout=timeout, metadata=metadata, + credentials=credentials) except: # Ignore error - user should see the original error. pass - self._common_except_mutate(e) + self._common_except_mutate(error) - if mu.commit_now: + if mutation.commit_now: self._finished = True - self.merge_context(ag.context) - return ag - - def _common_mutate(self, mu=None, set_obj=None, del_obj=None, set_nquads=None, del_nquads=None, + self.merge_context(assigned.context) + return assigned + + def _common_mutate(self, mutation=None, set_obj=None, del_obj=None, + set_nquads=None, del_nquads=None, commit_now=None, ignore_index_conflict=None): - if not mu: - mu = api.Mutation() + if not mutation: + mutation = api.Mutation() if set_obj: - mu.set_json = json.dumps(set_obj).encode('utf8') + mutation.set_json = json.dumps(set_obj).encode('utf8') if del_obj: - mu.delete_json = json.dumps(del_obj).encode('utf8') + mutation.delete_json = json.dumps(del_obj).encode('utf8') if set_nquads: - mu.set_nquads = set_nquads.encode('utf8') + mutation.set_nquads = set_nquads.encode('utf8') if del_nquads: - mu.del_nquads = del_nquads.encode('utf8') + mutation.del_nquads = del_nquads.encode('utf8') if commit_now: - mu.commit_now = True + mutation.commit_now = True if ignore_index_conflict: - mu.ignore_index_conflict = True - + mutation.ignore_index_conflict = True + if self._finished: - raise Exception('Transaction has already been committed or discarded') - + raise Exception( + 'Transaction has already been committed or discarded') + self._mutated = True - mu.start_ts = self._ctx.start_ts - return mu + mutation.start_ts = self._ctx.start_ts + return mutation @staticmethod - def _common_except_mutate(e): - if isinstance(e, grpc._channel._Rendezvous): - e.details() - status_code = e.code() - if status_code == grpc.StatusCode.ABORTED or status_code == grpc.StatusCode.FAILED_PRECONDITION: + def _common_except_mutate(error): + if isinstance(error, grpc._channel._Rendezvous): + error.details() + status_code = error.code() + if (status_code == grpc.StatusCode.ABORTED or + status_code == grpc.StatusCode.FAILED_PRECONDITION): raise errors.AbortedError() - - raise e + + raise error def commit(self, timeout=None, metadata=None, credentials=None): + """Commits the transaction.""" if not self._common_commit(): return try: - self._dc.any_client().commit_or_abort(self._ctx, timeout=timeout, metadata=metadata, + self._dc.any_client().commit_or_abort(self._ctx, timeout=timeout, + metadata=metadata, credentials=credentials) - except Exception as e: - self._common_except_commit(e) - + except Exception as error: + self._common_except_commit(error) + def _common_commit(self): if self._finished: - raise Exception('Transaction has already been committed or discarded') - + raise Exception( + 'Transaction has already been committed or discarded') + self._finished = True return self._mutated @staticmethod - def _common_except_commit(e): - if isinstance(e, grpc._channel._Rendezvous): - e.details() - status_code = e.code() + def _common_except_commit(error): + if isinstance(error, grpc._channel._Rendezvous): + error.details() + status_code = error.code() if status_code == grpc.StatusCode.ABORTED: raise errors.AbortedError() - raise e + raise error def discard(self, timeout=None, metadata=None, credentials=None): + """Discards the transaction.""" if not self._common_discard(): return - self._dc.any_client().commit_or_abort(self._ctx, timeout=timeout, metadata=metadata, credentials=credentials) - + self._dc.any_client().commit_or_abort(self._ctx, timeout=timeout, + metadata=metadata, + credentials=credentials) + def _common_discard(self): if self._finished: return False - + self._finished = True if not self._mutated: return False - + self._ctx.aborted = True return True - + def merge_context(self, src=None): + """Merges context from this instance with src.""" if src is None: # This condition will be true only if the server doesn't return a # txn context after a query or mutation. diff --git a/pydgraph/util.py b/pydgraph/util.py index 273e7d9..8ba8e1c 100644 --- a/pydgraph/util.py +++ b/pydgraph/util.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Various utility functions.""" + import sys from pydgraph.meta import VERSION __author__ = 'Shailesh Kochhar ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' __version__ = VERSION __status__ = 'development' def merge_lin_reads(target, src): + """Merger src linread map into target linread map.""" if src is None: return target @@ -37,8 +40,9 @@ def merge_lin_reads(target, src): return target -def is_string(s): +def is_string(string): + """Checks if argument is a string. Compatible with Python 2 and 3.""" if sys.version_info[0] < 3: - return isinstance(s, basestring) + return isinstance(string, basestring) - return isinstance(s, str) + return isinstance(string, str) diff --git a/tests/helper.py b/tests/helper.py index 0e44ae1..569c57c 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utilities used by tests.""" + __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest @@ -21,23 +23,25 @@ def create_lin_read(src_ids): - lr = pydgraph.LinRead() - ids = lr.ids + """Creates a linread map using src_ids.""" + lin_read = pydgraph.LinRead() + ids = lin_read.ids for key, value in src_ids.items(): ids[key] = value - return lr + return lin_read -def are_lin_reads_equal(a, b): - a_ids = a.ids - b_ids = b.ids +def are_lin_reads_equal(lin_read1, lin_read2): + """Returns True if both linread maps are equal.""" + ids1 = lin_read1.ids + ids2 = lin_read2.ids - if len(a_ids) != len(b_ids): + if len(ids1) != len(ids2): return False - for (key, value) in a_ids.items(): - if key not in b_ids or b.ids[key] != value: + for (key, value) in ids1.items(): + if key not in ids2 or lin_read2.ids[key] != value: return False return True @@ -47,21 +51,25 @@ def are_lin_reads_equal(a, b): def create_client(addr=SERVER_ADDR): + """Creates a new client object using the given address.""" return pydgraph.DgraphClient(pydgraph.DgraphClientStub(addr)) -def set_schema(c, schema): - return c.alter(pydgraph.Operation(schema=schema)) +def set_schema(client, schema): + """Sets the schema in the given client.""" + return client.alter(pydgraph.Operation(schema=schema)) -def drop_all(c): - return c.alter(pydgraph.Operation(drop_all=True)) +def drop_all(client): + """Drops all data in the given client.""" + return client.alter(pydgraph.Operation(drop_all=True)) def setup(): - c = create_client() - drop_all(c) - return c + """Creates a new client and drops all existing data.""" + client = create_client() + drop_all(client) + return client class ClientIntegrationTestCase(unittest.TestCase): diff --git a/tests/test_acct_upsert.py b/tests/test_acct_upsert.py index c299c5c..3201101 100644 --- a/tests/test_acct_upsert.py +++ b/tests/test_acct_upsert.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests to verify upsert directive.""" + __author__ = 'Shailesh Kochhar ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest import logging @@ -33,6 +35,7 @@ class TestAccountUpsert(helper.ClientIntegrationTestCase): + """Tests to verify upsert directive.""" def setUp(self): super(TestAccountUpsert, self).setUp() @@ -63,7 +66,8 @@ def do_upserts(self, account_list, concurrency): retry_ctr = multiprocessing.Value('i', 0, lock=True) def _updater(acct): - upsert_account(addr=self.TEST_SERVER_ADDR, account=acct, success_ctr=success_ctr, retry_ctr=retry_ctr) + upsert_account(addr=self.TEST_SERVER_ADDR, account=acct, + success_ctr=success_ctr, retry_ctr=retry_ctr) pool = mpd.Pool(concurrency) results = [ @@ -71,21 +75,21 @@ def _updater(acct): for acct in account_list for _ in range(concurrency) ] - [res.get() for res in results] + _ = [res.get() for res in results] pool.close() def assert_changes(self, firsts, accounts): """Will check to see changes have been made.""" - q = """{{ + query = """{{ all(func: anyofterms(first, "{}")) {{ first last age }} }}""".format(' '.join(firsts)) - logging.debug(q) - result = json.loads(self.client.query(q=q).json) + logging.debug(query) + result = json.loads(self.client.query(query).json) account_set = set() for acct in result['all']: @@ -100,8 +104,9 @@ def assert_changes(self, firsts, accounts): def upsert_account(addr, account, success_ctr, retry_ctr): - c = helper.create_client(addr) - q = """{{ + """Runs upsert operation.""" + client = helper.create_client(addr) + query = """{{ acct(func:eq(first, "{first}")) @filter(eq(last, "{last}") AND eq(age, {age})) {{ uid }} @@ -110,12 +115,13 @@ def upsert_account(addr, account, success_ctr, retry_ctr): last_update_time = time.time() - 10000 while True: if time.time() > last_update_time + 10000: - logging.debug('Success: %d Retries: %d', success_ctr.value, retry_ctr.value) + logging.debug('Success: %d Retries: %d', success_ctr.value, + retry_ctr.value) last_update_time = time.time() - txn = c.txn() + txn = client.txn() try: - result = json.loads(txn.query(q=q).json) + result = json.loads(txn.query(query).json) assert len(result['acct']) <= 1, ('Lookup of account %s found ' 'multiple accounts' % account) @@ -135,7 +141,8 @@ def upsert_account(addr, account, success_ctr, retry_ctr): uid = acct['uid'] assert uid is not None, 'Account with uid None' - updatequads = '<{0}> "{1:d}"^^ .'.format(uid, int(time.time())) + updatequads = '<{0}> "{1:d}"^^ .'.format( + uid, int(time.time())) txn.mutate(set_nquads=updatequads) txn.commit() @@ -153,9 +160,10 @@ def upsert_account(addr, account, success_ctr, retry_ctr): def suite(): - s = unittest.TestSuite() - s.addTest(TestAccountUpsert()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestAccountUpsert()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_bank.py b/tests/test_bank.py index e9f0919..4cf0b54 100644 --- a/tests/test_bank.py +++ b/tests/test_bank.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Test that runs through example bank transactions.""" + __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest import logging @@ -31,6 +33,8 @@ class TestBank(helper.ClientIntegrationTestCase): + """Test that runs through example bank transactions.""" + def setUp(self): super(TestBank, self).setUp() @@ -54,10 +58,11 @@ def test_bank_transfer(self): pool = mpd.Pool(CONCURRENCY) results = [pool.apply_async( run_transfers, - (self.TEST_SERVER_ADDR, TRANSFER_COUNT, self.uids, success_ctr, retry_ctr) + (self.TEST_SERVER_ADDR, TRANSFER_COUNT, self.uids, success_ctr, + retry_ctr) ) for _ in range(CONCURRENCY)] - [res.get() for res in results] + _ = [res.get() for res in results] pool.close() finally: total_watcher.terminate() @@ -85,6 +90,7 @@ def start_total_watcher(self): def looper(func, *args, **kwargs): + """Returns a function that runs func in an infinite loop.""" def _looper(): while True: func(*args, **kwargs) @@ -93,10 +99,10 @@ def _looper(): return _looper -def run_total(c, uids): +def run_total(client, uids): """Calculates the total amount in the accounts.""" - q = """{{ + query = """{{ var(func: uid("{uids:s}")) {{ b as bal }} @@ -105,16 +111,17 @@ def run_total(c, uids): }} }}""".format(uids='", "'.join(uids)) - resp = c.query(q) + resp = client.query(query) total = json.loads(resp.json)['total'] logging.info('Response: %s', total) assert total[0]['bal'] == 10000 def run_transfers(addr, transfer_count, account_ids, success_ctr, retry_ctr): + """Runs transfers between the given accounts.""" pname = mpd.current_process().name log = logging.getLogger('test_bank.run_transfers[%s]' % (pname,)) - c = helper.create_client(addr) + client = helper.create_client(addr) while True: from_acc, to_acc = select_account_pair(account_ids) @@ -125,7 +132,7 @@ def run_transfers(addr, transfer_count, account_ids, success_ctr, retry_ctr): }} }}""".format(uid1=from_acc, uid2=to_acc) - txn = c.txn() + txn = client.txn() try: accounts = load_from_query(txn, query, 'me') accounts[0]['bal'] += 5 @@ -135,10 +142,11 @@ def run_transfers(addr, transfer_count, account_ids, success_ctr, retry_ctr): success_ctr.value += 1 if not success_ctr.value % 100: - log.info('Runs %d. Aborts: %d', success_ctr.value, retry_ctr.value) + log.info('Runs %d. Aborts: %d', success_ctr.value, + retry_ctr.value) if success_ctr.value >= transfer_count: break - except: + except BaseException: with retry_ctr.get_lock(): retry_ctr.value += 1 @@ -163,6 +171,7 @@ def load_from_query(txn, query, field): def dump_from_obj(txn, obj, commit=False): + """Dumps the given object into the transaction.""" assigned = txn.mutate(set_obj=obj) if not commit: @@ -171,9 +180,10 @@ def dump_from_obj(txn, obj, commit=False): def suite(): - s = unittest.TestSuite() - s.addTest(TestBank()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestBank()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_client.py b/tests/test_client.py index cbf1c5d..1fc040c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests construction of Dgraph client.""" + __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest @@ -21,15 +23,17 @@ class TestDgraphClient(unittest.TestCase): + """Tests construction of Dgraph client.""" def test_constructor(self): with self.assertRaises(ValueError): pydgraph.DgraphClient() def suite(): - s = unittest.TestSuite() - s.addTest(TestDgraphClient()) - return s + """Returns a tests suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestDgraphClient()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_client_stub.py b/tests/test_client_stub.py index 254698a..61e0a1a 100644 --- a/tests/test_client_stub.py +++ b/tests/test_client_stub.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests client stub.""" + __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest import sys @@ -22,6 +24,7 @@ from . import helper class TestDgraphClientStub(helper.ClientIntegrationTestCase): + """Tests client stub.""" def validate_version_object(self, version): tag = version.tag @@ -51,9 +54,10 @@ def test_close(self): def suite(): - s = unittest.TestSuite() - s.addTest(TestDgraphClientStub()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestDgraphClientStub()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_essentials.py b/tests/test_essentials.py index 022e3b4..7f0555a 100644 --- a/tests/test_essentials.py +++ b/tests/test_essentials.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests mutation after query behavior.""" + __author__ = 'Shailesh Kochhar ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest import logging @@ -23,6 +25,8 @@ class TestEssentials(helper.ClientIntegrationTestCase): + """Tests mutation after query behavior.""" + def testMutationAfterQuery(self): """Tests what happens when making a mutation on a txn after querying on the client.""" @@ -44,9 +48,10 @@ def testMutationAfterQuery(self): def suite(): - s = unittest.TestSuite() - s.addTest(TestEssentials()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestEssentials()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_queries.py b/tests/test_queries.py index 68b7e5b..8a61a5d 100755 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests behavior of queries after mutation in the same transaction.""" + __author__ = 'Mohit Ranka ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest import sys @@ -26,6 +28,8 @@ class TestQueries(helper.ClientIntegrationTestCase): + """Tests behavior of queries after mutation in the same transaction.""" + def setUp(self): super(TestQueries, self).setUp() @@ -33,6 +37,8 @@ def setUp(self): helper.set_schema(self.client, 'name: string @index(term) .') def test_mutation_and_query(self): + """Runs mutation and verifies queries see the results.""" + txn = self.client.txn() _ = txn.mutate(pydgraph.Mutation(commit_now=True), set_nquads=""" <_:alice> \"Alice\" . @@ -52,23 +58,29 @@ def test_mutation_and_query(self): }""" response = self.client.query(query, variables={'$a': 'Alice'}) - self.assertEqual([{'name': 'Alice', 'follows': [{'name': 'Greg'}]}], json.loads(response.json).get('me')) - self.assertTrue(is_number(response.latency.parsing_ns), 'Parsing latency is not available') - self.assertTrue(is_number(response.latency.processing_ns), 'Processing latency is not available') - self.assertTrue(is_number(response.latency.encoding_ns), 'Encoding latency is not available') - - -def is_number(n): + self.assertEqual([{'name': 'Alice', 'follows': [{'name': 'Greg'}]}], + json.loads(response.json).get('me')) + self.assertTrue(is_number(response.latency.parsing_ns), + 'Parsing latency is not available') + self.assertTrue(is_number(response.latency.processing_ns), + 'Processing latency is not available') + self.assertTrue(is_number(response.latency.encoding_ns), + 'Encoding latency is not available') + + +def is_number(number): + """Returns true if object is a number. Compatible with Python 2 and 3.""" if sys.version_info[0] < 3: - return isinstance(n, (int, long)) + return isinstance(number, (int, long)) - return isinstance(n, int) + return isinstance(number, int) def suite(): - s = unittest.TestSuite() - s.addTest(TestQueries()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestQueries()) + return suite_obj if __name__ == '__main__': diff --git a/tests/test_util.py b/tests/test_util.py index e6d900c..2480272 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests utility functions.""" + __author__ = 'Garvit Pahal ' -__maintainer__ = 'Garvit Pahal ' +__maintainer__ = 'Martin Martinez Rivera ' import unittest @@ -24,6 +26,8 @@ class TestMergeLinReads(unittest.TestCase): + """Tests merge_lin_reads utility function.""" + def common_test(self, lr1, lr2, expected): self.assertTrue(helper.are_lin_reads_equal(util.merge_lin_reads(lr1, lr2), expected)) self.assertTrue(helper.are_lin_reads_equal(lr1, expected)) @@ -33,37 +37,37 @@ def test_disjoint(self): lr2 = helper.create_lin_read({2: 2, 3: 3}) res = helper.create_lin_read({1: 1, 2: 2, 3: 3}) self.common_test(lr1, lr2, res) - + def test_lower_value(self): lr1 = helper.create_lin_read({1: 2}) lr2 = helper.create_lin_read({1: 1}) res = helper.create_lin_read({1: 2}) self.common_test(lr1, lr2, res) - + def test_higher_value(self): lr1 = helper.create_lin_read({1: 1}) lr2 = helper.create_lin_read({1: 2}) res = helper.create_lin_read({1: 2}) self.common_test(lr1, lr2, res) - + def test_equal_value(self): lr1 = helper.create_lin_read({1: 1}) lr2 = helper.create_lin_read({1: 1}) res = helper.create_lin_read({1: 1}) self.common_test(lr1, lr2, res) - + def test_none(self): lr1 = helper.create_lin_read({1: 1}) lr2 = None res = helper.create_lin_read({1: 1}) self.common_test(lr1, lr2, res) - + def test_no_src_ids(self): lr1 = helper.create_lin_read({1: 1}) lr2 = pydgraph.LinRead() res = helper.create_lin_read({1: 1}) self.common_test(lr1, lr2, res) - + def test_no_target_ids(self): lr1 = pydgraph.LinRead() lr2 = helper.create_lin_read({1: 1}) @@ -72,6 +76,8 @@ def test_no_target_ids(self): class TestIsString(unittest.TestCase): + """Tests is_string utility function.""" + def test_is_string(self): self.assertTrue(util.is_string('')) self.assertTrue(util.is_string('a')) @@ -80,10 +86,11 @@ def test_is_string(self): def suite(): - s = unittest.TestSuite() - s.addTest(TestMergeLinReads()) - s.addTest(TestIsString()) - return s + """Returns a test suite object.""" + suite_obj = unittest.TestSuite() + suite_obj.addTest(TestMergeLinReads()) + suite_obj.addTest(TestIsString()) + return suite_obj if __name__ == '__main__':