diff --git a/corehq/apps/commtrack/tests/test_stock_report.py b/corehq/apps/commtrack/tests/test_stock_report.py index 28f65c850d50..f2febb917885 100644 --- a/corehq/apps/commtrack/tests/test_stock_report.py +++ b/corehq/apps/commtrack/tests/test_stock_report.py @@ -4,7 +4,7 @@ from django.test import TestCase from casexml.apps.case.tests.util import delete_all_xforms -from casexml.apps.stock.utils import get_current_ledger_transactions, get_current_ledger_transactions_multi +from casexml.apps.stock.utils import get_current_ledger_transactions, get_current_ledger_state from corehq.apps.commtrack.models import NewStockReport, SQLProduct, StockTransaction as STrans from casexml.apps.stock.const import REPORT_TYPE_BALANCE @@ -132,12 +132,12 @@ def test_transactions(expected): self.assertEqual({}, get_current_ledger_transactions('non-existent')) - def test_get_current_ledger_transactions_multi(self): + def test_get_current_ledger_state(self): def test_transactions(expected): - transactions = get_current_ledger_transactions_multi(self.case_ids.keys()) - for case, sections in transactions.items(): + state = get_current_ledger_state(self.case_ids.keys()) + for case, sections in state.items(): self._validate_case_data(sections, expected[case]) self._test_get_current_ledger_transactions(test_transactions) - self.assertEqual({}, get_current_ledger_transactions_multi([])) + self.assertEqual({}, get_current_ledger_state([])) diff --git a/corehq/ex-submodules/casexml/apps/phone/data_providers/case/stock.py b/corehq/ex-submodules/casexml/apps/phone/data_providers/case/stock.py index c2ee8d1cd58e..0cd223b55b7e 100644 --- a/corehq/ex-submodules/casexml/apps/phone/data_providers/case/stock.py +++ b/corehq/ex-submodules/casexml/apps/phone/data_providers/case/stock.py @@ -1,6 +1,6 @@ from collections import defaultdict from casexml.apps.stock.consumption import compute_consumption_or_default -from casexml.apps.stock.utils import get_current_ledger_transactions_multi +from casexml.apps.stock.utils import get_current_ledger_state from dimagi.utils.parsing import json_format_datetime from datetime import datetime from casexml.apps.stock.const import COMMTRACK_REPORT_XMLNS @@ -19,8 +19,8 @@ def entry_xml(id, quantity): quantity=str(int(quantity)), ) - def transaction_to_xml(trans): - return entry_xml(trans.product_id, trans.stock_on_hand) + def state_to_xml(state): + return entry_xml(state.product_id, state.stock_on_hand) def consumption_entry(case_id, product_id, section_id): consumption_value = compute_consumption_or_default( @@ -34,7 +34,7 @@ def consumption_entry(case_id, product_id, section_id): return entry_xml(product_id, consumption_value) case_ids = [case.case_id for case in case_state_list] - all_current_ledgers = get_current_ledger_transactions_multi(case_ids) + all_current_ledgers = get_current_ledger_state(case_ids) for commtrack_case in case_state_list: case_id = commtrack_case.case_id current_ledgers = all_current_ledgers[case_id] @@ -42,13 +42,13 @@ def consumption_entry(case_id, product_id, section_id): section_product_map = defaultdict(lambda: []) section_timestamp_map = defaultdict(lambda: json_format_datetime(datetime.utcnow())) for section_id in sorted(current_ledgers.keys()): - transactions_map = current_ledgers[section_id] - sorted_product_ids = sorted(transactions_map.keys()) - transactions = [transactions_map[p] for p in sorted_product_ids] - as_of = json_format_datetime(max(txn.report.date for txn in transactions)) + state_map = current_ledgers[section_id] + sorted_product_ids = sorted(state_map.keys()) + stock_states = [state_map[p] for p in sorted_product_ids] + as_of = json_format_datetime(max(txn.last_modified_date for txn in stock_states)) section_product_map[section_id] = sorted_product_ids section_timestamp_map[section_id] = as_of - yield E.balance(*(transaction_to_xml(e) for e in transactions), + yield E.balance(*(state_to_xml(e) for e in stock_states), **{'entity-id': case_id, 'date': as_of, 'section-id': section_id}) for section_id, consumption_section_id in stock_settings.section_to_consumption_types.items(): diff --git a/corehq/ex-submodules/casexml/apps/stock/utils.py b/corehq/ex-submodules/casexml/apps/stock/utils.py index 8fc8edbf0abf..9608ec705482 100644 --- a/corehq/ex-submodules/casexml/apps/stock/utils.py +++ b/corehq/ex-submodules/casexml/apps/stock/utils.py @@ -45,19 +45,33 @@ def state_stock_category(state): def get_current_ledger_transactions(case_id): """ Given a case returns a dict of all current ledger data. + { + "section_id": { + "product_id": StockTransaction, + "product_id": StockTransaction, + ... + }, + ... + } """ - trans = get_current_ledger_transactions_multi([case_id]) - return trans[case_id] + from corehq.apps.commtrack.models import StockState + results = StockState.objects.filter(case_id=case_id).values_list('case_id', 'section_id', 'product_id') + + ret = {} + for case_id, section_id, product_id in results: + sections = ret.setdefault(section_id, {}) + sections[product_id] = StockTransaction.latest(case_id, section_id, product_id) + return ret -def get_current_ledger_transactions_multi(case_ids): +def get_current_ledger_state(case_ids): """ Given a list of cases returns a dict of all current ledger data of the following format: { "case_id": { "section_id": { - "product_id": StockTransaction, - "product_id": StockTransaction, + "product_id": StockState, + "product_id": StockState, ... }, ... @@ -66,16 +80,16 @@ def get_current_ledger_transactions_multi(case_ids): } Where you get one stock transaction per product/section which is the last one seen. """ + from corehq.apps.commtrack.models import StockState if not case_ids: return {} - results = StockTransaction.objects.filter( + states = StockState.objects.filter( case_id__in=case_ids - ).values_list('case_id', 'section_id', 'product_id').distinct() - + ) ret = {case_id: {} for case_id in case_ids} - for case_id, section_id, product_id in results: - sections = ret[case_id].setdefault(section_id, {}) - sections[product_id] = StockTransaction.latest(case_id, section_id, product_id) + for state in states: + sections = ret[state.case_id].setdefault(state.section_id, {}) + sections[state.product_id] = state return ret