Skip to content

Commit

Permalink
[116] [VSQ] Adding polish and fixes.
Browse files Browse the repository at this point in the history
This cl adds a few things: 1) it now probably handles model updates and reads metadata from model server, 2) we ensure that iframe is loaded before sending postMessage to iFrame, 3) we ensure that companion page is loaded before we do any visual classification, 4) we limit number of results to 2.

(cherry picked from commit c5dd506)

Bug: 1449021, b:284645527
Change-Id: I3ef85ca185c5ce2ba59c446e0a39a356b0b16629
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4638949
Reviewed-by: Michael Crouse <mcrouse@chromium.org>
Commit-Queue: Pierre St Juste <pstjuste@google.com>
Cr-Original-Commit-Position: refs/heads/main@{#1162173}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4660790
Reviewed-by: Ali Stanfield <stanfield@google.com>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/branch-heads/5845@{#267}
Cr-Branched-From: 5a5dff6-refs/heads/main@{#1160321}
  • Loading branch information
Pierre St Juste authored and Chromium LUCI CQ committed Jun 30, 2023
1 parent 5336205 commit 8d096ae
Show file tree
Hide file tree
Showing 20 changed files with 245 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "base/task/thread_pool.h"
#include "chrome/browser/companion/visual_search/features.h"
#include "chrome/browser/companion/visual_search/visual_search_suggestions_service.h"
#include "content/public/browser/render_frame_host.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/blink/public/common/associated_interfaces/associated_interface_provider.h"
Expand Down Expand Up @@ -78,20 +79,42 @@ void VisualSearchClassifierHost::HandleClassification(
std::move(result_callback_).Run(std::move(data_uris));
}
result_callback_.Reset();
result_handler_.reset();
}

void VisualSearchClassifierHost::StartClassification(
content::RenderFrameHost* render_frame_host,
const GURL& validated_url,
ResultCallback callback) {
base::File model = visual_search_service_->GetModelFile();
if (!render_frame_host) {
LOCAL_HISTOGRAM_BOOLEAN("Companion.VisualSearch.EmptyRenderFrame", true);
return;
}

current_url_ = validated_url;
visual_search_service_->SetModelUpdateCallback(
base::BindOnce(&VisualSearchClassifierHost::StartClassificationWithModel,
weak_ptr_factory_.GetWeakPtr(), render_frame_host,
validated_url, std::move(callback)));
}

void VisualSearchClassifierHost::StartClassificationWithModel(
content::RenderFrameHost* render_frame_host,
const GURL validated_url,
ResultCallback callback,
base::File model,
std::string base64_config) {
LOCAL_HISTOGRAM_BOOLEAN("Companion.VisualSearch.ModelFileSuccess",
model.IsValid());
if (!model.IsValid()) {
return;
}

std::string base64_config;
if (validated_url != current_url_) {
LOCAL_HISTOGRAM_BOOLEAN("Companion.VisualSearch.MismatchURL", true);
return;
}

absl::optional<std::string> config_switch =
switches::GetVisualSearchConfigForCompanionOverride();

Expand All @@ -106,7 +129,7 @@ void VisualSearchClassifierHost::StartClassification(
mojo::AssociatedRemote<mojom::VisualSuggestionsRequestHandler> visual_search;
render_frame_host->GetRemoteAssociatedInterfaces()->GetInterface(
&visual_search);
if (visual_search.is_bound()) {
if (visual_search.is_bound() && !result_handler_.is_bound()) {
visual_search->StartVisualClassification(
std::move(model), base64_config,
result_handler_.BindNewPipeAndPassRemote());
Expand All @@ -120,4 +143,11 @@ void VisualSearchClassifierHost::StartClassification(
LOCAL_HISTOGRAM_BOOLEAN("Companion.VisualSearch.StartClassificationSuccess",
visual_search.is_bound());
}

void VisualSearchClassifierHost::CancelClassification() {
result_callback_.Reset();
current_url_ = GURL::EmptyGURL();
LOCAL_HISTOGRAM_BOOLEAN("Companion.VisualSearch.ClassificationCancelled",
true);
}
} // namespace companion::visual_search
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define CHROME_BROWSER_COMPANION_VISUAL_SEARCH_VISUAL_SEARCH_CLASSIFIER_HOST_H_

#include <memory>
#include "base/memory/weak_ptr.h"
#include "chrome/browser/companion/visual_search/visual_search_suggestions_service.h"
#include "chrome/common/companion/visual_search.mojom.h"
#include "content/public/browser/render_frame_host.h"
Expand Down Expand Up @@ -44,7 +45,22 @@ class VisualSearchClassifierHost : mojom::VisualSuggestionsResultHandler {
const GURL& validated_url,
ResultCallback callback);

// Used to cancel and cleanup any ongoing classification; currently it
// mainly tracks the model fetching step.
void CancelClassification();

private:
// This method performs the actual mojom IPC to start classifier agent after
// we have obtained the model from |visual_search_service_|.
void StartClassificationWithModel(content::RenderFrameHost* render_frame_host,
const GURL validated_url,
ResultCallback callback,
base::File file,
std::string base64_config);

// Used to track the url that is currently being processed.
GURL current_url_;

// Pointer to visual search service which we do not own.
raw_ptr<VisualSearchSuggestionsService> visual_search_service_ = nullptr;

Expand All @@ -53,6 +69,9 @@ class VisualSearchClassifierHost : mojom::VisualSuggestionsResultHandler {

// This reference binds this class to the result handler for the mojom IPC.
mojo::Receiver<mojom::VisualSuggestionsResultHandler> result_handler_{this};

// Pointer factory necessary for scheduling tasks on different threads.
base::WeakPtrFactory<VisualSearchClassifierHost> weak_ptr_factory_{this};
};
} // namespace companion::visual_search

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,23 @@ namespace companion::visual_search {

namespace {

static const char kModelFilename[] = "visual_model.tflite";
constexpr char kValidUrl[] = "https://foo.com/";

base::FilePath model_file_path() {
base::FilePath source_root_dir;
base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
return source_root_dir.AppendASCII("chrome")
.AppendASCII("test")
.AppendASCII("data")
.AppendASCII("companion_visual_search")
.AppendASCII("test-model-quantized.tflite");
}

} // namespace

class VisualSearchClassifierHostTest : public ChromeRenderViewHostTestHarness {
public:
VisualSearchClassifierHostTest() : url_("www.style-files.com") {}
VisualSearchClassifierHostTest() : url_(kValidUrl) {}
~VisualSearchClassifierHostTest() override = default;

void SetUp() override {
Expand All @@ -64,24 +74,20 @@ class VisualSearchClassifierHostTest : public ChromeRenderViewHostTestHarness {
void SetModelPath() {
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir);
test_data_dir = test_data_dir.AppendASCII("components/test/data");

base::flat_set<base::FilePath> additional_files;
additional_files.insert(test_data_dir.AppendASCII(kModelFilename));
additional_files.insert(model_file_path());

model_info_ =
optimization_guide::TestModelInfoBuilder()
.SetModelFilePath(test_data_dir.AppendASCII(kModelFilename))
.SetAdditionalFiles(additional_files)
.SetVersion(123)
.Build();
model_info_ = optimization_guide::TestModelInfoBuilder()
.SetModelFilePath(model_file_path())
.SetAdditionalFiles(additional_files)
.SetVersion(123)
.Build();

service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_VISUAL_SEARCH_CLASSIFICATION,
*model_info_);

base::RunLoop().RunUntilIdle();
}

void TearDown() override {
Expand All @@ -107,6 +113,7 @@ TEST_F(VisualSearchClassifierHostTest, StartClassification) {
base::BindOnce([](std::vector<std::string> results) {});
visual_search_host_->StartClassification(
web_contents()->GetPrimaryMainFrame(), url_, std::move(callback));
base::RunLoop().RunUntilIdle();
histogram_tester_.ExpectBucketCount("Companion.VisualSearch.ModelFileSuccess",
true, 1);
histogram_tester_.ExpectBucketCount(
Expand All @@ -122,6 +129,7 @@ TEST_F(VisualSearchClassifierHostTest, StartClassification_WithOverride) {
base::BindOnce([](std::vector<std::string> results) {});
visual_search_host_->StartClassification(
web_contents()->GetPrimaryMainFrame(), url_, std::move(callback));
base::RunLoop().RunUntilIdle();
histogram_tester_.ExpectBucketCount("Companion.VisualSearch.ModelFileSuccess",
true, 1);
histogram_tester_.ExpectBucketCount(
Expand All @@ -133,19 +141,28 @@ TEST_F(VisualSearchClassifierHostTest, StartClassification_NoModelSet) {
base::BindOnce([](std::vector<std::string> results) {});
visual_search_host_->StartClassification(
web_contents()->GetPrimaryMainFrame(), url_, std::move(callback));
histogram_tester_.ExpectBucketCount("Companion.VisualSearch.ModelFileSuccess",
false, 1);
base::RunLoop().RunUntilIdle();

// ModelFileSuccess is never called because the |OnModelUpdate| is never
// called by the |service_| since we never setup the model path.
histogram_tester_.ExpectTotalCount("Companion.VisualSearch.ModelFileSuccess",
0);
}

TEST_F(VisualSearchClassifierHostTest,
StartClassification_NoModelSetAndNoCallbackSet) {
base::HistogramTester histogram_tester;
TEST_F(VisualSearchClassifierHostTest, StartClassification_WithCancellation) {
SetModelPath();
VisualSearchClassifierHost::ResultCallback callback =
base::BindOnce([](std::vector<std::string> results) {});
visual_search_host_->StartClassification(
web_contents()->GetPrimaryMainFrame(), url_, std::move(callback));
visual_search_host_->CancelClassification();
base::RunLoop().RunUntilIdle();
histogram_tester_.ExpectBucketCount(
"Companion.VisualSearch.ClassificationCancelled", true, 1);
histogram_tester_.ExpectBucketCount("Companion.VisualSearch.ModelFileSuccess",
false, 1);
true, 1);
histogram_tester_.ExpectBucketCount("Companion.VisualSearch.MismatchURL",
true, 1);
}

} // namespace companion::visual_search
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
#include "base/logging.h"
#include "base/task/sequenced_task_runner.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "content/public/browser/browser_thread.h"

using content::BrowserThread;

namespace companion::visual_search {

namespace {
Expand All @@ -35,6 +34,17 @@ void CloseModelFile(base::File model_file) {
model_file.Close();
}

// Extracts the model string value from the model metadata.
// The model string is expected to be a serialized string of the
// |EligibilitySpec| proto.
std::string GetModelSpec(ModelMetadata& metadata) {
std::string model_spec;
if (metadata.has_value() && metadata->has_eligibility_spec()) {
metadata->eligibility_spec().SerializeToString(&model_spec);
}
return model_spec;
}

} // namespace

VisualSearchSuggestionsService::VisualSearchSuggestionsService(
Expand Down Expand Up @@ -81,6 +91,11 @@ void VisualSearchSuggestionsService::OnModelFileLoaded(base::File model_file) {
FROM_HERE, base::BindOnce(&CloseModelFile, std::move(*model_file_)));
}
model_file_ = std::move(model_file);
for (auto& callback : model_callbacks_) {
std::move(callback).Run(model_file_->Duplicate(),
GetModelSpec(model_metadata_));
}
model_callbacks_.clear();
}

void VisualSearchSuggestionsService::OnModelUpdated(
Expand All @@ -91,17 +106,29 @@ void VisualSearchSuggestionsService::OnModelUpdated(
OPTIMIZATION_TARGET_VISUAL_SEARCH_CLASSIFICATION) {
return;
}

const absl::optional<optimization_guide::proto::Any>& metadata =
model_info.GetModelMetadata();

if (metadata.has_value()) {
model_metadata_ = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::VisualSearchModelMetadata>(metadata.value());
}

background_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE, base::BindOnce(&LoadModelFile, model_info.GetModelFilePath()),
base::BindOnce(&VisualSearchSuggestionsService::OnModelFileLoaded,
weak_ptr_factory_.GetWeakPtr()));
}

base::File VisualSearchSuggestionsService::GetModelFile() {
void VisualSearchSuggestionsService::SetModelUpdateCallback(
ModelUpdateCallback callback) {
if (model_file_) {
return model_file_->Duplicate();
std::move(callback).Run(model_file_->Duplicate(),
GetModelSpec(model_metadata_));
return;
}
return base::File();
model_callbacks_.emplace_back(std::move(callback));
}

} // namespace companion::visual_search
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,24 @@
#include "base/functional/callback_forward.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/core/optimization_target_model_observer.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/visual_search_model_metadata.pb.h"

namespace optimization_guide {
class OptimizationGuideModelProvider;
} // namespace optimization_guide

namespace companion::visual_search {

using ModelMetadata =
absl::optional<optimization_guide::proto::VisualSearchModelMetadata>;

class VisualSearchSuggestionsService
: public KeyedService,
public optimization_guide::OptimizationTargetModelObserver {
public:
using ModelUpdateCallback = base::OnceCallback<void(base::File, std::string)>;

VisualSearchSuggestionsService(
optimization_guide::OptimizationGuideModelProvider* model_provider,
const scoped_refptr<base::SequencedTaskRunner>& background_task_runner);
Expand All @@ -37,15 +44,28 @@ class VisualSearchSuggestionsService
optimization_guide::proto::OptimizationTarget optimization_target,
const optimization_guide::ModelInfo& model_info) override;

// Simple getter to access the model file.
base::File GetModelFile();
// Registers a callback used when model file is available or updated.
void SetModelUpdateCallback(ModelUpdateCallback callback);

private:
void OnModelFileLoaded(base::File model_file);

// Maintain list of callbacks for observers of model updates.
std::vector<ModelUpdateCallback> model_callbacks_;

// Represents the model that we send to the classifier agent.
absl::optional<base::File> model_file_;

// Used to store the model metadata returned from model provider.
ModelMetadata model_metadata_;

// Pointer to the model provider that we use to fetch classifier models.
raw_ptr<optimization_guide::OptimizationGuideModelProvider> model_provider_;

// Background task runner needed to perform I/O operations.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;

// Pointer factory necessary for scheduling tasks on different threads.
base::WeakPtrFactory<VisualSearchSuggestionsService> weak_ptr_factory_{this};
};

Expand Down

0 comments on commit 8d096ae

Please sign in to comment.