Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] TNDArray shouldn't store strides, fix show #9641

Merged
merged 4 commits into from Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions hail/python/hail/expr/types.py
Expand Up @@ -656,12 +656,12 @@ def _parsable_string(self):
def _convert_from_json(self, x):
if is_numeric(self._element_type):
np_type = self.element_type.to_numpy()
return np.ndarray(shape=x['shape'], buffer=np.array(x['data'], dtype=np_type), strides=x['strides'], dtype=np_type)
return np.ndarray(shape=x['shape'], buffer=np.array(x['data'], dtype=np_type), dtype=np_type)
else:
raise TypeError("Hail cannot currently return ndarrays of non-numeric or boolean type.")

def _convert_to_json(self, x):
data = x.flatten("F").tolist()
data = x.flatten("C").tolist()

strides = []
axis_one_step_byte_size = x.itemsize
Expand All @@ -671,7 +671,6 @@ def _convert_to_json(self, x):

json_dict = {
"shape": x.shape,
"strides": strides,
"data": data
}
return json_dict
Expand Down
20 changes: 18 additions & 2 deletions hail/src/main/scala/is/hail/annotations/RegionValueBuilder.scala
Expand Up @@ -483,8 +483,24 @@ class RegionValueBuilder(var region: Region) {
addBoolean(i.includesStart)
addBoolean(i.includesEnd)
endStruct()
case t: TNDArray =>
addAnnotation(t.representation, a)
case t@TNDArray(elementType, _) =>
val structWithStrides = TStruct(
("shape", t.shapeType),
("strides", t.shapeType),
("data", TArray(elementType))
)
val ptype = currentType().asInstanceOf[PBaseStruct]
val shapeRow = a.asInstanceOf[Row](0).asInstanceOf[Row]
val shapeArray = shapeRow.toSeq.toIndexedSeq.map(x => x.asInstanceOf[Long])
var runningProduct = ptype.fieldType("data").asInstanceOf[PArray].elementType.byteSize
val stridesArray = new Array[Long](shapeArray.size)
((shapeArray.size - 1) to 0 by -1).foreach { i =>
stridesArray(i) = runningProduct
runningProduct = runningProduct * (if (shapeArray(i) > 0L) shapeArray(i) else 1L)
}
val stridesRow = Row(stridesArray:_*)

addAnnotation(structWithStrides, Row(shapeRow, stridesRow, a.asInstanceOf[Row](1)))
}
}

Expand Down
37 changes: 36 additions & 1 deletion hail/src/main/scala/is/hail/annotations/UnsafeRow.scala
Expand Up @@ -39,6 +39,30 @@ class UnsafeIndexedSeq(
override def toString: String = s"[${this.mkString(",")}]"
}

class UnsafeIndexedSeqRowMajorView(val wrapped: UnsafeIndexedSeq, shape: IndexedSeq[Long], strides: IndexedSeq[Long]) extends IndexedSeq[Annotation] {
val coordStorageArray = new Array[Long](shape.size)
val shapeProduct = shape.foldLeft(1L )(_ * _)
def apply(i: Int): Annotation = {
var workRemaining = i.toLong
var elementsInProcessedDimensions = shapeProduct

(0 until shape.size).foreach { dim =>
elementsInProcessedDimensions = elementsInProcessedDimensions / shape(dim)
coordStorageArray(dim) = workRemaining / elementsInProcessedDimensions
workRemaining = workRemaining % elementsInProcessedDimensions
}

val properIndex = (0 until shape.size).map(dim => coordStorageArray(dim) * strides(dim)).sum
if (properIndex > Int.MaxValue) {
throw new IllegalArgumentException("Index too large")
}

wrapped(properIndex.toInt)
}

override def length: Int = wrapped.length
}

object UnsafeRow {
def readBinary(boff: Long, t: PBinary): Array[Byte] =
t.loadBytes(boff)
Expand Down Expand Up @@ -94,7 +118,18 @@ object UnsafeRow {
val includesStart = x.includesStart(offset)
val includesEnd = x.includesEnd(offset)
Interval(start, end, includesStart, includesEnd)
case nd: PNDArray => read(nd.representation, region, offset)
case nd: PNDArray => {
val nDims = nd.nDims
val elementSize = nd.elementType.byteSize
val urWithStrides = read(nd.representation, region, offset).asInstanceOf[UnsafeRow]
val shapeRow = urWithStrides.get(0).asInstanceOf[UnsafeRow]
val shape = shapeRow.toSeq.map(x => x.asInstanceOf[Long]).toIndexedSeq
val strides = urWithStrides.get(1).asInstanceOf[UnsafeRow].toSeq.map(x => x.asInstanceOf[Long]).toIndexedSeq
val data = urWithStrides.get(2).asInstanceOf[UnsafeIndexedSeq]
val elementWiseStrides = (0 until nDims).map(i => strides(i) / elementSize)
val row = Row(shapeRow, new UnsafeIndexedSeqRowMajorView(data, shape, elementWiseStrides))
row
}
}
}
}
Expand Down
Expand Up @@ -246,7 +246,7 @@ case class RDict(keyType: TypeWithRequiredness, valueType: TypeWithRequiredness)
}
case class RNDArray(override val elementType: TypeWithRequiredness) extends RIterable(elementType, true) {
override def _unionLiteral(a: Annotation): Unit = {
val data = a.asInstanceOf[Row].getAs[Iterable[Any]](2)
val data = a.asInstanceOf[Row].getAs[Iterable[Any]](1)
data.asInstanceOf[Iterable[_]].foreach { elt =>
if (elt != null)
elementType.unionLiteral(elt)
Expand Down
3 changes: 1 addition & 2 deletions hail/src/main/scala/is/hail/types/virtual/TNDArray.scala
Expand Up @@ -48,7 +48,7 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type {
if (a == null) "NA" else {
val a_row = a.asInstanceOf[Row]
val shape = a_row(this.representation.fieldIdx("shape")).asInstanceOf[Row].toSeq.asInstanceOf[Seq[Long]].map(_.toInt)
val data = a_row(this.representation.fieldIdx("data")).asInstanceOf[UnsafeIndexedSeq]
val data = a_row(this.representation.fieldIdx("data")).asInstanceOf[IndexedSeq[Any]]

def dataToNestedString(data: Iterator[Annotation], shape: Seq[Int], sb: StringBuilder):Unit = {
if (shape.isEmpty) {
Expand Down Expand Up @@ -102,7 +102,6 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type {

lazy val representation = TStruct(
("shape", shapeType),
("strides", TTuple(Array.fill(nDims)(TInt64): _*)),
("data", TArray(elementType))
)
}
10 changes: 5 additions & 5 deletions hail/src/test/scala/is/hail/expr/ir/ETypeSuite.scala
Expand Up @@ -113,28 +113,28 @@ class ETypeSuite extends HailSuite {
@Test def testNDArrayEncodeDecode(): Unit = {
val pTypeInt0 = PCanonicalNDArray(PInt32Required, 0, true)
val eTypeInt0 = ENDArrayColumnMajor(EInt32Required, 0, true)
val dataInt0 = Row(Row(), Row(), FastIndexedSeq(0))
val dataInt0 = Row(Row(), FastIndexedSeq(0))

assertEqualEncodeDecode(pTypeInt0, eTypeInt0, pTypeInt0, dataInt0)

val pTypeFloat1 = PCanonicalNDArray(PFloat32Required, 1, true)
val eTypeFloat1 = ENDArrayColumnMajor(EFloat32Required, 1, true)
val dataFloat1 = Row(Row(5L), Row(4L), (0 until 5).map(_.toFloat))
val dataFloat1 = Row(Row(5L), (0 until 5).map(_.toFloat))

assertEqualEncodeDecode(pTypeFloat1, eTypeFloat1, pTypeFloat1, dataFloat1)

val pTypeInt2 = PCanonicalNDArray(PInt32Required, 2, true)
val eTypeInt2 = ENDArrayColumnMajor(EInt32Required, 2, true)
val dataInt2 = Row(Row(2L, 2L), Row(4L, 8L), FastIndexedSeq(10, 20, 30, 40))
val dataInt2 = Row(Row(2L, 2L), FastIndexedSeq(10, 20, 30, 40))

assertEqualEncodeDecode(pTypeInt2, eTypeInt2, pTypeInt2, dataInt2)

val pTypeDouble3 = PCanonicalNDArray(PFloat64Required, 3, false)
val eTypeDouble3 = ENDArrayColumnMajor(EFloat64Required, 3, false)
val dataDouble3 = Row(Row(3L, 2L, 1L), Row(16L, 8L, 8L), FastIndexedSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val dataDouble3 = Row(Row(3L, 2L, 1L), FastIndexedSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))

assert(encodeDecode(pTypeDouble3, eTypeDouble3, pTypeDouble3, dataDouble3) ==
Row(Row(3L, 2L, 1L), Row(8L, 24L, 48L), FastIndexedSeq(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)))
Row(Row(3L, 2L, 1L), FastIndexedSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))

// Test for skipping
val pStructContainingNDArray = PCanonicalStruct(true,
Expand Down
Expand Up @@ -10,7 +10,7 @@ import org.testng.annotations.Test
class PNDArraySuite extends PhysicalTestUtils {
@Test def copyTests() {
def runTests(deepCopy: Boolean, interpret: Boolean = false) {
copyTestExecutor(PCanonicalNDArray(PInt64(true), 1), PCanonicalNDArray(PInt64(true), 1), Annotation(Annotation(1L), Annotation(1L), IndexedSeq(4L,5L,6L)),
copyTestExecutor(PCanonicalNDArray(PInt64(true), 1), PCanonicalNDArray(PInt64(true), 1), Annotation(Annotation(1L), IndexedSeq(4L,5L,6L)),
deepCopy = deepCopy, interpret = interpret)
}

Expand Down