Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/core/mapping/base_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ void BaseMapper::map_task(const MapperContext ctx,

auto mappings = store_mappings(legate_task, options);

std::map<RegionField::Id, uint32_t> client_mapped;
std::map<RegionField::Id, uint32_t> client_mapped_regions;
std::map<uint32_t, uint32_t> client_mapped_futures;
for (uint32_t mapping_idx = 0; mapping_idx < mappings.size(); ++mapping_idx) {
auto& mapping = mappings[mapping_idx];

Expand All @@ -549,15 +550,19 @@ void BaseMapper::map_task(const MapperContext ctx,
}

for (auto& store : mapping.stores) {
if (store.is_future()) continue;
if (store.is_future()) {
auto fut_idx = store.future().index();
client_mapped_futures[fut_idx] = mapping_idx;
continue;
}

auto& rf = store.region_field();
auto key = rf.unique_id();

auto finder = client_mapped.find(key);
auto finder = client_mapped_regions.find(key);
// If this is the first store mapping for this requirement,
// we record the mapping index for future reference.
if (finder == client_mapped.end()) client_mapped[key] = mapping_idx;
if (finder == client_mapped_regions.end()) client_mapped_regions[key] = mapping_idx;
// If we're still in the same store mapping, we know for sure
// that the mapping is consistent.
else {
Expand All @@ -576,11 +581,17 @@ void BaseMapper::map_task(const MapperContext ctx,
auto default_option = options.front();
auto generate_default_mappings = [&](auto& stores, bool exact) {
for (auto& store : stores) {
if (store.is_future()) continue;
auto key = store.region_field().unique_id();
if (client_mapped.find(key) != client_mapped.end()) continue;
client_mapped[key] = static_cast<int32_t>(mappings.size());
mappings.push_back(StoreMapping::default_mapping(store, default_option, exact));
if (store.is_future()) {
auto fut_idx = store.future().index();
if (client_mapped_futures.find(fut_idx) == client_mapped_futures.end())
mappings.push_back(StoreMapping::default_mapping(store, default_option, exact));
continue;
} else {
auto key = store.region_field().unique_id();
if (client_mapped_regions.find(key) != client_mapped_regions.end()) continue;
client_mapped_regions[key] = static_cast<int32_t>(mappings.size());
mappings.push_back(StoreMapping::default_mapping(store, default_option, exact));
}
}
};

Expand All @@ -599,7 +610,11 @@ void BaseMapper::map_task(const MapperContext ctx,

if (req_indices.empty()) continue;

if (mapping.for_unbound_stores()) {
if (req_indices.empty()) {
// This is a mapping for futures
output.future_locations.push_back(get_target_memory(task.target_proc, mapping.policy.target));
continue;
} else if (mapping.for_unbound_stores()) {
for (auto req_idx : req_indices) {
output.output_targets[req_idx] = get_target_memory(task.target_proc, mapping.policy.target);
auto ndim = mapping.stores.front().dim();
Expand Down
8 changes: 7 additions & 1 deletion src/core/mapping/task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ IndexSpace RegionField::get_index_space() const
return get_requirement().region.get_index_space();
}

FutureWrapper::FutureWrapper(const Domain& domain) : domain_(domain) {}
FutureWrapper::FutureWrapper(uint32_t idx, const Domain& domain) : idx_(idx), domain_(domain) {}

Domain FutureWrapper::domain() const { return domain_; }

Expand Down Expand Up @@ -105,6 +105,12 @@ const RegionField& Store::region_field() const
return region_field_;
}

const FutureWrapper& Store::future() const
{
assert(is_future());
return future_;
}

Domain Store::domain() const
{
assert(!unbound());
Expand Down
5 changes: 4 additions & 1 deletion src/core/mapping/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,23 @@ class RegionField {
class FutureWrapper {
public:
FutureWrapper() {}
FutureWrapper(const Legion::Domain& domain);
FutureWrapper(uint32_t idx, const Legion::Domain& domain);

public:
FutureWrapper(const FutureWrapper& other) = default;
FutureWrapper& operator=(const FutureWrapper& other) = default;

public:
int32_t dim() const { return domain_.dim; }
uint32_t index() const { return idx_; }

public:
template <int32_t DIM>
Legion::Rect<DIM> shape() const;
Legion::Domain domain() const;

private:
uint32_t idx_{-1U};
Legion::Domain domain_{};
};

Expand Down Expand Up @@ -126,6 +128,7 @@ class Store {
public:
bool can_colocate_with(const Store& other) const;
const RegionField& region_field() const;
const FutureWrapper& future() const;

public:
template <int32_t DIM>
Expand Down
4 changes: 2 additions & 2 deletions src/core/utilities/deserializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ namespace mapping {
MapperDeserializer::MapperDeserializer(const LegionTask* task,
MapperRuntime* runtime,
MapperContext context)
: BaseDeserializer(task), runtime_(runtime), context_(context)
: BaseDeserializer(task), runtime_(runtime), context_(context), future_index_(0)
{
first_task_ = false;
}
Expand Down Expand Up @@ -163,7 +163,7 @@ void MapperDeserializer::_unpack(FutureWrapper& value)
domain.rect_data[idx + domain.dim] = point[idx] - 1;
}

value = FutureWrapper(domain);
value = FutureWrapper(future_index_++, domain);
}

void MapperDeserializer::_unpack(RegionField& value, bool is_output_region)
Expand Down
1 change: 1 addition & 0 deletions src/core/utilities/deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class MapperDeserializer : public BaseDeserializer<MapperDeserializer> {
private:
Legion::Mapping::MapperRuntime* runtime_;
Legion::Mapping::MapperContext context_;
uint32_t future_index_;
};

} // namespace mapping
Expand Down