Skip to content

Commit

Permalink
Remove unnecessary calls to iota. (#6797)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 31, 2021
1 parent 79b8b56 commit 138fe85
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::Iota(d_sorted_idx, device);
auto d_predts_t = dh::ToSpan(cache->predts_t);
Transpose(predts, d_predts_t, n_samples, n_classes, device);

Expand All @@ -231,6 +229,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
});
// no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
// use transform iterator in sorting.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);

/**
Expand Down Expand Up @@ -447,10 +446,9 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/**
* Sort the labels
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_labels = info.labels_.ConstDeviceSpan();

dh::Iota(d_sorted_idx, device);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx);

auto d_weights = info.weights_.ConstDeviceSpan();
Expand Down

0 comments on commit 138fe85

Please sign in to comment.