Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate nonzero count inside nonzero op #260

Merged
merged 3 commits into from
Jan 27, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
threadsPerThreadgroup: threadGroupSize];

[computeEncoder endEncoding];
mpsStream->commit(true);
mpsStream->synchronize(SyncType::COMMIT);
}
});

Expand Down Expand Up @@ -252,31 +252,43 @@ Tensor nonzero_fallback(const Tensor& self) {
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* scatterDataTensor_ = nil;
MPSGraphTensor* countNonzeroTensor_ = nil;
};

int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
at::native::resize_output(out_, {total_nonzero, nDim});
if (out_.numel() == 0) {
return out_;
}

bool contiguous_output = (out_.is_contiguous() && !out_.is_view());
Tensor out = out_;
if (!contiguous_output) {
out = at::native::empty_mps(
out_.sizes(),
// int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to upstream this block of commented code? If not, I suggest you remove it and keep it locally.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - yes, this needs to be removed (part of old code). I'll update

// auto self_cpu = self.detach().clone().cpu();
// int64_t total_nonzero_cpu = at::count_nonzero(self_cpu).item<int64_t>();

// std::string s;
// if (total_nonzero != total_nonzero_cpu) {
// s = std::to_string(total_nonzero) + string(" ") + std::to_string(total_nonzero_cpu) + " self(";
// for (int i = 0; i < self.sizes().size(); i++) {
// s += std::to_string(self.sizes()[i]) + ",";
// }
// s += ") ";
// s += self.is_view() ? "view " : "not_view ";
// s += self.is_contiguous() ? "contg" : "not_contg ";
// std::cout << s << std::endl;
// TORCH_CHECK(total_nonzero == total_nonzero_cpu, s.c_str());
// }



stream->synchronize(SyncType::COMMIT_AND_WAIT);
Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt));
Tensor out = at::native::empty_mps(
{self.numel(), nDim == 0 ? 1 : nDim},
out_.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
}

int64_t _apparentInputShape = 1;
for (auto dim : self.sizes()) {
_apparentInputShape *= dim;
}
MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)];
MPSShape *apparentOutputShape = @[@(self.numel() * nDim)];
MPSShape *apparentInputShape = @[@(_apparentInputShape)];

// Pseudocode:
Expand Down Expand Up @@ -310,6 +322,11 @@ Tensor nonzero_fallback(const Tensor& self) {
MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];

MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor
axis:0
name:nil];

MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
toType:MPSDataTypeInt32
name:@"castToInt32"];
Expand Down Expand Up @@ -358,15 +375,17 @@ Tensor nonzero_fallback(const Tensor& self) {
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->scatterDataTensor_ = scatterDataTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->countNonzeroTensor_ = countNonzero;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape);
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape);
Placeholder countNonzeroPlaceholder = Placeholder(cachedGraph->countNonzeroTensor_, count_nonzero);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape);
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
Expand All @@ -375,15 +394,16 @@ Tensor nonzero_fallback(const Tensor& self) {
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
countNonzeroPlaceholder.getMPSGraphTensor() : countNonzeroPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (!contiguous_output) {
out_.copy_(out);
}
}

int32_t total_nonzero = count_nonzero.item<int32_t>();
Copy link

@razarmehr razarmehr Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will incur two hard-syncs, one from item, and one for the following out_.copy_().
Later (not now), I suggest we create a dedicated MPSGraph for this part. We pre-allocate out_ with the same size of self (so we don't overflow the buffer when resizing), and do the zero-counting and resizing of output in a single MPSGraph op. We can discuss that later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We pre-allocate out_ with the same size of self (so we don't overflow the buffer when resizing)

Doing that might allocate a new buffer and change the pointer of the out buffer (causing a failure in the test).
E.g in case the user has a pre-allocated buffer from a previous nonzero op (and they know the exact number of nonzeros) doing again nonzero(input, out=preallocated_out) and resizing the output at the beginning of the function to match the input, could allocate new memory if input's number of elements is larger than output's number of elements (resize_].
Previously the op it was calling into count_nonzero at the beggining of the function to get the number of elements, this change makes it to get the number of elements from the same graph as nonzero (seemed a little bit faster when testing compared to previous method).

and do the zero-counting and resizing of output in a single MPSGraph op. We can discuss that later.

We can do it in the graph (both nonzero and count_nonzero happen in the same graph now), but the returned Tensor's shape we've preallocated in the beginning (not the MPSGraphTensor*) would still be wrong and we'd need to sync at the end (the .item() part) to get the number of nonzeros and resize it correctly. And if we've preallocated the output at the beginning it would hit the issue from above (it would work for 99% of the tests to preallocate output in the beginning, but would fail were they're passing a preallocated output to us)

at::native::resize_output(out_, {total_nonzero, nDim});
out_.copy_(out.resize_({total_nonzero, nDim}));
return out_;
}

Expand Down