-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculate nonzero count inside nonzero op #260
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, | |
threadsPerThreadgroup: threadGroupSize]; | ||
|
||
[computeEncoder endEncoding]; | ||
mpsStream->commit(true); | ||
mpsStream->synchronize(SyncType::COMMIT); | ||
} | ||
}); | ||
|
||
|
@@ -252,31 +252,43 @@ Tensor nonzero_fallback(const Tensor& self) { | |
MPSGraphTensor* inputTensor_ = nil; | ||
MPSGraphTensor* outputTensor_ = nil; | ||
MPSGraphTensor* scatterDataTensor_ = nil; | ||
MPSGraphTensor* countNonzeroTensor_ = nil; | ||
}; | ||
|
||
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>(); | ||
at::native::resize_output(out_, {total_nonzero, nDim}); | ||
if (out_.numel() == 0) { | ||
return out_; | ||
} | ||
|
||
bool contiguous_output = (out_.is_contiguous() && !out_.is_view()); | ||
Tensor out = out_; | ||
if (!contiguous_output) { | ||
out = at::native::empty_mps( | ||
out_.sizes(), | ||
// int64_t total_nonzero = at::count_nonzero(self).item<int64_t>(); | ||
// auto self_cpu = self.detach().clone().cpu(); | ||
// int64_t total_nonzero_cpu = at::count_nonzero(self_cpu).item<int64_t>(); | ||
|
||
// std::string s; | ||
// if (total_nonzero != total_nonzero_cpu) { | ||
// s = std::to_string(total_nonzero) + string(" ") + std::to_string(total_nonzero_cpu) + " self("; | ||
// for (int i = 0; i < self.sizes().size(); i++) { | ||
// s += std::to_string(self.sizes()[i]) + ","; | ||
// } | ||
// s += ") "; | ||
// s += self.is_view() ? "view " : "not_view "; | ||
// s += self.is_contiguous() ? "contg" : "not_contg "; | ||
// std::cout << s << std::endl; | ||
// TORCH_CHECK(total_nonzero == total_nonzero_cpu, s.c_str()); | ||
// } | ||
|
||
|
||
|
||
stream->synchronize(SyncType::COMMIT_AND_WAIT); | ||
Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt)); | ||
Tensor out = at::native::empty_mps( | ||
{self.numel(), nDim == 0 ? 1 : nDim}, | ||
out_.scalar_type(), | ||
c10::nullopt, | ||
kMPS, | ||
c10::nullopt, | ||
c10::nullopt); | ||
} | ||
|
||
int64_t _apparentInputShape = 1; | ||
for (auto dim : self.sizes()) { | ||
_apparentInputShape *= dim; | ||
} | ||
MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)]; | ||
MPSShape *apparentOutputShape = @[@(self.numel() * nDim)]; | ||
MPSShape *apparentInputShape = @[@(_apparentInputShape)]; | ||
|
||
// Pseudocode: | ||
|
@@ -310,6 +322,11 @@ Tensor nonzero_fallback(const Tensor& self) { | |
MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor | ||
secondaryTensor:zeroTensor | ||
name:nil]; | ||
|
||
MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor | ||
axis:0 | ||
name:nil]; | ||
|
||
MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor | ||
toType:MPSDataTypeInt32 | ||
name:@"castToInt32"]; | ||
|
@@ -358,15 +375,17 @@ Tensor nonzero_fallback(const Tensor& self) { | |
newCachedGraph->inputTensor_ = inputTensor; | ||
newCachedGraph->scatterDataTensor_ = scatterDataTensor; | ||
newCachedGraph->outputTensor_ = outputTensor; | ||
newCachedGraph->countNonzeroTensor_ = countNonzero; | ||
} | ||
return newCachedGraph; | ||
}); | ||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph); | ||
} | ||
|
||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape); | ||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape); | ||
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape); | ||
Placeholder countNonzeroPlaceholder = Placeholder(cachedGraph->countNonzeroTensor_, count_nonzero); | ||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape); | ||
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape); | ||
|
||
// Create dictionary of inputs and outputs | ||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{ | ||
|
@@ -375,15 +394,16 @@ Tensor nonzero_fallback(const Tensor& self) { | |
}; | ||
|
||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{ | ||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() | ||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), | ||
countNonzeroPlaceholder.getMPSGraphTensor() : countNonzeroPlaceholder.getMPSGraphTensorData() | ||
}; | ||
|
||
runMPSGraph(stream, cachedGraph->graph(), feeds, results); | ||
if (!contiguous_output) { | ||
out_.copy_(out); | ||
} | ||
} | ||
|
||
int32_t total_nonzero = count_nonzero.item<int32_t>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will incur two hard-syncs, one from item, and one for the following There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Doing that might allocate a new buffer and change the pointer of the out buffer (causing a failure in the test).
We can do it in the graph (both nonzero and count_nonzero happen in the same graph now), but the returned Tensor's shape we've preallocated in the beginning (not the MPSGraphTensor*) would still be wrong and we'd need to sync at the end (the |
||
at::native::resize_output(out_, {total_nonzero, nDim}); | ||
out_.copy_(out.resize_({total_nonzero, nDim})); | ||
return out_; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to upstream this block of commented code? If not, I suggest you remove it and keep it locally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - yes, this needs to be removed (part of old code). I'll update