Skip to content

Commit

Permalink
Calculate nonzero count inside nonzero op (#260)
Browse files Browse the repository at this point in the history
* Calculate output shape inside nonzero op

* nonzero optimizations

* Fix lintrunner
  • Loading branch information
DenisVieriu97 authored and kulinseth committed Feb 5, 2023
1 parent d3b5cc2 commit dce8fe9
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
threadsPerThreadgroup: threadGroupSize];

[computeEncoder endEncoding];
mpsStream->commit(true);
mpsStream->synchronize(SyncType::COMMIT);
}
});

Expand Down Expand Up @@ -251,31 +251,24 @@ Tensor nonzero_fallback(const Tensor& self) {
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* scatterDataTensor_ = nil;
MPSGraphTensor* countNonzeroTensor_ = nil;
};

int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
at::native::resize_output(out_, {total_nonzero, nDim});
if (out_.numel() == 0) {
return out_;
}

bool contiguous_output = (out_.is_contiguous() && !out_.is_view());
Tensor out = out_;
if (!contiguous_output) {
out = at::native::empty_mps(
out_.sizes(),
stream->synchronize(SyncType::COMMIT_AND_WAIT);
Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt));
Tensor out = at::native::empty_mps(
{self.numel(), nDim == 0 ? 1 : nDim},
out_.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
}

int64_t _apparentInputShape = 1;
for (auto dim : self.sizes()) {
_apparentInputShape *= dim;
}
MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)];
MPSShape *apparentOutputShape = @[@(self.numel() * nDim)];
MPSShape *apparentInputShape = @[@(_apparentInputShape)];

// Pseudocode:
Expand Down Expand Up @@ -309,6 +302,9 @@ Tensor nonzero_fallback(const Tensor& self) {
MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor
axis:0
name:nil];
MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
toType:MPSDataTypeInt32
name:@"castToInt32"];
Expand Down Expand Up @@ -357,15 +353,17 @@ Tensor nonzero_fallback(const Tensor& self) {
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->scatterDataTensor_ = scatterDataTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->countNonzeroTensor_ = countNonzero;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape);
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape);
Placeholder countNonzeroPlaceholder = Placeholder(cachedGraph->countNonzeroTensor_, count_nonzero);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape);
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
Expand All @@ -374,15 +372,16 @@ Tensor nonzero_fallback(const Tensor& self) {
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
countNonzeroPlaceholder.getMPSGraphTensor() : countNonzeroPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (!contiguous_output) {
out_.copy_(out);
}
}

int32_t total_nonzero = count_nonzero.item<int32_t>();
at::native::resize_output(out_, {total_nonzero, nDim});
out_.copy_(out.resize_({total_nonzero, nDim}));
return out_;
}

Expand Down

0 comments on commit dce8fe9

Please sign in to comment.