Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Update softmax compute and CPU schedule #3680

Merged
merged 7 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions topi/include/topi/cuda/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,32 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs
auto s = create_schedule(out_ops);

auto softmax = outs[0];
auto max_elem = softmax->op->InputTensors()[1];
auto expsum = softmax->op->InputTensors()[2];
tvm::Tensor max_elem;
tvm::Tensor expsum;
tvm::Tensor exp;
bool has_exp = false;

auto tag = softmax->op.as<ComputeOpNode>()->tag;
if (tag == "softmax_output") {
expsum = softmax->op->InputTensors()[1];
exp = softmax->op->InputTensors()[0];
max_elem = s[exp]->op->InputTensors()[1];
has_exp = true;
} else if (tag == "log_softmax_output") {
max_elem = softmax->op->InputTensors()[1];
expsum = softmax->op->InputTensors()[2];
} else {
LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
}

int num_thread = 64;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");

if (has_exp) {
s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
}

s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);

auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
Expand Down
42 changes: 29 additions & 13 deletions topi/include/topi/nn/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ inline Tensor softmax(const Tensor &x,
auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);

tvm::Map<std::string, NodeRef> attrs;
attrs.Set("axis", Integer(axis));

auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) {
Array<Expr> eval_range;
Expand All @@ -75,35 +78,48 @@ inline Tensor softmax(const Tensor &x,
return eval_range;
};

auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
Array<Expr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return non_reduce_indices;
};

auto _compute_max = [&](const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k1);
return topi::MaxOp(x(eval_range), {k1});
};

auto _compute_expsum = [&](const Tensor &max_elem,
auto _compute_exp = [&](const Tensor &max_elem,
const Array<Var> &indices) {
auto non_reduce_indices = get_non_reduce_indices(indices);
return tvm::exp(x(indices) - max_elem(non_reduce_indices));
};

auto _compute_expsum = [&](const Tensor &exp,
const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k2);
return tvm::sum(tvm::exp(x(eval_range) - max_elem(indices)), {k2});
return tvm::sum(exp(eval_range), {k2});
};

auto _normalize = [&](const Tensor &max_elem, const Tensor &expsum,
auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
const Array<Var> &indices) {
Array<Expr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return tvm::exp(x(indices) - max_elem(non_reduce_indices)) /
expsum(non_reduce_indices);
auto non_reduce_indices = get_non_reduce_indices(indices);
return exp(indices) / expsum(non_reduce_indices);
};

auto max_elem = tvm::compute(reduced_shape, _compute_max);
auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _compute_exp(max_elem, indices);
});
auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
return _compute_expsum(max_elem, indices);
return _compute_expsum(exp, indices);
});
return tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _normalize(max_elem, expsum, indices);
}, name, tag);
return _normalize(exp, expsum, indices);
}, name, tag, attrs);
}

/*!
Expand Down
24 changes: 21 additions & 3 deletions topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,35 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]

op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

if len(softmax.shape) > 2:
for op in [max_elem.op, expsum.op, softmax.op]:
ops = [max_elem.op, expsum.op, softmax.op]
if exp != None:
ops.append(exp.op)

for op in ops:
s = _schedule_injective(op, s)
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")

if exp != None:
s[exp].bind(exp.op.axis[0], block_x)

s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
Expand Down
18 changes: 16 additions & 2 deletions topi/python/topi/hls/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,22 @@ def schedule_softmax(outs):
tvm.schedule.AutoInlineInjective(s)

softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]

op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

if exp != None:
s[exp].compute_at(s[softmax], s[softmax].op.axis[1])

s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
Expand Down
25 changes: 17 additions & 8 deletions topi/python/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,33 @@ def softmax(x, axis=-1):
def insert_reduce_index(indices, reduce_index):
return indices[:axis] + (reduce_index,) + indices[axis:]

def get_non_reduce_indices(indices):
return tuple([var for (i, var) in enumerate(indices) if i != axis])

def _compute_max(*indices):
eval_range = insert_reduce_index(indices, k1)
return tvm.max(x[eval_range], axis=k1)

def _compute_expsum(max_elem, *indices):
def _compute_exp(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return tvm.exp(x[indices] - max_elem[non_reduce_indices])

def _compute_expsum(exp, *indices):
eval_range = insert_reduce_index(indices, k2)
return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2)
return tvm.sum(exp[eval_range], axis=k2)

def _normalize(max_elem, expsum, *indices):
non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis])
return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
def _normalize(exp, expsum, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return exp[indices] / expsum[non_reduce_indices]

reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices),
exp = tvm.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
name='T_softmax_exp')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
name='T_softmax_expsum')
return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices),
name='T_softmax_norm')
return tvm.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
name='T_softmax_norm', attrs={"axis" : axis})


@tvm.tag_scope(tag='log_softmax_output')
Expand Down
19 changes: 17 additions & 2 deletions topi/python/topi/opengl/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,23 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]

op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

if exp != None:
s[exp].opengl()

s[max_elem].opengl()
s[expsum].opengl()
s[softmax].opengl()
Expand Down
37 changes: 28 additions & 9 deletions topi/python/topi/x86/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,34 @@ def schedule_softmax(outs):
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0]
softmax = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)

op_tag = softmax.op.tag
if op_tag == 'softmax_output':
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
max_elem = s[exp].op.input_tensors[1]
axis = int(softmax.op.attrs['axis'])
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
axis = 1
else:
s[x].parallel(s[x].op.axis[0])
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

# only parallelize outer dimensions up to axis
outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[softmax].fuse(*outer_axes)
s[softmax].parallel(fused_outer_axes)

# move computations with the same outer dimensions under the same root
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)

if exp != None:
s[exp].compute_at(s[softmax], fused_outer_axes)

return s