Permalink
Browse files

Replay reverted commit.

  • Loading branch information...
Maciek Chociej
Maciek Chociej committed May 16, 2017
1 parent c773cda commit 9941cad948313bffd44b8cafe759aa6d8d6a9d75
@@ -148,6 +148,9 @@ def DoubleClobbers(self):
def FreeRegister(self, register):
assert len(register) > 1
if register[0] not in ['r', 'd', 'q']:
return
num = int(register[1:])
if register[0] == 'r':
@@ -255,7 +255,9 @@ def Copy(self):
self.dereference_increment)
def __repr__(self):
if self.register_bits is 64:
if self.register_bits is None:
text = '%%[%s]' % self.name
elif self.register_bits is 64:
text = '%%x[%s]' % self.name
elif self.register_bits <= 32:
text = '%%w[%s]' % self.name
@@ -388,6 +390,9 @@ def Clobbers(self):
for i in self.general_ever] + ['v%d' % i for i in self.vector_ever])
def FreeRegister(self, register):
if isinstance(register, _MappedParameter):
return
if register.register_type == 'v':
assert register.number in self.vector
self.vector.remove(register.number)
@@ -797,7 +802,7 @@ def EmitVPadal(self, add_type, destination, source):
_AppendType(add_type, source))
def EmitLdr(self, register, value):
self.EmitOp2('ldr', _Cast(32, register), value)
self.EmitOp2('ldr', _Cast(32, register), _Cast(None, value))
def EmitVLoad(self, load_no, load_type, destination, source):
self.EmitVLoadA(load_no, load_type, [destination], source)
@@ -78,13 +78,40 @@ def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count,
registers.FreeRegisters(block)
def _LoadMemoryParameter(emitter, registers, name, source):
register = registers.GeneralRegister()
emitter.EmitLdr(register, registers.MapMemoryParameter(name, source))
return register
def _GenerateAggregatorReductionLowRegisters(emitter, registers,
aggregators, output_address):
emitter.EmitNewline()
emitter.EmitComment('Aggregator Reduction.')
_GenerateAggregatorReduction(
emitter, registers, aggregators, output_address,
_LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
_LoadMemoryParameter(emitter, registers, 'additive_sum_offset',
'params.additive_sum_offset'))
def _GenerateAggregatorReductionHighRegisters(emitter, registers,
aggregators, output_address):
emitter.EmitNewline()
emitter.EmitComment('Aggregator Reduction.')
_GenerateAggregatorReduction(
emitter, registers, aggregators, output_address,
registers.MapParameter('multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
registers.MapParameter('additive_sum_offset',
'params.additive_sum_offset'))
def _GenerateAggregatorReduction(emitter, registers, aggregators,
output_address, multiplicative_sum_offset,
additive_sum_offset):
"""Reduce 4 lane sum aggregators to 1 value and store the sums."""
emitter.EmitNewline()
emitter.EmitComment('Aggregator Reduction.')
multiplier = registers.DoubleRegister()
emitter.EmitVMov('32',
emitter.Lane(32, multiplier, 0), multiplicative_sum_offset)
@@ -162,12 +189,14 @@ def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
_GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count,
leftovers, aggregators, inputs, output)
_GenerateAggregatorReduction(
self.asm_emitter, registers, aggregators, output,
registers.MapParameter('multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
registers.MapParameter('additive_sum_offset',
'params.additive_sum_offset'))
registers.FreeRegisters(inputs)
if len(inputs) <= 6:
_GenerateAggregatorReductionHighRegisters(
self.asm_emitter, registers, aggregators, output)
else:
_GenerateAggregatorReductionLowRegisters(
self.asm_emitter, registers, aggregators, output)
self.asm_emitter.EmitAsmEnd(registers)
self.asm_emitter.PopIndent(len(self.emitter.indent))
@@ -253,12 +282,9 @@ def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
leftovers, aggregators, input_address,
stride, output_address)
_GenerateAggregatorReduction(
self.asm_emitter, registers, aggregators, output_address,
registers.MapParameter('multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
registers.MapParameter('additive_sum_offset',
'params.additive_sum_offset'))
_GenerateAggregatorReductionHighRegisters(
self.asm_emitter, registers, aggregators, output_address)
self.asm_emitter.EmitAsmEnd(registers)
self.asm_emitter.PopIndent(len(self.emitter.indent))
@@ -133,15 +133,14 @@ void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
std::uint8_t* task_scratch = scratch;
std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
std::vector<Task*> tasks;
std::for_each(task_rects.begin(), task_rects.end(),
[&tasks, &task_scratch, lhs, rhs, k, result, result_stride,
operation, scratch_per_thread]
(internal::TaskRect& rect) {
tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
task_scratch, lhs, rhs, rect, k, result, result_stride,
operation));
task_scratch += scratch_per_thread;
});
std::for_each(
task_rects.begin(), task_rects.end(),
[&tasks, &task_scratch, lhs, rhs, k, result, result_stride, operation,
scratch_per_thread](internal::TaskRect& rect) {
tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
task_scratch, lhs, rhs, rect, k, result, result_stride, operation));
task_scratch += scratch_per_thread;
});
pool->Execute(tasks);
}
@@ -37,8 +37,8 @@ void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
std::cout << "Legacy::GemmQ8." << std::endl;
#endif
#endif
typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum,
QuantizedStaticPreprocessed, RowMajor>
typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
Params;
Params params;
@@ -81,8 +81,8 @@ void gemv_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
std::cout << "Legacy::GemvQ8." << std::endl;
#endif
#endif
typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum,
QuantizedStaticPreprocessed, RowMajor>
typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
Params;
Params params;
@@ -129,8 +129,9 @@ void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
std::cout << "Legacy::GemmI32." << std::endl;
#endif
#endif
typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum,
QuantizedStaticPreprocessedAsInt32, RowMajor>
typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
RowMajor>
Params;
Params params;
@@ -168,8 +169,9 @@ void gemv_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
std::cout << "Legacy::GemvI32." << std::endl;
#endif
#endif
typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum,
QuantizedStaticPreprocessedAsInt32, RowMajor>
typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
RowMajor>
Params;
Params params;
View
@@ -27,8 +27,9 @@ const std::int32_t kMinGemmTaskDimension = 4;
template <typename Executor, typename Params>
std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n,
int kernel_k, std::uint8_t* scratch, int m_start, int m,
int n_start, int n, std::vector<Params>* tasks) {
int kernel_k, std::uint8_t* scratch, int m_start,
int m, int n_start, int n,
std::vector<Params>* tasks) {
tasks->push_back(params);
Params& task = tasks->back();
task.scratch = scratch;
@@ -131,7 +132,7 @@ inline void MultiThreadGemm(MultiThreadingContext* context,
auto workers_pool = context->workers_pool();
std::vector<Task*> tasks;
std::for_each(task_params.begin(), task_params.end(), [tasks](Params *param) {
std::for_each(task_params.begin(), task_params.end(), [tasks](Params* param) {
tasks.push_back(new TaskRunnerType(param));
});
workers_pool->Execute(tasks);
@@ -86,7 +86,7 @@ inline void MultiThreadTransform1D(MultiThreadingContext* context,
auto workers_pool = context->workers_pool();
std::vector<Task*> tasks;
std::for_each(task_params.begin(), task_params.end(), [tasks](Params *param) {
std::for_each(task_params.begin(), task_params.end(), [tasks](Params* param) {
tasks.push_back(new TaskRunnerType(param));
});
workers_pool->Execute(tasks);
Oops, something went wrong.

0 comments on commit 9941cad

Please sign in to comment.