Skip to content

Commit

Permalink
[misc] Add SNode to offline-cache key (taichi-dev#4716)
Browse files Browse the repository at this point in the history
* Add SNode to offline-cache key

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix solution of hashing snode

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and k-ye committed May 5, 2022
1 parent 8ff50ad commit d299122
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 24 deletions.
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2389,7 +2389,7 @@ FunctionType CodeGenLLVM::gen() {
std::string kernel_key;
if (config.offline_cache && this->supports_offline_cache() &&
!kernel->is_evaluator) {
kernel_key = get_offline_cache_key(&kernel->program->config, kernel);
kernel_key = get_hashed_offline_cache_key(&kernel->program->config, kernel);

LlvmOfflineCacheFileReader reader(config.offline_cache_file_path);
LlvmOfflineCache::KernelCacheData cache_data;
Expand Down
91 changes: 85 additions & 6 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
#include "taichi/ir/expr.h"
#include "taichi/ir/expression.h"
#include "taichi/ir/frontend_ir.h"
#include "taichi/program/program.h"
#include "taichi/llvm/llvm_offline_cache.h"

namespace taichi {
namespace lang {

class ExpressionHumanFriendlyPrinter : public ExpressionVisitor {
class ExpressionPrinter : public ExpressionVisitor {
public:
ExpressionHumanFriendlyPrinter(std::ostream *os = nullptr) : os_(os) {
ExpressionPrinter(std::ostream *os = nullptr) : os_(os) {
}

void set_ostream(std::ostream *os) {
Expand All @@ -21,6 +23,16 @@ class ExpressionHumanFriendlyPrinter : public ExpressionVisitor {
return *os_;
}

private:
std::ostream *os_{nullptr};
};

class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
public:
explicit ExpressionHumanFriendlyPrinter(std::ostream *os = nullptr)
: ExpressionPrinter(os) {
}

void visit(ExprGroup &expr_group) override {
emit_vector(expr_group.exprs);
}
Expand Down Expand Up @@ -212,11 +224,10 @@ class ExpressionHumanFriendlyPrinter : public ExpressionVisitor {
return oss.str();
}

private:
protected:
template <typename... Args>
void emit(Args &&... args) {
TI_ASSERT(os_);
(*os_ << ... << std::forward<Args>(args));
(this->get_ostream() << ... << std::forward<Args>(args));
}

template <typename T>
Expand All @@ -243,8 +254,76 @@ class ExpressionHumanFriendlyPrinter : public ExpressionVisitor {
emit(std::forward<D>(e));
}
}
};

std::ostream *os_{nullptr};
// Temporary reuse ExpressionHumanFriendlyPrinter
class ExpressionOfflineCacheKeyGenerator
: public ExpressionHumanFriendlyPrinter {
public:
explicit ExpressionOfflineCacheKeyGenerator(Program *prog,
std::ostream *os = nullptr)
: ExpressionHumanFriendlyPrinter(os), prog_(prog) {
}

void visit(GlobalVariableExpression *expr) override {
emit("#", expr->ident.name());
if (expr->snode) {
emit("(snode=", this->get_hashed_key_of_snode(expr->snode), ')');
} else {
emit("(dt=", expr->dt->to_string(), ')');
}
}

void visit(GlobalPtrExpression *expr) override {
if (expr->snode) {
emit(this->get_hashed_key_of_snode(expr->snode));
} else {
expr->var->accept(this);
}
emit('[');
emit_vector(expr->indices.exprs);
emit(']');
}

void visit(SNodeOpExpression *expr) override {
emit(snode_op_type_name(expr->op_type));
emit('(', this->get_hashed_key_of_snode(expr->snode), ", [");
emit_vector(expr->indices.exprs);
emit(']');
if (expr->value.expr) {
emit(' ');
expr->value->accept(this);
}
emit(')');
}

private:
const std::string &cache_snode_tree_key(int snode_tree_id,
std::string &&key) {
if (snode_tree_id >= snode_tree_key_cache_.size()) {
snode_tree_key_cache_.resize(snode_tree_id + 1);
}
return snode_tree_key_cache_[snode_tree_id] = std::move(key);
}

std::string get_hashed_key_of_snode(SNode *snode) {
TI_ASSERT(snode && prog_);
auto snode_tree_id = snode->get_snode_tree_id();
std::string res;
if (snode_tree_id < snode_tree_key_cache_.size() &&
!snode_tree_key_cache_[snode_tree_id].empty()) {
res = snode_tree_key_cache_[snode_tree_id];
} else {
auto *snode_tree_root = prog_->get_snode_root(snode_tree_id);
auto snode_tree_key =
get_hashed_offline_cache_key_of_snode(snode_tree_root);
res = cache_snode_tree_key(snode_tree_id, std::move(snode_tree_key));
}
return res.append(std::to_string(snode->id));
}

Program *prog_{nullptr};
std::vector<std::string> snode_tree_key_cache_;
};

} // namespace lang
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ void full_simplify(IRNode *root,
const CompileConfig &config,
const FullSimplifyPass::Args &args);
void print(IRNode *root, std::string *output = nullptr);
void gen_offline_cache_key(Program *program, IRNode *root, std::string *output);
void frontend_type_check(IRNode *root);
void lower_ast(IRNode *root);
void type_check(IRNode *root, const CompileConfig &config);
Expand Down
86 changes: 83 additions & 3 deletions taichi/llvm/llvm_offline_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,91 @@ get_offline_cache_key_of_compile_config(CompileConfig *config) {
return serializer.data;
}

std::string get_offline_cache_key(CompileConfig *config, Kernel *kernel) {
static TI_FORCE_INLINE void get_offline_cache_key_of_snode_impl(
SNode *snode,
BinaryOutputSerializer &serializer) {
for (auto &c : snode->ch) {
get_offline_cache_key_of_snode_impl(c.get(), serializer);
}
for (int i = 0; i < taichi_max_num_indices; ++i) {
auto &extractor = snode->extractors[i];
serializer(extractor.num_elements_from_root);
serializer(extractor.shape);
serializer(extractor.acc_shape);
serializer(extractor.num_bits);
serializer(extractor.acc_offset);
serializer(extractor.active);
}
serializer(snode->index_offsets);
serializer(snode->num_active_indices);
serializer(snode->physical_index_position);
serializer(snode->id);
serializer(snode->depth);
serializer(snode->name);
serializer(snode->num_cells_per_container);
serializer(snode->total_num_bits);
serializer(snode->total_bit_start);
serializer(snode->chunk_size);
serializer(snode->cell_size_bytes);
serializer(snode->offset_bytes_in_parent_cell);
if (snode->physical_type) {
serializer(snode->physical_type->to_string());
}
serializer(snode->dt->to_string());
serializer(snode->has_ambient);
if (!snode->ambient_val.dt->is_primitive(PrimitiveTypeID::unknown)) {
serializer(snode->ambient_val.stringify());
}
if (snode->grad_info && !snode->grad_info->is_primal()) {
if (auto *grad_snode = snode->grad_info->grad_snode()) {
get_offline_cache_key_of_snode_impl(grad_snode, serializer);
}
}
if (snode->exp_snode) {
get_offline_cache_key_of_snode_impl(snode->exp_snode, serializer);
}
serializer(snode->bit_offset);
serializer(snode->placing_shared_exp);
serializer(snode->owns_shared_exponent);
for (auto s : snode->exponent_users) {
get_offline_cache_key_of_snode_impl(s, serializer);
}
if (snode->currently_placing_exp_snode) {
get_offline_cache_key_of_snode_impl(snode->currently_placing_exp_snode,
serializer);
}
if (snode->currently_placing_exp_snode_dtype) {
serializer(snode->currently_placing_exp_snode_dtype->to_string());
}
serializer(snode->is_bit_level);
serializer(snode->is_path_all_dense);
serializer(snode->node_type_name);
serializer(snode->type);
serializer(snode->_morton);
serializer(snode->get_snode_tree_id());
}

std::string get_hashed_offline_cache_key_of_snode(SNode *snode) {
TI_ASSERT(snode);

BinaryOutputSerializer serializer;
serializer.initialize();
get_offline_cache_key_of_snode_impl(snode, serializer);
serializer.finalize();

picosha2::hash256_one_by_one hasher;
hasher.process(serializer.data.begin(), serializer.data.end());
hasher.finish();

return picosha2::get_hash_hex_string(hasher);
}

std::string get_hashed_offline_cache_key(CompileConfig *config,
Kernel *kernel) {
std::string kernel_ast_string;
if (kernel) {
irpass::re_id(kernel->ir.get());
irpass::print(kernel->ir.get(), &kernel_ast_string);
irpass::gen_offline_cache_key(kernel->program, kernel->ir.get(),
&kernel_ast_string);
}

std::vector<std::uint8_t> compile_config_key;
Expand Down
6 changes: 4 additions & 2 deletions taichi/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

#include "taichi/common/core.h"
#include "taichi/program/kernel.h"
#include "taichi/llvm/llvm_fwd.h"
#include "taichi/util/io.h"

#include "llvm/IR/Module.h"

namespace taichi {
namespace lang {

std::string get_offline_cache_key(CompileConfig *config, Kernel *kernel);
std::string get_hashed_offline_cache_key_of_snode(SNode *snode);
std::string get_hashed_offline_cache_key(CompileConfig *config, Kernel *kernel);

struct LlvmOfflineCache {
struct OffloadedTaskCacheData {
Expand Down
36 changes: 24 additions & 12 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ std::string to_string(const LaneAttribute<LocalAddress> &ptr) {

class IRPrinter : public IRVisitor {
private:
ExpressionHumanFriendlyPrinter expr_printer_;
ExpressionPrinter *expr_printer_{nullptr};

public:
int current_indent;
int current_indent{0};

std::string *output;
std::string *output{nullptr};
std::stringstream ss;

IRPrinter(std::string *output = nullptr) : output(output) {
current_indent = 0;
IRPrinter(ExpressionPrinter *expr_printer = nullptr,
std::string *output = nullptr)
: expr_printer_(expr_printer), output(output) {
}

template <typename... Args>
Expand All @@ -75,15 +76,17 @@ class IRPrinter : public IRVisitor {
}
}

static void run(IRNode *node, std::string *output) {
static void run(ExpressionPrinter *expr_printer,
IRNode *node,
std::string *output) {
if (node == nullptr) {
TI_WARN("IRPrinter: Printing nullptr.");
if (output) {
*output = std::string();
}
return;
}
auto p = IRPrinter(output);
auto p = IRPrinter(expr_printer, output);
p.print("kernel {{");
node->accept(&p);
p.print("}}");
Expand Down Expand Up @@ -777,16 +780,18 @@ class IRPrinter : public IRVisitor {
}

std::string expr_to_string(Expression *expr) {
TI_ASSERT(expr_printer_);
std::ostringstream oss;
expr_printer_.set_ostream(&oss);
expr->accept(&expr_printer_);
expr_printer_->set_ostream(&oss);
expr->accept(expr_printer_);
return oss.str();
}

std::string expr_group_to_string(ExprGroup &expr_group) {
TI_ASSERT(expr_printer_);
std::ostringstream oss;
expr_printer_.set_ostream(&oss);
expr_printer_.visit(expr_group);
expr_printer_->set_ostream(&oss);
expr_printer_->visit(expr_group);
return oss.str();
}
};
Expand All @@ -796,7 +801,14 @@ class IRPrinter : public IRVisitor {
namespace irpass {

void print(IRNode *root, std::string *output) {
return IRPrinter::run(root, output);
ExpressionHumanFriendlyPrinter expr_printer;
return IRPrinter::run(&expr_printer, root, output);
}

void gen_offline_cache_key(Program *prog, IRNode *root, std::string *output) {
irpass::re_id(root);
ExpressionOfflineCacheKeyGenerator cache_key_generator(prog);
return IRPrinter::run(&cache_key_generator, root, output);
}

} // namespace irpass
Expand Down

0 comments on commit d299122

Please sign in to comment.