forked from horovod/horovod
-
Notifications
You must be signed in to change notification settings - Fork 0
/
controller.cc
726 lines (629 loc) · 26.3 KB
/
controller.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
// Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "controller.h"
#include <atomic>
#include <map>
#include <queue>
#include <set>
#include <unordered_set>
#include "global_state.h"
#include "logging.h"
#include "operations.h"
namespace horovod {
namespace common {
void Controller::SynchronizeParameters() {
ParameterManager::Params param;
if (is_coordinator_) {
param = parameter_manager_.GetParams();
}
void* buffer = (void*)(¶m);
size_t param_size = sizeof(param);
Bcast(buffer, param_size, 0, Communicator::GLOBAL);
if (!is_coordinator_) {
parameter_manager_.SetParams(param);
}
parameter_manager_.Reset();
}
Controller::Controller(ResponseCache& response_cache, TensorQueue& tensor_queue,
Timeline& timeline, ParameterManager& parameter_manager)
: stall_inspector_(response_cache), tensor_queue_(tensor_queue),
timeline_(timeline), response_cache_(response_cache),
parameter_manager_(parameter_manager) {}
ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down) {
// Update cache capacity if autotuning is active.
if (parameter_manager_.IsAutoTuning()) {
response_cache_.set_capacity((int)parameter_manager_.CacheEnabled() *
cache_capacity_);
}
// Copy the data structures out from parameters.
// However, don't keep the lock for the rest of the loop, so that
// enqueued stream callbacks can continue.
CacheCoordinator cache_coordinator(response_cache_.num_active_bits());
// message queue used only in this cycle
std::deque<Request> message_queue_tmp;
tensor_queue_.PopMessagesFromQueue(message_queue_tmp);
for (auto& message : message_queue_tmp) {
// Keep track of cache hits
if (response_cache_.capacity() > 0) {
auto cache_ = response_cache_.cached(message);
if (cache_ == ResponseCache::CacheState::HIT) {
uint32_t cache_bit = response_cache_.peek_cache_bit(message);
cache_coordinator.record_hit(cache_bit);
// Record initial time cached tensor is encountered in queue.
stall_inspector_.RecordCachedTensorStart(message.tensor_name());
} else {
if (cache_ == ResponseCache::CacheState::INVALID) {
uint32_t cache_bit = response_cache_.peek_cache_bit(message);
cache_coordinator.record_invalid_bit(cache_bit);
}
cache_coordinator.set_uncached_in_queue(true);
// Remove timing entry if uncached or marked invalid.
stall_inspector_.RemoveCachedTensor(message.tensor_name());
}
}
}
// Flag indicating that the background thread should shut down.
bool should_shut_down = shut_down;
// Check for stalled tensors.
if (stall_inspector_.ShouldPerformCheck()) {
if (is_coordinator_) {
should_shut_down |= stall_inspector_.CheckForStalledTensors(size_);
}
if (response_cache_.capacity() > 0) {
stall_inspector_.InvalidateStalledCachedTensors(cache_coordinator);
}
stall_inspector_.UpdateCheckTime();
}
cache_coordinator.set_should_shut_down(should_shut_down);
if (response_cache_.capacity() > 0) {
// Obtain common cache hits and cache invalidations across workers. Also,
// determine if any worker has uncached messages in queue or requests
// a shutdown. This function removes any invalid cache entries, if they
// exist.
CoordinateCacheAndState(cache_coordinator);
LOG(DEBUG) << "Cache coordinated.";
// Remove uncommon cached tensors from queue and replace to state
// queue for next cycle. Skip adding common cached tensors to
// queue as they are handled separately.
size_t num_messages = message_queue_tmp.size();
for (size_t i = 0; i < num_messages; ++i) {
auto message = message_queue_tmp.front();
if (response_cache_.cached(message) == ResponseCache::CacheState::HIT) {
uint32_t cache_bit = response_cache_.peek_cache_bit(message);
if (cache_coordinator.cache_hits().find(cache_bit) ==
cache_coordinator.cache_hits().end()) {
// Try to process again in next cycle.
tensor_queue_.PushMessageToQueue(message);
} else {
// Remove timing entry for messages being handled this cycle.
stall_inspector_.RemoveCachedTensor(message.tensor_name());
}
} else {
// Remove timing entry for messages being handled this cycle.
stall_inspector_.RemoveCachedTensor(message.tensor_name());
message_queue_tmp.push_back(std::move(message));
}
message_queue_tmp.pop_front();
}
}
if (!message_queue_tmp.empty()) {
LOG(DEBUG, rank_) << "Sent " << message_queue_tmp.size()
<< " messages to coordinator.";
}
ResponseList response_list;
response_list.set_shutdown(cache_coordinator.should_shut_down());
bool need_communication = true;
if (response_cache_.capacity() > 0 &&
!cache_coordinator.uncached_in_queue()) {
// if cache is enabled and no uncached new message coming in, no need for
// additional communications
need_communication = false;
// If no messages to send, we can simply return an empty response list;
if (cache_coordinator.cache_hits().empty()) {
return response_list;
}
// otherwise we need to add cached messages to response list.
}
if (!need_communication) {
// If all messages in queue have responses in cache, use fast path with
// no additional coordination.
std::deque<Response> responses;
// Convert cache hits to responses. Populate so that least
// recently used responses get priority. All workers call the code
// here so we use the get method here to consistently update the cache
// order.
for (auto bit : cache_coordinator.cache_hits()) {
responses.push_back(response_cache_.get_response(bit));
}
// Fuse responses as normal.
response_list = FuseResponses(responses);
} else {
// There are uncached messages coming in, need communication to figure out
// whether those are ready to be reduced.
// Collect all tensors that are ready to be reduced. Record them in the
// tensor count table (rank zero) or send them to rank zero to be
// recorded (everyone else).
std::vector<std::string> ready_to_reduce;
if (is_coordinator_) {
LOG(TRACE) << "Adding messages from rank 0";
while (!message_queue_tmp.empty()) {
// Pop the first available message
Request message = message_queue_tmp.front();
message_queue_tmp.pop_front();
bool reduce = IncrementTensorCount(message);
stall_inspector_.RecordUncachedTensorStart(
message.tensor_name(), message.request_rank(), size_);
if (reduce) {
ready_to_reduce.push_back(message.tensor_name());
}
}
// Receive ready tensors from other ranks
std::vector<RequestList> ready_list;
RecvReadyTensors(ready_to_reduce, ready_list);
// Process messages.
for (int i = 1; i < size_; ++i) {
LOG(TRACE) << "Adding messages from rank " << i;
auto received_message_list = ready_list[i];
for (auto& received_message : received_message_list.requests()) {
auto& received_name = received_message.tensor_name();
bool reduce = IncrementTensorCount(received_message);
stall_inspector_.RecordUncachedTensorStart(
received_message.tensor_name(), received_message.request_rank(),
size_);
if (reduce) {
ready_to_reduce.push_back(received_name);
}
}
if (received_message_list.shutdown()) {
// Received SHUTDOWN request from one of the workers.
should_shut_down = true;
}
}
// At this point, rank zero should have a fully updated tensor count
// table and should know all the tensors that need to be reduced or
// gathered, and everyone else should have sent all their information
// to rank zero. We can now do reductions and gathers; rank zero will
// choose which ones and in what order, and will notify the other ranks
// before doing each reduction.
std::deque<Response> responses;
if (response_cache_.capacity() > 0) {
// Prepopulate response list with cached responses. Populate so that
// least recently used responses get priority. Since only the
// coordinator rank calls this code, use peek instead of get here to
// preserve cache order across workers.
for (auto bit : cache_coordinator.cache_hits()) {
responses.push_back(response_cache_.peek_response(bit));
}
}
for (auto& tensor_name : ready_to_reduce) {
Response response = ConstructResponse(tensor_name);
responses.push_back(std::move(response));
}
response_list = FuseResponses(responses);
response_list.set_shutdown(should_shut_down);
// Broadcast final results to other ranks.
SendFinalTensors(response_list);
} else {
RequestList message_list;
message_list.set_shutdown(should_shut_down);
while (!message_queue_tmp.empty()) {
message_list.add_request(message_queue_tmp.front());
message_queue_tmp.pop_front();
}
// Send ready tensors to rank zero
SendReadyTensors(message_list);
// Receive final tensors to be processed from rank zero
RecvFinalTensors(response_list);
}
}
if (!response_list.responses().empty()) {
std::string tensors_ready;
for (const auto& r : response_list.responses()) {
tensors_ready += r.tensor_names_string() + "; ";
}
LOG(TRACE) << "Sending ready responses as " << tensors_ready;
}
// If need_communication is false, meaning no uncached message coming in,
// thus no need to update cache.
if (need_communication && response_cache_.capacity() > 0) {
// All workers add supported responses to cache. This updates the cache
// order consistently across workers.
for (auto& response : response_list.responses()) {
if (response.response_type() == Response::ResponseType::ALLREDUCE &&
(int)response.devices().size() == size_) {
response_cache_.put(response, tensor_queue_);
}
}
}
// Reassign cache bits based on current cache order.
response_cache_.update_cache_bits();
return response_list;
}
int64_t Controller::TensorFusionThresholdBytes() {
int64_t proposed_fusion_threshold =
parameter_manager_.TensorFusionThresholdBytes();
// If the cluster is homogeneous and hierarchical allreduce is enabled,
// adjust buffer size to make sure it is divisible by local_size to improve
// performance.
if (is_homogeneous_ && parameter_manager_.HierarchicalAllreduce()) {
// Assume the worst-case data type float64, since if it is divisible with
// float64, it will be divisible for other types too.
// Ensuring that fusion buffer can hold a number of elements divisible by
// FUSION_BUFFER_ATOMIC_UNIT for performance
int double_size = GetTypeSize(HOROVOD_FLOAT64);
int64_t div = local_size_ * double_size * FUSION_BUFFER_ATOMIC_UNIT;
return ((proposed_fusion_threshold + div - 1) / div) * div;
}
return proposed_fusion_threshold;
}
Response Controller::ConstructResponse(std::string& name) {
bool error = false;
auto it = message_table_.find(name);
assert(it != message_table_.end());
std::vector<Request>& requests = it->second;
assert(!requests.empty());
std::ostringstream error_message_stream;
// Check that all data types of tensors being reduced, gathered or broadcasted
// are identical.
auto data_type = requests[0].tensor_type();
for (unsigned int i = 1; i < requests.size(); ++i) {
auto request_type = requests[i].tensor_type();
if (data_type != request_type) {
error = true;
error_message_stream << "Mismatched data types: One rank had type "
<< DataType_Name(data_type)
<< ", but another rank had type "
<< DataType_Name(request_type) << ".";
break;
}
}
// Check that all requested operations are the same
auto message_type = requests[0].request_type();
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
auto request_type = requests[i].request_type();
if (message_type != request_type) {
error = true;
error_message_stream << "Mismatched operations: One rank did an "
<< Request::RequestType_Name(message_type)
<< ", but another rank did an "
<< Request::RequestType_Name(request_type) << ".";
break;
}
}
// If we are doing an allreduce or broadcast, check that all tensor shapes are
// identical.
if (message_type == Request::ALLREDUCE ||
message_type == Request::BROADCAST) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
tensor_shape.AddDim(dim);
}
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
TensorShape request_shape;
for (auto dim : requests[i].tensor_shape()) {
request_shape.AddDim(dim);
}
if (tensor_shape != request_shape) {
error = true;
error_message_stream
<< "Mismatched " << Request::RequestType_Name(message_type)
<< " tensor shapes: One rank sent a tensor of shape "
<< tensor_shape.DebugString()
<< ", but another rank sent a tensor of shape "
<< request_shape.DebugString() << ".";
break;
}
}
}
// If we are doing an allgather, make sure all but the first dimension are
// the same. The first dimension may be different and the output tensor is
// the sum of the first dimension. Collect the sizes by rank.
std::vector<int64_t> tensor_sizes(requests.size());
if (message_type == Request::ALLGATHER) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
tensor_shape.AddDim(dim);
}
if (tensor_shape.dims() == 0) {
error = true;
error_message_stream << "Rank zero tried to "
<< Request::RequestType_Name(message_type)
<< " a rank-zero tensor.";
} else {
tensor_sizes[requests[0].request_rank()] = tensor_shape.dim_size(0);
}
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
TensorShape request_shape;
for (auto dim : requests[i].tensor_shape()) {
request_shape.AddDim(dim);
}
if (tensor_shape.dims() != request_shape.dims()) {
error = true;
error_message_stream
<< "Mismatched " << Request::RequestType_Name(message_type)
<< " tensor shapes: One rank sent a tensor of rank "
<< tensor_shape.dims()
<< ", but another rank sent a tensor of rank "
<< request_shape.dims() << ".";
break;
}
bool dim_mismatch = false;
for (int dim = 1; dim < tensor_shape.dims(); ++dim) {
if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
error = true;
error_message_stream
<< "Mismatched " << Request::RequestType_Name(message_type)
<< " tensor shapes: One rank sent a tensor with dimension " << dim
<< " equal to " << tensor_shape.dim_size(dim)
<< ", but another rank sent a tensor with dimension " << dim
<< " equal to " << request_shape.dim_size(dim) << ".";
dim_mismatch = true;
break;
}
}
if (dim_mismatch) {
break;
}
tensor_sizes[requests[i].request_rank()] = request_shape.dim_size(0);
}
}
// If we are doing a broadcast, check that all root ranks are identical.
if (message_type == Request::BROADCAST) {
int first_root_rank = requests[0].root_rank();
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
int this_root_rank = requests[i].root_rank();
if (first_root_rank != this_root_rank) {
error = true;
error_message_stream
<< "Mismatched " << Request::RequestType_Name(message_type)
<< " root ranks: One rank specified root rank " << first_root_rank
<< ", but another rank specified root rank " << this_root_rank
<< ".";
break;
}
}
}
bool first_device_is_cpu = requests[0].device() == CPU_DEVICE_ID;
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
bool this_device_is_cpu = requests[i].device() == CPU_DEVICE_ID;
if (first_device_is_cpu != this_device_is_cpu) {
error = true;
error_message_stream
<< "Mismatched " << Request::RequestType_Name(message_type)
<< " CPU/GPU device selection: One rank specified device "
<< (first_device_is_cpu ? "CPU" : "GPU")
<< ", but another rank specified device "
<< (this_device_is_cpu ? "CPU" : "GPU") << ".";
break;
}
}
std::vector<int32_t> devices(requests.size());
for (auto& request : requests) {
devices[request.request_rank()] = request.device();
}
Response response;
response.add_tensor_name(name);
if (error) {
std::string error_message = error_message_stream.str();
response.set_response_type(Response::ERROR);
response.set_error_message(error_message);
} else if (message_type == Request::ALLGATHER) {
response.set_response_type(Response::ALLGATHER);
for (auto dim : tensor_sizes) {
response.add_tensor_size(dim);
}
} else if (message_type == Request::ALLREDUCE) {
response.set_response_type(Response::ALLREDUCE);
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
}
response.set_devices(devices);
// Clear all queued up requests for this name. They are now taken care of
// by the constructed response.
message_table_.erase(it);
stall_inspector_.RemoveUncachedTensor(name);
return response;
}
void Controller::CoordinateCacheAndState(CacheCoordinator& cache_coordinator) {
// Sync cache and state information across workers.
cache_coordinator.sync(shared_from_this(), timeline_enabled_);
// If invalid cache entries exist, erase associated entries.
if (!cache_coordinator.invalid_bits().empty()) {
for (auto bit : cache_coordinator.invalid_bits()) {
response_cache_.erase_response(bit);
}
}
if (timeline_enabled_) {
// Start/continue negotiation phase on timeline bit entries.
for (auto bit : cache_coordinator.timeline_bits()) {
auto& response = response_cache_.peek_response(bit);
timeline_.NegotiateStart(response.tensor_names()[0],
(Request::RequestType)response.response_type());
}
// End negotation phase for synced cache hit set entries.
for (auto bit : cache_coordinator.cache_hits()) {
auto& response = response_cache_.peek_response(bit);
timeline_.NegotiateEnd(response.tensor_names()[0]);
}
}
}
ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
ResponseList response_list;
while (!responses.empty()) {
auto response = responses.front();
assert(response.tensor_names().size() == 1);
responses.pop_front();
int64_t tensor_size = 0;
if (response.response_type() == Response::ResponseType::ALLREDUCE) {
// Attempt to add more responses to this fused response.
const auto& entry =
tensor_queue_.GetTensorEntry(response.tensor_names()[0]);
tensor_size = entry.tensor->size();
std::deque<Response> skipped_responses;
int64_t skipped_size = 0;
while (!responses.empty()) {
auto new_response = responses.front();
assert(new_response.tensor_names().size() == 1);
const auto& new_entry =
tensor_queue_.GetTensorEntry(new_response.tensor_names()[0]);
int64_t new_tensor_size = new_entry.tensor->size();
if (response.response_type() == new_response.response_type() &&
response.devices() == new_response.devices() &&
entry.tensor->dtype() == new_entry.tensor->dtype() &&
tensor_size + new_tensor_size <= TensorFusionThresholdBytes()) {
// These tensors will fuse together well.
tensor_size += new_tensor_size;
response.add_tensor_name(new_response.tensor_names()[0]);
responses.pop_front();
} else {
// In general, don't try to fuse additional tensors since they are
// usually computed in order of requests and skipping tensors may
// mean that the batch will have to wait longer while skipped
// tensors could be reduced at that time. However, mixed-precision
// training may yield requests of various dtype in a mixed-up
// sequence causing breakups in fusion. To counter this some look
// ahead is allowed.
skipped_size += new_tensor_size;
if (tensor_size + skipped_size <= TensorFusionThresholdBytes()) {
// Skip response and look ahead for more to fuse.
skipped_responses.push_back(std::move(responses.front()));
responses.pop_front();
} else {
break;
}
}
}
// Replace any skipped responses.
while (!skipped_responses.empty()) {
responses.push_front(std::move(skipped_responses.back()));
skipped_responses.pop_back();
}
} else if (response.response_type() == Response::ResponseType::ALLGATHER) {
// Attempt to add more responses to this fused response.
const auto& entry =
tensor_queue_.GetTensorEntry(response.tensor_names()[0]);
// This is size of first dimension.
int64_t total_byte_size_of_output =
TotalByteSizeOfAllgatherOutput(response.tensor_sizes(), entry);
std::deque<Response> skipped_responses;
int64_t skipped_size = 0;
while (!responses.empty()) {
auto new_response = responses.front();
assert(new_response.tensor_names().size() == 1);
const auto& new_entry =
tensor_queue_.GetTensorEntry(new_response.tensor_names()[0]);
int64_t new_total_byte_size_of_output = TotalByteSizeOfAllgatherOutput(
new_response.tensor_sizes(), new_entry);
if (response.response_type() == new_response.response_type() &&
response.devices() == new_response.devices() &&
entry.tensor->dtype() == new_entry.tensor->dtype() &&
total_byte_size_of_output + new_total_byte_size_of_output <=
TensorFusionThresholdBytes()) {
// These tensors will fuse together well.
total_byte_size_of_output += new_total_byte_size_of_output;
response.add_allgather_response(new_response);
responses.pop_front();
} else {
// In general, don't try to fuse additional tensors since they are
// usually computed in order of requests and skipping tensors may
// mean that the batch will have to wait longer while skipped
// tensors could be reduced at that time. However, mixed-precision
// training may yield requests of various dtype in a mixed-up
// sequence causing breakups in fusion. To counter this some look
// ahead is allowed.
skipped_size += new_total_byte_size_of_output;
if (total_byte_size_of_output + skipped_size <=
TensorFusionThresholdBytes()) {
// Skip response and look ahead for more to fuse.
skipped_responses.push_back(std::move(responses.front()));
responses.pop_front();
} else {
break;
}
}
}
// Replace any skipped responses.
while (!skipped_responses.empty()) {
responses.push_front(std::move(skipped_responses.back()));
skipped_responses.pop_back();
}
}
response_list.add_response(response);
LOG(DEBUG) << "Created response of size " << tensor_size;
}
return response_list;
}
int64_t Controller::TotalByteSizeOfAllgatherOutput(
const std::vector<int64_t>& tensor_sizes, const TensorTableEntry& entry) {
int64_t total_dimension_size = 0;
for (auto sz : tensor_sizes) {
total_dimension_size += sz;
}
// Every tensor participating in Allgather operation may have
// different first dimension size, but the rest of dimensions are same
// for all tensors. Here we get shape of tensor sliced by first
// dimension. Allgather output will have shape of: (sum of first
// dimension of every tensor) x (tensor slice shape).
int64_t total_count_of_output_entries = total_dimension_size;
for (int i = 1; i < entry.tensor->shape().dims(); ++i) {
total_count_of_output_entries *= entry.tensor->shape().dim_size(i);
}
int element_size = GetTypeSize(entry.tensor->dtype());
int64_t total_byte_size_of_output =
total_count_of_output_entries * element_size;
return total_byte_size_of_output;
}
int Controller::GetLocalSizeAtCrossRank(int i) {
return local_sizes_for_cross_rank_[i];
}
bool Controller::IncrementTensorCount(const Request& msg) {
auto& name = msg.tensor_name();
auto table_iter = message_table_.find(name);
if (table_iter == message_table_.end()) {
std::vector<Request> messages = {msg};
messages.reserve(static_cast<unsigned long>(size_));
message_table_.emplace(name, std::move(messages));
table_iter = message_table_.find(name);
timeline_.NegotiateStart(name, msg.request_type());
} else {
std::vector<Request>& messages = table_iter->second;
messages.push_back(msg);
}
timeline_.NegotiateRankReady(name, msg.request_rank());
std::vector<Request>& messages = table_iter->second;
int count = (int)messages.size();
bool ready_to_reduce = count == size_;
if (ready_to_reduce) {
timeline_.NegotiateEnd(name);
}
return ready_to_reduce;
}
} // namespace common
} // namespace horovod