Skip to content

Commit

Permalink
Rollback previous PR due to breakage of non-local colab runtime.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537020902
  • Loading branch information
marcolongfils committed Jun 1, 2023
1 parent 0245558 commit 0be8552
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 154 deletions.
23 changes: 4 additions & 19 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
#

# Abseil
http_archive(
name = "com_google_absl",
sha256 = "44634eae586a7158dceedda7d8fd5cec6d1ebae08c83399f75dd9ce76324de40", # Last updated 2022-05-18
strip_prefix = "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0",
urls = ["https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip"],
)
# skylib dependency required for Abseil
http_archive(
name = "bazel_skylib",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz"],
sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728",
git_repository(
name = "com_google_absl",
tag = "20190808",
remote = "https://github.com/abseil/abseil-cpp",
)

# Google Logging Library
Expand Down Expand Up @@ -62,14 +55,6 @@ http_archive(
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
python_configure(name = "local_config_python")

http_archive(
name = "pybind11_abseil",
sha256 = "6481888831cd548858c09371ea892329b36c8d4d961f559876c64e009d0bc630",
strip_prefix = "pybind11_abseil-3922b3861a2b27d4111e3ac971e6697ea030a36e",
url = "https://github.com/pybind/pybind11_abseil/archive/3922b3861a2b27d4111e3ac971e6697ea030a36e.tar.gz",
patches = ["//trimmed_match:status_module.patch"],
)

# Bazel Skylib library required for Absl C++ library
http_archive(
name = "bazel_skylib",
Expand Down
6 changes: 1 addition & 5 deletions trimmed_match/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@ py_library(
srcs = [
"estimator.py",
],
data = [
"//trimmed_match/core/python:estimator_ext.so",
"@pybind11_abseil//pybind11_abseil:status.so",
],
data = ["//trimmed_match/core/python:estimator_ext.so"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
)

py_test(
name = "estimator_test",
srcs = ["estimator_test.py"],
data = ["@pybind11_abseil//pybind11_abseil:status.so"],
main = "estimator_test.py",
python_version = "PY3",
deps = [":estimator"],
Expand Down
3 changes: 0 additions & 3 deletions trimmed_match/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ cc_library(
deps = [
":geox_data_util",
":math_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@glog",
],
Expand Down
38 changes: 14 additions & 24 deletions trimmed_match/core/estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
#include <vector>

#include "glog/logging.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "absl/types/optional.h"
#include "trimmed_match/core/geox_data_util.h"
#include "trimmed_match/core/math_util.h"
Expand Down Expand Up @@ -81,12 +78,10 @@ TrimmedMatch::TrimmedMatch(const std::vector<double>& delta_response,
<< delta_response.size() << " vs " << delta_cost.size();
}

absl::StatusOr<double> TrimmedMatch::CalculateIroas(
const double trim_rate) const {
if (trim_rate < 0.0 || trim_rate > max_trim_rate_) {
return absl::InvalidArgumentError(absl::Substitute(
"Trim rate must be in (0,$0), but got $1", max_trim_rate_, trim_rate));
}
double TrimmedMatch::CalculateIroas(const double trim_rate) const {
CHECK(trim_rate >= 0.0 && trim_rate <= max_trim_rate_)
<< "Trim rate must be in (0, " << max_trim_rate_ << "), but got "
<< trim_rate;

if (trim_rate == 0.0) {
return geox_util_->CalculateEmpiricalIroas();
Expand All @@ -95,11 +90,8 @@ absl::StatusOr<double> TrimmedMatch::CalculateIroas(
const std::vector<double> candidates =
geox_util_->FindAllZerosOfTrimmedMean(trim_rate);

if (candidates.empty()) {
return absl::InternalError(
"We could not find a root for the TM equation. One likely reason is "
"that the incremental cost for the untrimmed geo pairs is 0.");
}
CHECK(!candidates.empty())
<< "Incremental cost for the untrimmed geo pairs is 0";

if (candidates.size() == 1) {
return candidates[0];
Expand Down Expand Up @@ -158,30 +150,28 @@ double TrimmedMatch::CalculateStandardError(const double trim_rate,
return std::sqrt(approx_variance / num_pairs_);
}

absl::StatusOr<Result> TrimmedMatch::Report(const double normal_quantile,
const double trim_rate) const {
Result TrimmedMatch::Report(const double normal_quantile,
const double trim_rate) const {
TrimAndError result;
std::vector<TrimAndError> candidate_results;

// If trim_rate falls into [0, max_trim_rate_), use it.
// Otherwise, choose a trim rate in that range so that the corresponding
// standard error of the estimate is close to the minimum.
if (trim_rate >= 0.0 && trim_rate <= max_trim_rate_) {
const absl::StatusOr<double> iroas = CalculateIroas(trim_rate);
if (!iroas.ok()) return iroas.status();
const double std_error = CalculateStandardError(trim_rate, *iroas);
result = {trim_rate, *iroas, std_error};
const double iroas = CalculateIroas(trim_rate);
const double std_error = CalculateStandardError(trim_rate, iroas);
result = {trim_rate, iroas, std_error};
candidate_results.push_back(result);
} else {
const int max_num_trim =
static_cast<int>(std::ceil(max_trim_rate_ * num_pairs_));
for (int i = 0; i <= max_num_trim; ++i) {
const double rate = static_cast<double>(i) / num_pairs_;
if (rate > max_trim_rate_) break;
const absl::StatusOr<double> iroas = CalculateIroas(rate);
if (!iroas.ok()) return iroas.status();
const double std_error = CalculateStandardError(rate, *iroas);
candidate_results.push_back({rate, *iroas, std_error});
const double iroas = CalculateIroas(rate);
const double std_error = CalculateStandardError(rate, iroas);
candidate_results.push_back({rate, iroas, std_error});
}

// Choose the result close to, but no more than 1 + 0.25/sqrt(num_pairs)
Expand Down
7 changes: 3 additions & 4 deletions trimmed_match/core/estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <limits>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "trimmed_match/core/geox_data_util.h"
#include "trimmed_match/core/math_util.h"
Expand Down Expand Up @@ -70,7 +69,7 @@ class TrimmedMatch {

// Returns the root of the trimmed mean equation which minimizes
// TrimmedSymmetricNorm().
absl::StatusOr<double> CalculateIroas(double trim_rate) const;
double CalculateIroas(double trim_rate) const;

// Returns the square root of (asymptotic variance / number of pairs), where
// asymptotic variance is given by Eq (8.2) in
Expand All @@ -85,8 +84,8 @@ class TrimmedMatch {
// Otherwise, report the result for the given trim rate.
// Normal_quantile is by default the 90% normal percentile, which corresponds
// to 80% 2-sided confidence interval.
absl::StatusOr<Result> Report(double normal_quantile = 1.281551566,
double trim_rate = -1.0) const;
Result Report(double normal_quantile = 1.281551566,
double trim_rate = -1.0) const;

private:
const double max_trim_rate_;
Expand Down
46 changes: 18 additions & 28 deletions trimmed_match/core/estimator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,18 @@
namespace trimmedmatch {
namespace {

void CompareReports(const Result& expected,
const absl::StatusOr<Result>& result,
void CompareReports(const Result& expected, const Result& result,
const double epsilon = 1e-6) {
if (result.ok()) {
EXPECT_NEAR(expected.estimate, result->estimate, epsilon);
EXPECT_NEAR(expected.std_error, result->std_error, epsilon);
EXPECT_NEAR(expected.trim_rate, result->trim_rate, epsilon);
EXPECT_NEAR(expected.conf_interval_low, result->conf_interval_low, epsilon);
EXPECT_NEAR(expected.conf_interval_up, result->conf_interval_up, epsilon);
EXPECT_EQ(expected.candidate_results.size(),
result->candidate_results.size());

for (size_t i = 0; i < result->candidate_results.size(); ++i) {
EXPECT_NEAR(expected.candidate_results[i].iroas,
result->candidate_results[i].iroas, epsilon);
}
EXPECT_NEAR(expected.estimate, result.estimate, epsilon);
EXPECT_NEAR(expected.std_error, result.std_error, epsilon);
EXPECT_NEAR(expected.trim_rate, result.trim_rate, epsilon);
EXPECT_NEAR(expected.conf_interval_low, result.conf_interval_low, epsilon);
EXPECT_NEAR(expected.conf_interval_up, result.conf_interval_up, epsilon);
EXPECT_EQ(expected.candidate_results.size(), result.candidate_results.size());

for (size_t i = 0; i < result.candidate_results.size(); ++i) {
EXPECT_NEAR(expected.candidate_results[i].iroas,
result.candidate_results[i].iroas, epsilon);
}
}

Expand Down Expand Up @@ -70,21 +66,15 @@ TEST(TrimmedMatchInitialization, TrimmedMatchInvalidInput) {
EXPECT_DEATH(auto result = TrimmedMatch({1, 2, 3}, {1, 2}), "");
}

TEST_F(EstimatorInternalTest, CalculateIroasEmptyRoot) {
TEST(CalculateIroasTest, CalculateIroasEmptyRoot) {
TrimmedMatch trimmed_match({1, 2, 3, 4}, {-10, -1, 1, 10});
EXPECT_FALSE(trimmed_match.CalculateIroas(0.1).ok());
EXPECT_EQ(
trimmed_match.CalculateIroas(0.1).status(),
absl::InternalError(
"We could not find a root for the TM equation. One likely reason is "
"that the incremental cost for the untrimmed geo pairs is 0."));
EXPECT_DEATH(auto iroas = trimmed_match.CalculateIroas(0.25), "");
}

TEST_F(EstimatorInternalTest, InvalidInput) {
EXPECT_FALSE(trimmed_match1_.CalculateIroas(-0.25).ok());
EXPECT_EQ(trimmed_match1_.CalculateIroas(-0.25).status(),
absl::InvalidArgumentError(
"Trim rate must be in (0,0.25), but got -0.25"));
EXPECT_DEATH(auto iroas = trimmed_match1_.CalculateIroas(-0.25), "");
EXPECT_DEATH(auto error = trimmed_match1_.CalculateStandardError(-0.25, 0.0),
"");
}

TEST_F(EstimatorInternalTest, CalculateIroasNoTrim) {
Expand All @@ -96,7 +86,7 @@ TEST_F(EstimatorInternalTest, CalculateIroasNoTrim) {
}

EXPECT_NEAR(total_delta_response / total_delta_cost,
*trimmed_match1_.CalculateIroas(0.0), 1e-6);
trimmed_match1_.CalculateIroas(0.0), 1e-6);
}

TEST_F(EstimatorInternalTest, CalculateIroasWithTrim) {
Expand All @@ -109,7 +99,7 @@ TEST_F(EstimatorInternalTest, CalculateIroasWithTrim) {
}

EXPECT_NEAR(total_delta_response / total_delta_cost,
*trimmed_match1_.CalculateIroas(0.25), 1e-6);
trimmed_match1_.CalculateIroas(0.25), 1e-6);
}

TEST_F(EstimatorInternalTest, CalculateStandardErrorNoTrim) {
Expand Down
5 changes: 1 addition & 4 deletions trimmed_match/core/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,5 @@ pybind_extension(
srcs = ["estimator_ext.cc"],
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
"//trimmed_match/core:estimator",
"@pybind11_abseil//pybind11_abseil:status_casters",
],
deps = ["//trimmed_match/core:estimator"],
)
3 changes: 0 additions & 3 deletions trimmed_match/core/python/estimator_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
#include "trimmed_match/core/estimator.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11_abseil/status_casters.h"

namespace trimmedmatch {

namespace py = pybind11;

PYBIND11_MODULE(estimator_ext, m) {
pybind11::google::ImportStatusModule();

py::class_<TrimAndError>(m, "TrimAndError")
.def(py::init<>())
.def_readwrite("trim_rate", &TrimAndError::trim_rate)
Expand Down
32 changes: 0 additions & 32 deletions trimmed_match/core/python/estimator_test.py

This file was deleted.

7 changes: 1 addition & 6 deletions trimmed_match/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np
from scipy import stats
from trimmed_match.core.python import estimator_ext
from pybind11_abseil.pybind11_abseil import status

# A class to report the Trimmed Match estimator for a fixed trim rate:
# trim_rate: float
Expand Down Expand Up @@ -212,11 +211,7 @@ def Report(self, confidence: float = 0.80, trim_rate: float = -1.0) -> Report:
raise ValueError(f"trim_rate {trim_rate} is greater than max_trim_rate "
f"which is {self._max_trim_rate}.")

try:
output = self._tm.Report(
stats.norm.ppf(0.5 + 0.5 * confidence), trim_rate)
except status.StatusNotOk as e:
raise ValueError(str(e)) from e
output = self._tm.Report(stats.norm.ppf(0.5 + 0.5 * confidence), trim_rate)
epsilons = self._CalculateEpsilons(output.estimate)
temp = np.array(epsilons).argsort()
ranks = np.empty_like(temp)
Expand Down
18 changes: 3 additions & 15 deletions trimmed_match/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,14 @@ def testTrimmedMatchValueError(self):
_ = estimator.TrimmedMatch(self._delta_response, self._delta_cost, -0.1)
# if delta_response and delta_delta have different lengths
with self.assertRaisesRegex(
ValueError, "Lengths of delta_response and delta_spend differ."
):
ValueError, "Lengths of delta_response and delta_spend differ."):
_ = estimator.TrimmedMatch(self._delta_response, self._delta_cost + [1.0])
# if confidence is outside of (0, 1]
tm = estimator.TrimmedMatch(self._delta_response, self._delta_cost)
with self.assertRaisesRegex(
ValueError, r"Confidence is outside of \(0, 1\]"
):
with self.assertRaisesRegex(ValueError,
r"Confidence is outside of \(0, 1\]"):
_ = tm.Report(-0.5, 0.0)

def testTrimmedMatchCppError(self):
# catches errors from C++ code
with self.assertRaisesRegex(
ValueError,
"We could not find a root for the TM equation. One likely reason is"
" that the incremental cost for the untrimmed geo pairs is 0.",
):
tm = estimator.TrimmedMatch([1, 2, 3, 4], [-10, -1, 1, 10], 0.25)
tm.Report(confidence=0.8, trim_rate=0.1)

def testCalculateEpsilons(self):
"""Tests _CalculateEpsilons."""
tm = estimator.TrimmedMatch(self._delta_response, self._delta_cost, 0.25)
Expand Down
11 changes: 0 additions & 11 deletions trimmed_match/status_module.patch

This file was deleted.

0 comments on commit 0be8552

Please sign in to comment.