Skip to content

Commit

Permalink
ENH: add version of CS without the "contribution" of the logit error
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffgortmaker committed May 11, 2022
1 parent 1a670db commit a75c608
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pyblp/markets/results_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def safely_compute_profits(

@NumericalErrorHandler(exceptions.PostEstimationNumericalError)
def safely_compute_consumer_surplus(
self, keep_all: bool, eliminate_product_ids: Optional[Any], prices: Optional[Array]) -> (
Tuple[Array, List[Error]]):
self, keep_all: bool, eliminate_product_ids: Optional[Any], include_logit_error: bool,
prices: Optional[Array]) -> Tuple[Array, List[Error]]:
"""Estimate population-normalized consumer surplus or keep all individual-level surpluses. By default, use
unchanged prices, handling any numerical errors.
"""
Expand Down Expand Up @@ -399,6 +399,8 @@ def safely_compute_consumer_surplus(
# compute individual-level consumer surpluses
numerator = np.log(np.exp(log_scale) + (scale_weights * exp_utilities).sum(axis=0, keepdims=True)) - log_scale
surpluses = numerator / derivatives
if not include_logit_error:
surpluses -= np.log(1 + self.J) / derivatives
if keep_all:
return surpluses, errors

Expand Down
18 changes: 14 additions & 4 deletions pyblp/results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,14 +953,14 @@ def compute_profit_hessians(

def compute_consumer_surpluses(
self, prices: Optional[Any] = None, keep_all: bool = False, eliminate_product_ids: Optional[Any] = None,
market_id: Optional[Any] = None) -> Array:
include_logit_error: bool = True, market_id: Optional[Any] = None) -> Array:
r"""Estimate population-normalized consumer surpluses, :math:`\text{CS}`.
Assuming away nonlinear income effects, the surplus in market :math:`t` is
.. math:: \text{CS} = \sum_{i \in I_t} w_{it}\text{CS}_{it},
in which the consumer surplus for individual :math:`i` is
in which the consumer surplus for individual :math:`i` is, up to an unknown constant,
.. math::
Expand All @@ -978,6 +978,12 @@ def compute_consumer_surpluses(
where :math:`V_{ijt}` is defined in :eq:`utilities` and :math:`V_{iht}` is defined in :eq:`inclusive_value`.
These expressions include the contribution of the logit errors :math:`\epsilon_{ijt}` to consumer surplus:
.. math::
\log(1 + J_t) \Big/ \left(-\frac{\partial V_{i1t}}{\partial p_{1t}}\right).
.. warning::
:math:`\frac{\partial V_{1ti}}{\partial p_{1t}}` is the derivative of utility for the first product with
Expand All @@ -1002,6 +1008,10 @@ def compute_consumer_surpluses(
IDs of the products to eliminate from the choice set. These IDs should show up in the ``product_ids`` field
of ``product_data`` in :class:`Problem`. Eliminating one or more products and comparing consumer surpluses
gives a measure of willingness to pay for these products.
include_logit_error : `bool, optional`
Whether to include the above contribution of the logit errors :math:`\epsilon_{ijt}`. By default, their
contribution is included. If ``False``, it is subtracted. When comparing changes in consumer surplus, this
will only matter if the number of alternatives changes.
market_id : `object, optional`
ID of the market in which to compute consumer surplus. By default, consumer surpluses are computed in all
markets and stacked.
Expand All @@ -1024,6 +1034,6 @@ def compute_consumer_surpluses(
market_ids = self._select_market_ids(market_id)
prices = self._coerce_optional_prices(prices, market_ids)
return self._combine_arrays(
ResultsMarket.safely_compute_consumer_surplus, market_ids, fixed_args=[keep_all, eliminate_product_ids],
market_args=[prices]
ResultsMarket.safely_compute_consumer_surplus, market_ids,
fixed_args=[keep_all, eliminate_product_ids, include_logit_error], market_args=[prices]
)
7 changes: 7 additions & 0 deletions tests/test_blp.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,18 @@ def test_surplus(simulated_problem: SimulatedProblemFixture) -> None:
# compute surpluses for a single market
t = problem.products.market_ids[0]
surpluses = results.compute_consumer_surpluses(market_id=t, keep_all=True)
observed_surpluses = results.compute_consumer_surpluses(market_id=t, include_logit_error=False, keep_all=True)
surplus = results.compute_consumer_surpluses(market_id=t)
observed_surplus = results.compute_consumer_surpluses(market_id=t, include_logit_error=False)

# test that removing the contribution of the logit error reduces surplus
np.testing.assert_array_less(observed_surpluses, surpluses)
np.testing.assert_array_less(observed_surplus, surplus)

# test that we get the same result when manually integrating over surpluses
weights = problem.agents.weights[problem.agents.market_ids.flat == t]
np.testing.assert_allclose(surpluses @ weights, surplus, atol=1e-14, rtol=0, verbose=True)
np.testing.assert_allclose(observed_surpluses @ weights, observed_surplus, atol=1e-14, rtol=0, verbose=True)


@pytest.mark.usefixtures('simulated_problem')
Expand Down

0 comments on commit a75c608

Please sign in to comment.