Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612923307
  • Loading branch information
MediaPipe Team authored and Copybara-Service committed Mar 5, 2024
1 parent 35c70be commit e58d8c5
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {
RET_CHECK_EQ(xnn_status_success,
xnn_create_subgraph(
/*external_value_ids=*/input_tensors_added_order_.size() +
output_tensors.size() + rope_weights_.size(),
output_tensors.size(),
/*flags=*/0, &subgraph_ptr));
RET_CHECK_NE(subgraph_ptr, nullptr);

Expand All @@ -166,10 +166,6 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {
RET_CHECK_EQ(t->tensor_id(subgraph_ptr), XNN_INVALID_VALUE_ID);
t->set_tensor_id(subgraph_ptr, cnt++);
}
for (auto& t : rope_weights_) {
interm_tensors_.erase(t);
t->set_tensor_id(subgraph_ptr, cnt++);
}
}

XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph};
Expand All @@ -183,9 +179,6 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {
for (auto& output : output_tensors) {
MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph));
}
for (auto& t : rope_weights_) {
MP_RETURN_IF_ERROR(t->DefineRope(*subgraph));
}

for (auto& step : build_steps_) {
if (auto s = step(subgraph.get()); !s.ok()) {
Expand All @@ -194,7 +187,6 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {
}

build_steps_.clear();
rope_weights_.clear();
XnnGraph result(std::move(subgraph),
std::make_unique<RuntimeConfigs>(*runtime_configs_));
result.input_tensors_ = std::move(input_tensors_added_order_);
Expand Down Expand Up @@ -765,8 +757,6 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::PerDimScale(

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rope(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos) {
rope_weights_.insert(segment_pos);

const auto& input_dim = input->dims;
const auto& segment_pos_dim = segment_pos->dims;
// B T N H
Expand Down Expand Up @@ -998,7 +988,6 @@ absl::Status XnnGraph::SetupRuntime() {
{
VLOG(3) << "input size " << input_tensors_.size();
VLOG(3) << "output size " << output_tensors_.size();
VLOG(3) << "rope size " << rope_weights_.size();
externals_.clear();
// Init external
for (const auto& input : input_tensors_) {
Expand All @@ -1011,9 +1000,6 @@ absl::Status XnnGraph::SetupRuntime() {
externals_.push_back(xnn_external_value{
output->tensor_id(owned_subgraph_.get()), output->Data()});
}
for (const auto& t : rope_weights_) {
VLOG(3) << "rope id " << t->tensor_id(owned_subgraph_.get());
}
}
RET_CHECK_EQ(
xnn_status_success,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,6 @@ class XnnGraphBuilder {
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
absl::flat_hash_set<std::shared_ptr<Tensor>> static_weights_;

// This is sort of bug that the weights used for rope has to be defined with
// EXTERNAL flag, but with id out of the external range.
absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weights_;

// Caches
absl::flat_hash_map<
size_t /*dim*/,
Expand Down Expand Up @@ -352,8 +348,6 @@ class XnnGraph {
std::vector<std::shared_ptr<Tensor>> input_tensors_;
std::vector<std::shared_ptr<Tensor>> output_tensors_;

absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weights_;

absl::flat_hash_set<std::shared_ptr<Tensor>> static_weights_;
};

Expand Down
12 changes: 7 additions & 5 deletions mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,14 @@ LlmBuilder::PreProcess(std::shared_ptr<Tensor> token_embedding,
InputResource resource;
constexpr absl::string_view kAttnMaskSource = "atten_mask";
constexpr absl::string_view kPosEmbeddingSource = "pos_embedding";
constexpr absl::string_view kSegmentPosSource = "segment_pos";
if (is_prefix) {
MP_ASSIGN_OR_RETURN(resource.atten_mask, NewInput({llm_params_.seq_size_T,
llm_params_.seq_size_T},
kAttnMaskSource));
resource.segment_pos = std::make_shared<Tensor>(
Tensor::DimsType({llm_params_.seq_size_T, llm_params_.head_dim_H}));
MP_ASSIGN_OR_RETURN(resource.segment_pos, NewInput({llm_params_.seq_size_T,
llm_params_.head_dim_H},
kSegmentPosSource));
MP_RETURN_IF_ERROR(
InitSegmentPos(0, llm_params_.seq_size_T, *resource.segment_pos));
MP_ASSIGN_OR_RETURN(
Expand All @@ -466,8 +468,9 @@ LlmBuilder::PreProcess(std::shared_ptr<Tensor> token_embedding,
NewInput({1, llm_params_.model_dim_D}, kPosEmbeddingSource));
MP_ASSIGN_OR_RETURN(resource.atten_mask,
NewInput({1, llm_params_.seq_size_T}, kAttnMaskSource));
resource.segment_pos =
std::make_shared<Tensor>(Tensor::DimsType{1, llm_params_.head_dim_H});
MP_ASSIGN_OR_RETURN(
resource.segment_pos,
NewInput({1, llm_params_.head_dim_H}, kSegmentPosSource));
MP_RETURN_IF_ERROR(InitSegmentPos(0, 1, *resource.segment_pos));
}
const float dim_scale = std::sqrt(llm_params_.model_dim_D);
Expand Down Expand Up @@ -739,7 +742,6 @@ absl::Status LlmBuilder::InitSegmentPos(size_t current_seq_len,
if (!segment_pos_values_) {
MP_RETURN_IF_ERROR(InitSegmentPosValues(rope_size));
}
MP_RETURN_IF_ERROR(out_segment_pos.LoadFromVec({}, /*exact_match=*/false));

out_segment_pos.Resize(Tensor::DimsType{process_seq_len, rope_size});
MP_RETURN_IF_ERROR(out_segment_pos.LoadFromBuffer(
Expand Down
7 changes: 3 additions & 4 deletions mediapipe/tasks/cc/genai/inference/utils/xnn_utils/phi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ Phi2Builder::PreProcess(std::shared_ptr<Tensor> token_embedding,
MP_ASSIGN_OR_RETURN(
resource.atten_mask,
NewInput({llm_params_.seq_size_T, llm_params_.seq_size_T}));
resource.segment_pos = std::make_shared<Tensor>(
Tensor::DimsType({llm_params_.seq_size_T, rope_size}));
MP_ASSIGN_OR_RETURN(resource.segment_pos,
NewInput({llm_params_.seq_size_T, rope_size}));
MP_RETURN_IF_ERROR(
InitSegmentPos(0, llm_params_.seq_size_T, *resource.segment_pos));
} else {
MP_ASSIGN_OR_RETURN(resource.atten_mask,
NewInput({1, llm_params_.seq_size_T}));
resource.segment_pos =
std::make_shared<Tensor>(Tensor::DimsType{1, rope_size});
MP_ASSIGN_OR_RETURN(resource.segment_pos, NewInput({1, rope_size}));
MP_RETURN_IF_ERROR(InitSegmentPos(0, 1, *resource.segment_pos));
}
return std::make_pair(token_embedding, resource);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ Stablelm4E1T3BBuilder::PreProcess(std::shared_ptr<Tensor> token_embedding,
MP_ASSIGN_OR_RETURN(
resource.atten_mask,
NewInput({llm_params_.seq_size_T, llm_params_.seq_size_T}));
resource.segment_pos = std::make_shared<Tensor>(
Tensor::DimsType({llm_params_.seq_size_T, rope_size}));
MP_ASSIGN_OR_RETURN(resource.segment_pos,
NewInput({llm_params_.seq_size_T, rope_size}));
MP_RETURN_IF_ERROR(
InitSegmentPos(0, llm_params_.seq_size_T, *resource.segment_pos));
} else {
MP_ASSIGN_OR_RETURN(resource.atten_mask,
NewInput({1, llm_params_.seq_size_T}));
resource.segment_pos =
std::make_shared<Tensor>(Tensor::DimsType{1, rope_size});
MP_ASSIGN_OR_RETURN(resource.segment_pos, NewInput({1, rope_size}));
MP_RETURN_IF_ERROR(InitSegmentPos(0, 1, *resource.segment_pos));
}
return std::make_pair(token_embedding, resource);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,6 @@ void Tensor::set_tensor_id(xnn_subgraph_t subgraph, uint32_t id) {
map_subgraph_to_tensor_id[subgraph] = id;
}

absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) {
RET_CHECK_NE(tensor_id(&subgraph), XNN_INVALID_VALUE_ID);
return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
}

absl::Status Tensor::LoadFromBuffer(const void* buffer) {
AllocateBufferIfNeeded();
memcpy(Data(), buffer, ElementSize(num_elements));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ struct Tensor {
absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph);
virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags);
absl::Status DefineWeight(xnn_subgraph& subgraph);
absl::Status DefineRope(xnn_subgraph& subgraph);

// Load the tensor from buffer, assuming the buffer is long enough.
absl::Status LoadFromBuffer(const void* buffer);
Expand Down

0 comments on commit e58d8c5

Please sign in to comment.