Skip to content

Commit

Permalink
added backward handle caching and hash modification (tensorflow#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
taknevski authored and benoitsteiner committed Mar 30, 2017
1 parent 73db701 commit 4f4e6d1
Showing 1 changed file with 43 additions and 30 deletions.
73 changes: 43 additions & 30 deletions tensorflow/core/kernels/xsmm_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ class libxsmm_dnn_conv_desc_wrap{

struct HashFunction{
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{




//unsigned char ptr[sizeof(&w.d)];


//memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d))


//
/*
std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;
N << w.d.N; C << w.d.C;
Expand All @@ -144,28 +156,28 @@ struct HashFunction{
S << w.d.S; u << w.d.u;
v << w.d.v; padh << w.d.pad_h_in;
padw << w.d.pad_w_in;


std::string out_ = N.str() + C.str()\
+ H.str() + W.str()\
+ K.str() + R.str()\
+ S.str() + u.str()\
+ v.str() + padh.str()\
+ padw.str();

return ( std::hash<std::string>()(out_));
//
//
*/
return ( std::hash<unsigned long long>()((unsigned long long)&(w.d)));
}
};

class handles{
public:
libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
HashFunction>::iterator i = libxsmm_handles.find(w);
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction>::iterator i = libxsmm_handles.find(w);
if (i == libxsmm_handles.end()){
libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle =
libxsmm_dnn_create_conv_layer(w.d, &status);
libxsmm_dnn_layer* libxsmm_handle = libxsmm_dnn_create_conv_layer(w.d, &status);
chk_libxsmm_err(status, "Create handle");
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
return libxsmm_handle;
Expand All @@ -174,14 +186,15 @@ class handles{
return i->second;
}
~handles(){
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
HashFunction>::iterator i;
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction>::iterator i;
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
"Destroy handle");
}
private:

std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;

};

static handles libxsmm_handles;
Expand All @@ -198,12 +211,12 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_dnn_conv_desc_wrap w(desc);
void* scratch;

if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
libxsmm_handle = libxsmm_handles.find(w);
else {
libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
chk_libxsmm_err(status, "Create handle");
}
//if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
libxsmm_handle = libxsmm_handles.find(w);
//else{
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
// chk_libxsmm_err(status, "Create handle");
//}

status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
Expand All @@ -216,23 +229,23 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_dnn_buffer* libxsmm_input;
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;

/*
/*
const DeviceBase::CpuWorkerThreads* worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
*/

int ifmblock = (libxsmm_handle->ifmblock);
int ofmblock = (libxsmm_handle->ofmblock);
int ofmblock = (libxsmm_handle->ofmblock);

int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
float *native_filter = (float*)libxsmm_aligned_scratch(
blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float),
2097152);
float *native_filter = (float*)libxsmm_aligned_scratch( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);



const DeviceBase::CpuWorkerThreads* worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();

Expand Down Expand Up @@ -301,9 +314,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
}

/* bind scratch */
scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152);
scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status ), 2097152);
chk_libxsmm_err( status, "scratch allocation" );
chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" );
chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ), "binding scratch" );

if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
Expand Down Expand Up @@ -335,9 +348,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");

if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
"Destroy handle");
//if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
//chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
// "Destroy handle");

libxsmm_free(native_filter);
libxsmm_free(scratch);
Expand Down

0 comments on commit 4f4e6d1

Please sign in to comment.