Skip to content

Commit

Permalink
Clean: Remove len field from DataCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei3131 committed May 29, 2019
1 parent ad1540b commit e97969d
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 32 deletions.
2 changes: 1 addition & 1 deletion srcs/cpp/include/kungfu.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ extern void order_group_wait(order_group_t *);

#include <functional>
typedef std::function<void()> DoneCallback;
typedef std::function<void(void *,int)> DataCallback;
typedef std::function<void(void *)> DataCallback;

extern int KungfuReduce(const void *sendbuf, void *recvbuf, int count,
KungFu_Datatype dtype, KungFu_Op op, const char *name,
Expand Down
6 changes: 3 additions & 3 deletions srcs/cpp/include/kungfu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ extern void invoke_callback(callback_t *);
extern void delete_callback(callback_t *);

typedef struct data_callback_s data_callback_t;
extern void invoke_data_callback(data_callback_t *, void *, int len);
extern void invoke_data_callback(data_callback_t *, void *);
extern void delete_data_callback(data_callback_t *);

extern void float16_sum(void *z, const void *x, const void *y, int len);
Expand All @@ -37,12 +37,12 @@ struct CallbackWrapper {
};

struct data_callback_s {
using func_t = std::function<void(void *, int)>;
using func_t = std::function<void(void *)>;

public:
explicit data_callback_s(const func_t &f) : f_(f) {}

void operator()(void *data, int len) { f_(data, len); }
void operator()(void *data) { f_(data); }

private:
func_t f_;
Expand Down
2 changes: 1 addition & 1 deletion srcs/cpp/src/kungfu_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void invoke_callback(callback_t *f) { (*f)(); }

void delete_callback(callback_t *f) { delete f; }

void invoke_data_callback(data_callback_t *f, void *data, int len) { (*f)(data, len); }
void invoke_data_callback(data_callback_t *f, void *data) { (*f)(data); }

void delete_data_callback(data_callback_t *f) { delete f; }

Expand Down
39 changes: 13 additions & 26 deletions srcs/cpp/src/tensorflow/ops/p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ class RequestModel : public AsyncOpKernel
errors::InvalidArgument("ranks_ must not be empty"));

total_buf_size_ = 0;
for (int s : var_sizes_) {
total_buf_size_ += s;
}
for (int s : var_sizes_) { total_buf_size_ += s; }
}

public:
Expand All @@ -143,15 +141,12 @@ class RequestModel : public AsyncOpKernel

std::uniform_int_distribution<int> dist(0, ranks_.size() - 1);
int destination = dist(engine);
while (destination == self_rank_) {
destination = dist(engine);
}
while (destination == self_rank_) { destination = dist(engine); }

std::function<void()> func = [
&, modelBuf = modelBuf, type_size_bytes_ = type_size_bytes_,
outputs = outputs, var_sizes_ = var_sizes_, done = done
]()
{
std::function<void()> func = [&, modelBuf = modelBuf,
type_size_bytes_ = type_size_bytes_,
outputs = outputs,
var_sizes_ = var_sizes_, done = done]() {
std::lock_guard<std::mutex> l(mu_);

int offset = 0;
Expand Down Expand Up @@ -217,9 +212,7 @@ class SaveModel : public OpKernel

// number of floats it has
total_buf_size_ = 0;
for (int s : var_sizes_) {
total_buf_size_ += s;
}
for (int s : var_sizes_) { total_buf_size_ += s; }

modelBuf = (unsigned char *)malloc(total_buf_size_ * type_size_bytes_);
}
Expand Down Expand Up @@ -322,9 +315,7 @@ class RequestModelWithPrefetch : public OpKernel
errors::InvalidArgument("ranks_ must not be empty"));

total_buf_size_ = 0;
for (int s : var_sizes_) {
total_buf_size_ += s;
}
for (int s : var_sizes_) { total_buf_size_ += s; }
}

public:
Expand All @@ -345,9 +336,7 @@ class RequestModelWithPrefetch : public OpKernel

std::uniform_int_distribution<int> dist(0, ranks_.size() - 1);
int destination = dist(engine);
while (destination == self_rank_) {
destination = dist(engine);
}
while (destination == self_rank_) { destination = dist(engine); }

// Fill in the model Buffer with response from random peer
if (modelBuf == nullptr) {
Expand All @@ -357,12 +346,10 @@ class RequestModelWithPrefetch : public OpKernel
_kungfu_world->Request(destination, (void *)modelBuf,
total_buf_size_,
to_kungfu_type(context->input(0).dtype()));
prefetchCallback = [
&, modelBuf = modelBuf, prefetchBuf = prefetchBuf,
total_buf_size_ = total_buf_size_,
type_size_bytes_ = type_size_bytes_
]()
{
prefetchCallback = [&, modelBuf = modelBuf,
prefetchBuf = prefetchBuf,
total_buf_size_ = total_buf_size_,
type_size_bytes_ = type_size_bytes_]() {
std::lock_guard<std::mutex> l(mu_);
std::copy(prefetchBuf,
prefetchBuf + total_buf_size_ * type_size_bytes_,
Expand Down
2 changes: 1 addition & 1 deletion srcs/go/libkungfu-comm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func GoKungfuRank() int {
func GoKungfuRegisterDataCallback(name *C.char, handle *C.data_callback_t) int {
sess := kungfu.CurrentSession()
return sess.RegisterDataCallback(C.GoString(name), func(msg *rch.Message) {
C.invoke_data_callback(handle, unsafe.Pointer(&msg.Data[0]), C.int(msg.Length))
C.invoke_data_callback(handle, unsafe.Pointer(&msg.Data[0]))
})
}

Expand Down

0 comments on commit e97969d

Please sign in to comment.