Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/mirage/search/verification/output_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <cstddef>
#include <vector>
#include <cstddef>

namespace mirage {
namespace search {
Expand Down
17 changes: 8 additions & 9 deletions include/mirage/triton_transpiler/transpile.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#pragma once

#include <cstddef>
#include <string>
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "mirage/kernel/graph.h"
#include "mirage/threadblock/graph.h"
Expand All @@ -42,7 +42,7 @@ struct TritonTranspilerConfig {

// Result of transpiling a custom kernel operator
struct TritonCustomOPTranspileResult {
// Generated kernel function name
// Generated kernel function name
std::string func_name;
// Generated Triton kernel code
std::string code;
Expand All @@ -64,26 +64,25 @@ class TritonTranspiler {
std::vector<mirage::kernel::DTensor> mugraph_output_tensors;
std::unordered_map<decltype(tb::STensor::guid), STensorMeta>
stensor_metas; // STensor guid -> metadata

// Internal counter for kernel naming
static int kernel_idx_counter;

public:
TritonTranspiler(kernel::Graph const *_graph,
TritonTranspilerConfig const &_config);
TritonTranspilerConfig const &_config);
// Main entry point for code generation
TritonTranspileResult generate_code();
// Transpile a custom kernel operator
TritonCustomOPTranspileResult
TritonCustomOPTranspileResult
transpile_kn_custom_op(kn::KNCustomizedOp const *op);

// Transpile the kernel graph
TritonTranspileResult transpile_ugraph();

};

TritonTranspileResult transpile(kernel::Graph const *g,
TritonTranspilerConfig const &config);
TritonTranspilerConfig const &config);

} // namespace triton_transpiler
} // namespace mirage
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ z3-solver
# torch>=2.0
numpy
graphviz
tqdm
tqdm
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def config_cython():

# Create the build directory if it does not exist
os.makedirs(build_dir, exist_ok=True)
subprocess.check_call(['cmake', '..',
subprocess.check_call(['cmake', '..',
'-DZ3_CXX_INCLUDE_DIRS=' + z3_path + '/include/',
'-DZ3_LIBRARIES=' + path.join(z3_path, 'lib', 'libz3.so'),
'-DCMAKE_C_COMPILER=' + os.environ['CC'],
Expand Down
15 changes: 15 additions & 0 deletions src/nki_transpiler/transpile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ std::optional<NKIErrorInfo> NKITranspiler::resolve_tensor_layout() {
if (stensor.dim[i] > 128) {
opt.add(!s_is_partition[stensor.guid][i]);
}
// A partition dimension must be the last two dims
if ((i != num_dims - 1) && (i != num_dims - 2)) {
opt.add(!s_is_partition[stensor.guid][i]);
}
}
opt.add(z3::atmost(partition_exprs, 1));
opt.add(z3::atleast(partition_exprs, 1));
Expand Down Expand Up @@ -540,6 +544,16 @@ NKITranspileResult NKITranspiler::transpile_ugraph() {
exec.e("from torch_xla.core import xla_model as xm");
exec.e("device = xm.xla_device()");
for (kn::KNOperator *const op : g->operators) {
for (kn::DTensor const &dtensor : op->output_tensors) {
std::string shape;
for (int i = 0; i < dtensor.num_dims; i++) {
shape += fmt("$,", dtensor.dim[i]);
}
exec.e("$ = torch.randn(($), dtype=torch.float16).to(device=device)",
fmt("dtensor$", dtensor.guid),
shape);
}
#ifdef DEADCODE
if (op->op_type == type::KNOperatorType::KN_INPUT_OP) {
std::string shape;
kn::DTensor dtensor = op->output_tensors.at(0);
Expand All @@ -560,6 +574,7 @@ NKITranspileResult NKITranspiler::transpile_ugraph() {
fmt("dtensor$", dtensor.guid),
shape);
}
#endif
}
CodeKeeper custom_kernels;
for (kn::KNOperator *const op : g->operators) {
Expand Down
183 changes: 111 additions & 72 deletions src/nki_transpiler/transpile_tb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,62 +149,83 @@ NKICustomOPTranspileResult
for (int i = 0; i < stensor.num_dims - 2; i++) {
assert(stensor.dim[i] == 1);
}
string instruction = "nl.load";
if (meta.partition_dim == stensor.num_dims - 1) {
instruction = "nl.load_transpose2d";
bool transposed = false;
std::string range = "";
if (meta.partition_dim == stensor.num_dims - 2) {
// Normal case
code.e("$ = nl.ndarray(($, $), dtype=nl.float16, buffer=nl.sbuf)",
fmt("stensor$", stensor.guid),
stensor.dim[stensor.num_dims - 2],
stensor.dim[stensor.num_dims - 1]);
} else {
// Tranposed case
assert(meta.partition_dim == stensor.num_dims - 1);
// partition dim is the innermost dimension so we need
// to use load_transposed2d
transposed = true;
code.e("$ = nl.ndarray(($, $), dtype=nl.float16, buffer=nl.sbuf)",
fmt("stensor$", stensor.guid),
stensor.dim[stensor.num_dims - 1],
stensor.dim[stensor.num_dims - 2]);
}

int3 imap = input_op->input_map;
int forloop_dim = input_op->forloop_dim;
std::string range = "";
for (int i = 0; i < stensor.num_dims; i++) {
std::string index;
if (imap.x == i) {
range += fmt("$*$:$*$",
"nl.program_id(0)",
stensor.dim[i],
"(nl.program_id(0)+1)",
stensor.dim[i]);
int scale_factor = stensor.dim[i];
if (forloop_dim == i) {
scale_factor *= g.forloop_range;
}
index = fmt("nl.program_id(0) * $", scale_factor);
} else if (imap.y == i) {
range += fmt("$*$:$*$",
"nl.program_id(1)",
stensor.dim[i],
"(nl.program_id(1)+1)",
stensor.dim[i]);
int scale_factor = stensor.dim[i];
if (forloop_dim == i) {
scale_factor *= g.forloop_range;
}
index = fmt("nl.program_id(1) * $", scale_factor);
} else if (imap.z == i) {
range += fmt("$*$:$*$",
"nl.program_id(2)",
stensor.dim[i],
"(nl.program_id(2)+1)",
stensor.dim[i]);
} else if (forloop_dim == i) {
range +=
fmt("$*$:$*$", "i", stensor.dim[i], "(i+1)", stensor.dim[i]);
} else {
range += fmt("$:$", 0, stensor.dim[i]);
int scale_factor = stensor.dim[i];
if (forloop_dim == i) {
scale_factor *= g.forloop_range;
}
index = fmt("nl.program_id(2) * $", scale_factor);
}
if (forloop_dim == i) {
if (index == "") {
index = fmt("i * $", stensor.dim[i]);
} else {
index = index + fmt("+ i * $", stensor.dim[i]);
}
}
if (i == stensor.num_dims - 2) {
if (index == "") {
index = fmt("nl.arange($)[:, None]", stensor.dim[i]);
} else {
index = index + fmt(" + nl.arange($)[:, None]", stensor.dim[i]);
}
} else if (i == stensor.num_dims - 1) {
if (index == "") {
index = fmt("nl.arange($)[None, :]", stensor.dim[i]);
} else {
index = index + fmt(" + nl.arange($)[None, :]", stensor.dim[i]);
}
}
if (index == "") {
index = "0";
}
range += index;
if (i < stensor.num_dims - 1) {
range += ",";
range += ", ";
}
}
std::string range_suffix = "";
if (stensor.num_dims == 2 || stensor.num_dims == 1) {
} else if (stensor.num_dims == 3) {
range_suffix = "[0]";
} else if (stensor.num_dims == 4) {
range_suffix = "[0,0]";
} else {
assert(false && "Currently unsupported dim size");
}

code.e("$ = $($[$])",
fmt("stensor$", stensor.guid),
instruction,
transposed ? "nl.load_transpose2d" : "nl.load",
fmt("dtensor$", dtensor.guid),
range);
if (range_suffix.length() > 0) {
code.e("$ = $$",
fmt("stensor$", stensor.guid),
fmt("stensor$", stensor.guid),
range_suffix);
}
break;
}
case type::TB_OUTPUT_OP: {
Expand Down Expand Up @@ -409,45 +430,37 @@ NKICustomOPTranspileResult
int3 omap = output_op->output_map;
std::string range = "";
for (int i = 0; i < stensor.num_dims; i++) {
std::string index;
if (omap.x == i) {
range += fmt("$*$:$*$",
"nl.program_id(0)",
stensor.dim[i],
"(nl.program_id(0)+1)",
stensor.dim[i]);
index = fmt("nl.program_id(0) * $", stensor.dim[i]);
} else if (omap.y == i) {
range += fmt("$*$:$*$",
"nl.program_id(1)",
stensor.dim[i],
"(nl.program_id(1)+1)",
stensor.dim[i]);
index = fmt("nl.program_id(1) * $", stensor.dim[i]);
} else if (omap.z == i) {
range += fmt("$*$:$*$",
"nl.program_id(2)",
stensor.dim[i],
"(nl.program_id(2)+1)",
stensor.dim[i]);
} else {
range += fmt("$:$", 0, stensor.dim[i]);
index = fmt("nl.program_id(2) * $", stensor.dim[i]);
}
if (i == stensor.num_dims - 2) {
if (index == "") {
index = fmt("nl.arange($)[:, None]", stensor.dim[i]);
} else {
index = index + fmt(" + nl.arange($)[:, None]", stensor.dim[i]);
}
} else if (i == stensor.num_dims - 1) {
if (index == "") {
index = fmt("nl.arange($)[None, :]", stensor.dim[i]);
} else {
index = index + fmt(" + nl.arange($)[None, :]", stensor.dim[i]);
}
}
if (index == "") {
index = "0";
}
range += index;
if (i < stensor.num_dims - 1) {
range += ",";
range += ", ";
}
}
// Generate code for TB Output
std::string range_prefix = "";
if (stensor.num_dims == 1 || stensor.num_dims == 2) {
// Do nothing
} else if (stensor.num_dims == 3) {
range_prefix = "0,";
} else if (stensor.num_dims == 4) {
range_prefix = "0,0,";
} else {
assert(false && "Currently unsupported dim size");
}
code.e("nl.store($[$$], $)",
code.e("nl.store($[$], $)",
fmt("dtensor$", dtensor.guid),
range_prefix,
range,
need_transpose ? fmt("nl.transpose(stensor$)", stensor.guid)
: fmt("stensor$", stensor.guid));
Expand Down Expand Up @@ -539,6 +552,32 @@ NKICustomOPTranspileResult
fmt("stensor$", input.guid));
break;
}
case type::TB_REDUCTION_0_TO_DIMX_OP:
case type::TB_REDUCTION_1_TO_DIMX_OP:
case type::TB_REDUCTION_2_TO_DIMX_OP: {
// May need to recompute
int reduc_dim = tb_op->op_type >= type::TB_REDUCTION_0_TO_DIMX_OP
? tb_op->op_type - type::TB_REDUCTION_0_TO_DIMX_OP
: tb_op->op_type - type::TB_REDUCTION_0_OP;
tb::STensor const &input = tb_op->input_tensors.at(0);
tb::STensor const &output = tb_op->output_tensors.at(0);
STensorMeta meta0 = stensor_metas.at(input.guid);
STensorMeta meta1 = stensor_metas.at(output.guid);
// FIXME: currently assume no change of partition dim in reduction
assert(meta0.partition_dim == meta1.partition_dim);
int num_dims = input.num_dims;
// assert that reduc_dim is among the last two dimensions since
// we omit all other leading dims (which must have a dim size of 1)
assert(num_dims - 2 <= reduc_dim && reduc_dim < num_dims);
// Cannot pick partition dim as the reduce_dim
assert(reduc_dim != meta0.partition_dim);
// reduction is perform on axis=1, since axis=0 maps to
// the partition dim
code.e("$ = nl.sum($, axis=1, keepdims=True)",
fmt("stensor$", output.guid),
fmt("stensor$", input.guid));
break;
}
default: {
assert(false && fmt("Unsupported op_type:$", tb_op->op_type).c_str());
}
Expand Down
3 changes: 2 additions & 1 deletion src/search/search_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ int cython_search(mirage::kernel::Graph const *input_graph,
config.frange_to_explore.push_back(frange);
}
}
const char *result_filename = filename ? filename : "mirage_search_checkpoint.json";
char const *result_filename =
filename ? filename : "mirage_search_checkpoint.json";
search::KernelGraphGenerator gen(
*input_graph, config, result_filename, verbose);
gen.config.show();
Expand Down
Loading