Skip to content

Commit

Permalink
Add a Session interface to OnDeviceModel
Browse files Browse the repository at this point in the history
This will allow adding context to an input before executing. The model
can only handle one session at a time, so sessions will be queued
until the previous one finished.

Bug: b/304353973
Change-Id: I4b3eb4921c94676028b1e4314394c551d9f70123
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4982139
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: Nasko Oskov <nasko@chromium.org>
Reviewed-by: Ken Rockot <rockot@google.com>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1217201}
  • Loading branch information
clarkduvall authored and Chromium LUCI CQ committed Oct 30, 2023
1 parent 277b96a commit 5655160
Show file tree
Hide file tree
Showing 16 changed files with 287 additions and 38 deletions.
10 changes: 7 additions & 3 deletions chrome/browser/resources/on_device_internals/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.m

import {getTemplate} from './app.html.js';
import {BrowserProxy} from './browser_proxy.js';
import {OnDeviceModelRemote, PerformanceClass, StreamingResponderCallbackRouter} from './on_device_model.mojom-webui.js';
import {OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter} from './on_device_model.mojom-webui.js';

interface Response {
text: string;
Expand Down Expand Up @@ -97,6 +97,7 @@ class OnDeviceInternalsAppElement extends PolymerElement {
private model_: OnDeviceModelRemote|null;
private performanceClassText_: string;
private responses_: Response[];
private session_: SessionRemote|null = null;
private text_: string;

private proxy_: BrowserProxy = BrowserProxy.getInstance();
Expand Down Expand Up @@ -141,6 +142,8 @@ class OnDeviceInternalsAppElement extends PolymerElement {
this.error_ = result.error;
} else {
this.model_ = result.model || null;
this.session_ = new SessionRemote();
this.model_?.startSession(this.session_.$.bindNewPipeAndPassReceiver());
this.modelPath_ = modelPath;
}
}
Expand All @@ -150,11 +153,12 @@ class OnDeviceInternalsAppElement extends PolymerElement {
}

private onExecute_() {
if (this.model_ === null) {
if (this.session_ === null) {
return;
}
const router = new StreamingResponderCallbackRouter();
this.model_.execute(this.text_, router.$.bindNewPipeAndPassRemote());
this.session_.execute(
{text: this.text_}, router.$.bindNewPipeAndPassRemote());
const onResponseId = router.onResponse.addListener((text: string) => {
this.set(
'currentResponse_.response',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ void OnDeviceModelServiceController::Execute(
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
streaming_responder) {
if (model_remote_) {
model_remote_->Execute(std::string(input), std::move(streaming_responder));
model_remote_->StartSession(session_remote_.BindNewPipeAndPassReceiver());
session_remote_->Execute(on_device_model::mojom::InputOptions::New(
std::string(input), std::nullopt),
std::move(streaming_responder));
return;
}
LaunchService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class OnDeviceModelServiceController {
base::FilePath model_path_;
mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;
mojo::Remote<on_device_model::mojom::Session> session_remote_;

SEQUENCE_CHECKER(sequence_checker_);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@

namespace optimization_guide {

class FakeOnDeviceModel : public on_device_model::mojom::OnDeviceModel {
class FakeOnDeviceModel : public on_device_model::mojom::OnDeviceModel,
public on_device_model::mojom::Session {
// on_device_model::mojom::OnDeviceModel:
void Execute(const std::string& input,
void StartSession(
mojo::PendingReceiver<on_device_model::mojom::Session> session) override {
receivers_.Add(this, std::move(session));
}

// on_device_model::mojom::Session:
void AddContext(on_device_model::mojom::InputOptionsPtr input) override {}

void Execute(on_device_model::mojom::InputOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
response) override {
mojo::Remote<on_device_model::mojom::StreamingResponder> remote(
std::move(response));
remote->OnResponse("Model starting\n");
remote->OnResponse("Input: " + input + "\n");
remote->OnResponse("Input: " + input->text + "\n");
remote->OnComplete();
}

private:
mojo::ReceiverSet<on_device_model::mojom::Session> receivers_;
};

class FakeOnDeviceModelService
Expand Down
1 change: 1 addition & 0 deletions services/on_device_model/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ template("model_service") {
public_deps = [
"//base",
"//mojo/public/cpp/bindings",
"//services/on_device_model/public/cpp",
"//services/on_device_model/public/mojom",
]
defines = [ "IS_ON_DEVICE_MODEL_IMPL" ]
Expand Down
42 changes: 33 additions & 9 deletions services/on_device_model/on_device_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,59 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/on_device_model/public/cpp/on_device_model.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/on_device_model/on_device_model_service.h"

namespace on_device_model {
namespace {

class OnDeviceModel : public mojom::OnDeviceModel {
class SessionImpl : public OnDeviceModel::Session {
public:
OnDeviceModel() = default;
~OnDeviceModel() override = default;
SessionImpl() = default;
~SessionImpl() override = default;

OnDeviceModel(const OnDeviceModel&) = delete;
OnDeviceModel& operator=(const OnDeviceModel&) = delete;
SessionImpl(const SessionImpl&) = delete;
SessionImpl& operator=(const SessionImpl&) = delete;

void AddContext(mojom::InputOptionsPtr input) override {
context_.push_back(input->text);
}

void Execute(
const std::string& input,
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override {
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
remote->OnResponse("Input: " + input + "\n");
for (const std::string& context : context_) {
remote->OnResponse("Context: " + context + "\n");
}
remote->OnResponse("Input: " + input->text + "\n");
remote->OnComplete();
}

private:
std::vector<std::string> context_;
};

class OnDeviceModelImpl : public OnDeviceModel {
public:
OnDeviceModelImpl() = default;
~OnDeviceModelImpl() override = default;

OnDeviceModelImpl(const OnDeviceModelImpl&) = delete;
OnDeviceModelImpl& operator=(const OnDeviceModelImpl&) = delete;

std::unique_ptr<Session> CreateSession() override {
return std::make_unique<SessionImpl>();
}
};

} // namespace

// static
std::unique_ptr<mojom::OnDeviceModel> OnDeviceModelService::CreateModel(
std::unique_ptr<OnDeviceModel> OnDeviceModelService::CreateModel(
ModelAssets assets) {
return std::make_unique<OnDeviceModel>();
return std::make_unique<OnDeviceModelImpl>();
}

// static
Expand Down
42 changes: 32 additions & 10 deletions services/on_device_model/on_device_model_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,48 @@
#include "services/on_device_model/chrome_ml_instance.h"
#include "services/on_device_model/on_device_model_service.h"
#include "services/on_device_model/public/cpp/model_assets.h"
#include "services/on_device_model/public/cpp/on_device_model.h"
#include "third_party/ml/public/on_device_model_executor.h"
#include "third_party/ml/public/utils.h"

namespace on_device_model {

namespace {

class OnDeviceModel : public mojom::OnDeviceModel {
// TODO(cduvall): Implement sessions in ml::OnDeviceModelExecutor.
class SessionImpl : public OnDeviceModel::Session {
public:
explicit OnDeviceModel(std::unique_ptr<ml::OnDeviceModelExecutor> executor)
: executor_(std::move(executor)) {}
~OnDeviceModel() override = default;
explicit SessionImpl(ml::OnDeviceModelExecutor* executor)
: executor_(executor) {}
~SessionImpl() override = default;

SessionImpl(const SessionImpl&) = delete;
SessionImpl& operator=(const SessionImpl&) = delete;

OnDeviceModel(const OnDeviceModel&) = delete;
OnDeviceModel& operator=(const OnDeviceModel&) = delete;
void AddContext(mojom::InputOptionsPtr input) override {}

void Execute(
const std::string& input,
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override {
executor_->Execute(input, std::move(response));
executor_->Execute(input->text, std::move(response));
}

private:
raw_ptr<ml::OnDeviceModelExecutor> executor_;
};

class OnDeviceModelImpl : public OnDeviceModel {
public:
explicit OnDeviceModelImpl(
std::unique_ptr<ml::OnDeviceModelExecutor> executor)
: executor_(std::move(executor)) {}
~OnDeviceModelImpl() override = default;

OnDeviceModelImpl(const OnDeviceModelImpl&) = delete;
OnDeviceModelImpl& operator=(const OnDeviceModelImpl&) = delete;

std::unique_ptr<Session> CreateSession() override {
return std::make_unique<SessionImpl>(executor_.get());
}

private:
Expand All @@ -37,7 +59,7 @@ class OnDeviceModel : public mojom::OnDeviceModel {
} // namespace

// static
std::unique_ptr<mojom::OnDeviceModel> OnDeviceModelService::CreateModel(
std::unique_ptr<OnDeviceModel> OnDeviceModelService::CreateModel(
ModelAssets assets) {
if (!GetChromeMLInstance()) {
return nullptr;
Expand All @@ -48,7 +70,7 @@ std::unique_ptr<mojom::OnDeviceModel> OnDeviceModelService::CreateModel(
if (!executor) {
return nullptr;
}
return std::make_unique<OnDeviceModel>(std::move(executor));
return std::make_unique<OnDeviceModelImpl>(std::move(executor));
}

// static
Expand Down
57 changes: 56 additions & 1 deletion services/on_device_model/on_device_model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,62 @@

#include "services/on_device_model/on_device_model_service.h"

#include "services/on_device_model/public/cpp/on_device_model.h"

namespace on_device_model {
namespace {

class SessionWrapper : public mojom::Session {
public:
SessionWrapper(mojo::PendingReceiver<mojom::Session> receiver,
std::unique_ptr<OnDeviceModel::Session> session)
: receiver_(this, std::move(receiver)), session_(std::move(session)) {}
~SessionWrapper() override = default;

SessionWrapper(const SessionWrapper&) = delete;
SessionWrapper& operator=(const SessionWrapper&) = delete;

void AddContext(mojom::InputOptionsPtr input) override {
session_->AddContext(std::move(input));
}

void Execute(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override {
session_->Execute(std::move(input), std::move(response));
}

mojo::Receiver<mojom::Session>& receiver() { return receiver_; }

private:
mojo::Receiver<mojom::Session> receiver_;
std::unique_ptr<OnDeviceModel::Session> session_;
};

class ModelWrapper : public mojom::OnDeviceModel {
public:
explicit ModelWrapper(std::unique_ptr<on_device_model::OnDeviceModel> model)
: model_(std::move(model)) {}
~ModelWrapper() override = default;

ModelWrapper(const ModelWrapper&) = delete;
ModelWrapper& operator=(const ModelWrapper&) = delete;

void StartSession(mojo::PendingReceiver<mojom::Session> session) override {
current_session_ = std::make_unique<SessionWrapper>(
std::move(session), model_->CreateSession());
current_session_->receiver().set_disconnect_handler(base::BindOnce(
&ModelWrapper::SessionDisconnected, base::Unretained(this)));
}

private:
void SessionDisconnected() { current_session_.reset(); }

std::unique_ptr<SessionWrapper> current_session_;
std::unique_ptr<on_device_model::OnDeviceModel> model_;
};

} // namespace

OnDeviceModelService::OnDeviceModelService(
mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
Expand All @@ -22,7 +77,7 @@ void OnDeviceModelService::LoadModel(ModelAssets assets,
}

mojo::PendingRemote<mojom::OnDeviceModel> remote;
model_receivers_.Add(std::move(model),
model_receivers_.Add(std::make_unique<ModelWrapper>(std::move(model)),
remote.InitWithNewPipeAndPassReceiver());
std::move(callback).Run(mojom::LoadModelResult::NewModel(std::move(remote)));
}
Expand Down
3 changes: 2 additions & 1 deletion services/on_device_model/on_device_model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "base/component_export.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "services/on_device_model/public/cpp/on_device_model.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"

namespace on_device_model {
Expand Down Expand Up @@ -35,7 +36,7 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL) OnDeviceModelService
GetEstimatedPerformanceClassCallback callback) override;

private:
static std::unique_ptr<mojom::OnDeviceModel> CreateModel(ModelAssets assets);
static std::unique_ptr<OnDeviceModel> CreateModel(ModelAssets assets);

mojo::Receiver<mojom::OnDeviceModelService> receiver_;
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_receivers_;
Expand Down

0 comments on commit 5655160

Please sign in to comment.