Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] add eigh method #13166

Merged
merged 3 commits into from Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion hail/python/hail/ir/__init__.py
Expand Up @@ -8,7 +8,7 @@
Void, Cast, NA, IsNA, If, Coalesce, Let, AggLet, Ref, TopLevelReference, ProjectedTopLevelReference, SelectedTopLevelReference, \
TailLoop, Recur, ApplyBinaryPrimOp, ApplyUnaryPrimOp, ApplyComparisonOp, \
MakeArray, ArrayRef, ArraySlice, ArrayLen, ArrayZeros, StreamIota, StreamRange, StreamGrouped, MakeNDArray, \
NDArrayShape, NDArrayReshape, NDArrayMap, NDArrayMap2, NDArrayRef, NDArraySlice, NDArraySVD, \
NDArrayShape, NDArrayReshape, NDArrayMap, NDArrayMap2, NDArrayRef, NDArraySlice, NDArraySVD, NDArrayEigh, \
NDArrayReindex, NDArrayAgg, NDArrayMatMul, NDArrayQR, NDArrayInv, NDArrayConcat, NDArrayWrite, \
ArraySort, ArrayMaximalIndependentSet, ToSet, ToDict, toArray, ToArray, CastToArray, \
ToStream, toStream, LowerBoundOnOrderedCollection, GroupByKey, StreamMap, StreamZip, StreamTake, \
Expand Down Expand Up @@ -158,6 +158,7 @@
'NDArrayAgg',
'NDArrayMatMul',
'NDArrayQR',
'NDArrayEigh',
'NDArraySVD',
'NDArrayInv',
'NDArrayConcat',
Expand Down
25 changes: 25 additions & 0 deletions hail/python/hail/ir/ir.py
Expand Up @@ -1136,6 +1136,31 @@ def _compute_type(self, env, agg_env, deep_typecheck):
return tndarray(tfloat64, 1)


class NDArrayEigh(IR):
@typecheck_method(nd=IR, eigvals_only=bool, error_id=nullable(int), stack_trace=nullable(str))
def __init__(self, nd, eigvals_only=False, error_id=None, stack_trace=None):
super().__init__(nd)
self.nd = nd
self.eigvals_only = eigvals_only
self._error_id = error_id
self._stack_trace = stack_trace
if error_id is None or stack_trace is None:
self.save_error_info()

def copy(self):
return NDArrayEigh(self.nd, self.eigvals_only, self._error_id, self._stack_trace)

def head_str(self):
return f'{self._error_id} {self.eigvals_only}'

def _compute_type(self, env, agg_env, deep_typecheck):
self.nd.compute_type(env, agg_env, deep_typecheck)
if self.eigvals_only:
return tndarray(tfloat64, 1)
else:
return ttuple(tndarray(tfloat64, 1), tndarray(tfloat64, 2))


class NDArrayInv(IR):
@typecheck_method(nd=IR, error_id=nullable(int), stack_trace=nullable(str))
def __init__(self, nd, error_id=None, stack_trace=None):
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/nd/__init__.py
@@ -1,9 +1,9 @@
from .nd import array, from_column_major, arange, full, zeros, ones, svd, qr, solve, solve_triangular, diagonal, inv, concatenate, \
from .nd import array, from_column_major, arange, full, zeros, ones, svd, eigh, qr, solve, solve_triangular, diagonal, inv, concatenate, \
eye, identity, vstack, hstack, maximum, minimum

newaxis = None

__all__ = [
'array', 'from_column_major', 'arange', 'full', 'zeros', 'ones', 'qr', 'solve', 'solve_triangular', 'svd', 'diagonal', 'inv',
'array', 'from_column_major', 'arange', 'full', 'zeros', 'ones', 'qr', 'solve', 'solve_triangular', 'svd', 'eigh', 'diagonal', 'inv',
'concatenate', 'eye', 'identity', 'vstack', 'hstack', 'newaxis', 'maximum', 'minimum'
]
27 changes: 26 additions & 1 deletion hail/python/hail/nd/nd.py
Expand Up @@ -10,7 +10,7 @@
expr_numeric, Int64Expression, cast_expr, construct_expr, expr_bool,
unify_all)
from hail.expr.expressions.typed_expressions import NDArrayNumericExpression
from hail.ir import NDArrayQR, NDArrayInv, NDArrayConcat, NDArraySVD, Apply
from hail.ir import NDArrayQR, NDArrayInv, NDArrayConcat, NDArraySVD, NDArrayEigh, Apply


tsequenceof_nd = oneof(sequenceof(expr_ndarray()), expr_array(expr_ndarray()))
Expand Down Expand Up @@ -426,6 +426,31 @@ def svd(nd, full_matrices=True, compute_uv=True):
return construct_expr(ir, return_type, nd._indices, nd._aggregations)


@typecheck(nd=expr_ndarray(), eigvals_only=bool)
def eigh(nd, eigvals_only=False):
"""Performs an eigenvalue decomposition of a symmetric matrix.

Parameters
----------
nd : :class:`.NDArrayNumericExpression`
A 2 dimensional ndarray, shape(N, N).
eigvals_only: :class:`.bool`
If False (default), compute the eigenvectors and eigenvalues. Otherwise, only compute eigenvalues.

Returns
-------
- w: :class:`.NDArrayNumericExpression`
The eigenvalues, shape(N).
- v: :class:`.NDArrayNumericExpression`
The eigenvectors, shape(N, N). Only returned if eigvals_only is false.
patrick-schultz marked this conversation as resolved.
Show resolved Hide resolved
"""
float_nd = nd.map(lambda x: hl.float64(x))
ir = NDArrayEigh(float_nd._ir, eigvals_only)

return_type = tndarray(tfloat64, 1) if eigvals_only else ttuple(tndarray(tfloat64, 1), tndarray(tfloat64, 2))
return construct_expr(ir, return_type, nd._indices, nd._aggregations)


@typecheck(nd=expr_ndarray())
def inv(nd):
"""Performs a matrix inversion.
Expand Down
31 changes: 31 additions & 0 deletions hail/python/test/hail/expr/test_ndarrays.py
Expand Up @@ -942,6 +942,37 @@ def assert_evals_to_same_svd(nd_expr, np_array, full_matrices=True, compute_uv=T
assert_evals_to_same_svd(rank_2_tall_rectangle, np_rank_2_tall_rectangle, full_matrices=False)


def test_eigh():
def assert_evals_to_same_eig(nd_expr, np_array, eigvals_only=True):
evaled = hl.eval(hl.nd.eigh(nd_expr, eigvals_only))
np_eig = np.linalg.eigvalsh(np_array)

# check shapes
for h, n in zip(evaled, np_eig):
assert h.shape == n.shape

if eigvals_only:
np.testing.assert_array_almost_equal(evaled, np_eig)
else:
he, hv = evaled

# eigvals match
np.testing.assert_array_almost_equal(he, np_eig)

# V is orthonormal
vvt = hv @ hv.T
np.testing.assert_array_almost_equal(vvt, np.identity(vvt.shape[0]))

# V is eigenvectors
np.testing.assert_array_almost_equal(np_array @ hv, hv * he)

A = np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]])
hA = hl.nd.array(A)

assert_evals_to_same_eig(hA, A)
assert_evals_to_same_eig(hA, A, eigvals_only=True)


def test_numpy_interop():
v = [2, 3]
w = [3, 5]
Expand Down
15 changes: 15 additions & 0 deletions hail/src/main/scala/is/hail/asm4s/Code.scala
Expand Up @@ -342,6 +342,21 @@ object Code {
a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass, a12ct.runtimeClass, a13ct.runtimeClass, a14ct.runtimeClass, a15ct.runtimeClass, a16ct.runtimeClass),
Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16))

def invokeScalaObject19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, S](
patrick-schultz marked this conversation as resolved.
Show resolved Hide resolved
cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8],
a9: Code[A9], a10: Code[A10], a11: Code[A11], a12: Code[A12], a13: Code[A13], a14: Code[A14], a15: Code[A15], a16: Code[A16],
a17: Code[A17], a18: Code[A18], a19: Code[A19])(
implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7],
a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], a12ct: ClassTag[A12], a13ct: ClassTag[A13], a14ct: ClassTag[A14],
a15ct: ClassTag[A15], a16ct: ClassTag[A16], a17ct: ClassTag[A17], a18ct: ClassTag[A18], a19ct: ClassTag[A19], sct: ClassTag[S]): Code[S] =
invokeScalaObject[S](
cls, method,
Array[Class[_]](
a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass,
a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass, a12ct.runtimeClass, a13ct.runtimeClass, a14ct.runtimeClass, a15ct.runtimeClass, a16ct.runtimeClass,
a17ct.runtimeClass, a18ct.runtimeClass, a19ct.runtimeClass),
Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19))

def invokeStatic[S](cls: Class[_], method: String, parameterTypes: Array[Class[_]], args: Array[Code[_]])(implicit sct: ClassTag[S]): Code[S] = {
val m = Invokeable.lookupMethod(cls, method, parameterTypes)(sct)
assert(m.isStatic)
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Children.scala
Expand Up @@ -170,6 +170,8 @@ object Children {
Array(nd)
case NDArraySVD(nd, _, _, _) =>
Array(nd)
case NDArrayEigh(nd, _, _) =>
Array(nd)
case NDArrayInv(nd, errorID) =>
Array(nd)
case NDArrayWrite(nd, path) =>
Expand Down
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Expand Up @@ -137,6 +137,9 @@ object Copy {
case NDArraySVD(_, fullMatrices, computeUV, errorID) =>
assert(newChildren.length == 1)
NDArraySVD(newChildren(0).asInstanceOf[IR], fullMatrices, computeUV, errorID)
case NDArrayEigh(_, eigvalsOnly, errorID) =>
assert(newChildren.length == 1)
NDArrayEigh(newChildren(0).asInstanceOf[IR], eigvalsOnly, errorID)
case NDArrayInv(_, errorID) =>
assert(newChildren.length == 1)
NDArrayInv(newChildren(0).asInstanceOf[IR], errorID)
Expand Down
29 changes: 29 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Expand Up @@ -1874,6 +1874,35 @@ class Emit[C](

}

case NDArrayEigh(nd, eigvalsOnly, errorID) =>
emitNDArrayColumnMajorStrides(nd).map(cb) { case mat: SNDArrayValue =>
val n = mat.shapes(0)
val jobz = if (eigvalsOnly) "N" else "V"
val (workSize, iWorkSize) = SNDArray.syevr_query(cb, jobz, "U", cb.memoize(n.toI), region)

val matType = PCanonicalNDArray(PFloat64Required, 2)
val vecType = PCanonicalNDArray(PFloat64Required, 1)
val intVecType = PCanonicalNDArray(PInt32Required, 1)

val W = vecType.constructUninitialized(FastIndexedSeq(n), cb, region)
val work = vecType.constructUninitialized(FastIndexedSeq(SizeValueDyn(workSize)), cb, region)
val iWork = intVecType.constructUninitialized(FastIndexedSeq(iWorkSize), cb, region)

if (eigvalsOnly) {
SNDArray.syevr(cb, "U", mat, W, None, work, iWork)

W
} else {
val resultType = NDArrayEigh.pTypes(false, false).asInstanceOf[PCanonicalTuple]
val Z = matType.constructUninitialized(FastIndexedSeq(n, n), cb, region)
val iSuppZ = vecType.constructUninitialized(FastIndexedSeq(SizeValueDyn(cb.memoize(n * 2))), cb, region)

SNDArray.syevr(cb, "U", mat, W, Some((Z, iSuppZ)), work, iWork)

resultType.constructFromFields(cb, region, FastIndexedSeq(EmitCode.present(cb.emb, W), EmitCode.present(cb.emb, Z)), false)
}
}

case x@NDArrayQR(nd, mode, errorID) =>
// See here to understand different modes: https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.qr.html
emitNDArrayColumnMajorStrides(nd).map(cb) { case pndValue: SNDArrayValue =>
Expand Down
11 changes: 11 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Expand Up @@ -521,6 +521,17 @@ final case class NDArrayQR(nd: IR, mode: String, errorID: Int) extends IR

final case class NDArraySVD(nd: IR, fullMatrices: Boolean, computeUV: Boolean, errorID: Int) extends IR

object NDArrayEigh {
def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = {
if (eigvalsOnly) {
PCanonicalNDArray(PFloat64Required, 1, req)
} else {
PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 1, true), PCanonicalNDArray(PFloat64Required, 2, true))
}
}
}
final case class NDArrayEigh(nd: IR, eigvalsOnly: Boolean, errorID: Int) extends IR

final case class NDArrayInv(nd: IR, errorID: Int) extends IR

final case class AggFilter(cond: IR, aggIR: IR, isScan: Boolean) extends IR
Expand Down
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/InferType.scala
Expand Up @@ -222,6 +222,12 @@ object InferType {
} else {
TNDArray(TFloat64, Nat(1))
}
case NDArrayEigh(nd, eigvalsOnly, _) =>
if (eigvalsOnly) {
TNDArray(TFloat64, Nat(1))
} else {
TTuple(TNDArray(TFloat64, Nat(1)), TNDArray(TFloat64, Nat(2)))
}
case NDArrayInv(_, _) =>
TNDArray(TFloat64, Nat(2))
case NDArrayWrite(_, _) => TVoid
Expand Down
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Expand Up @@ -1079,6 +1079,12 @@ object IRParser {
ir_value_expr(env)(it).map { nd =>
NDArraySVD(nd, fullMatrices, computeUV, errorID)
}
case "NDArrayEigh" =>
val errorID = int32_literal(it)
val eigvalsOnly = boolean_literal(it)
ir_value_expr(env)(it).map { nd =>
NDArrayEigh(nd, eigvalsOnly, errorID)
}
case "NDArrayInv" =>
val errorID = int32_literal(it)
ir_value_expr(env)(it).map{ nd => NDArrayInv(nd, errorID) }
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Expand Up @@ -267,6 +267,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
case NDArrayMatMul(_, _, errorID) => single(s"$errorID")
case NDArrayQR(_, mode, errorID) => FastSeq(errorID.toString, mode)
case NDArraySVD(_, fullMatrices, computeUV, errorID) => FastSeq(errorID.toString, fullMatrices.toString, computeUV.toString)
case NDArrayEigh(_, eigvalsOnly, errorID) => FastSeq(errorID.toString, eigvalsOnly.toString)
case NDArrayInv(_, errorID) => single(s"$errorID")
case ArraySort(_, l, r, _) if !elideBindings => FastSeq(prettyIdentifier(l), prettyIdentifier(r))
case ArrayRef(_,_, errorID) => single(s"$errorID")
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Expand Up @@ -687,6 +687,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
requiredness.union(lookup(r).required)
case NDArrayQR(child, mode, _) => requiredness.fromPType(NDArrayQR.pType(mode, lookup(child).required))
case NDArraySVD(child, _, computeUV, _) => requiredness.fromPType(NDArraySVD.pTypes(computeUV, lookup(child).required))
case NDArrayEigh(child, eigvalsOnly, _) => requiredness.fromPType(NDArrayEigh.pTypes(eigvalsOnly, lookup(child).required))
case NDArrayInv(child, _) => requiredness.unionFrom(lookup(child))
case MakeStruct(fields) =>
fields.foreach { case (n, f) =>
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Expand Up @@ -268,6 +268,10 @@ object TypeCheck {
val ndType = nd.typ.asInstanceOf[TNDArray]
assert(ndType.elementType == TFloat64)
assert(ndType.nDims == 2)
case x@NDArrayEigh(nd, _, _) =>
val ndType = nd.typ.asInstanceOf[TNDArray]
assert(ndType.elementType == TFloat64)
assert(ndType.nDims == 2)
case x@NDArrayInv(nd, _) =>
val ndType = nd.typ.asInstanceOf[TNDArray]
assert(ndType.elementType == TFloat64)
Expand Down
23 changes: 21 additions & 2 deletions hail/src/main/scala/is/hail/linalg/LAPACK.scala
Expand Up @@ -2,9 +2,8 @@ package is.hail.linalg

import java.lang.reflect.Method
import java.util.function._

import com.sun.jna.{FunctionMapper, Library, Native, NativeLibrary}
import com.sun.jna.ptr.IntByReference
import com.sun.jna.ptr.{IntByReference, DoubleByReference}

import scala.util.{Failure, Success, Try}
import is.hail.utils._
Expand Down Expand Up @@ -200,6 +199,25 @@ object LAPACK {
INFOref.getValue()
}

def dsyevr(jobz: String, range: String, uplo: String, n: Int, A: Long, ldA: Int, vl: Double, vu: Double, il: Int, iu: Int, abstol: Double, W: Long, Z: Long, ldZ: Int, ISuppZ: Long, Work: Long, lWork: Int, IWork: Long, lIWork: Int): Int = {
val nRef = new IntByReference(n)
val ldARef = new IntByReference(ldA)
val vlRef = new DoubleByReference(vl)
val vuRef = new DoubleByReference(vu)
val ilRef = new IntByReference(il)
val iuRef = new IntByReference(iu)
val abstolRef = new DoubleByReference(abstol)
val ldZRef = new IntByReference(ldZ)
val lWorkRef = new IntByReference(lWork)
val lIWorkRef = new IntByReference(lIWork)
val INFOref = new IntByReference(1)
val mRef = new IntByReference(0)

libraryInstance.get.dsyevr(jobz, range, uplo, nRef, A, ldARef, vlRef, vuRef, ilRef, iuRef, abstolRef, mRef, W, Z, ldZRef, ISuppZ, Work, lWorkRef, IWork, lIWorkRef, INFOref)

INFOref.getValue()
}

def dtrtrs(UPLO: String, TRANS: String, DIAG: String, N: Int, NRHS: Int,
A: Long, LDA: Int, B: Long, LDB: Int): Int = {
val Nref = new IntByReference(N)
Expand Down Expand Up @@ -254,6 +272,7 @@ trait LAPACKLibrary extends Library {
def dgetrf(M: IntByReference, N: IntByReference, A: Long, LDA: IntByReference, IPIV: Long, INFO: IntByReference)
def dgetri(N: IntByReference, A: Long, LDA: IntByReference, IPIV: Long, WORK: Long, LWORK: IntByReference, INFO: IntByReference)
def dgesdd(JOBZ: String, M: IntByReference, N: IntByReference, A: Long, LDA: IntByReference, S: Long, U: Long, LDU: IntByReference, VT: Long, LDVT: IntByReference, WORK: Long, LWORK: IntByReference, IWORK: Long, INFO: IntByReference)
def dsyevr(jobz: String, range: String, uplo: String, n: IntByReference, A: Long, ldA: IntByReference, vl: DoubleByReference, vu: DoubleByReference, il: IntByReference, iu: IntByReference, abstol: DoubleByReference, m: IntByReference, W: Long, Z: Long, ldZ: IntByReference, ISuppZ: Long, Work: Long, lWork: IntByReference, IWork: Long, lIWork: IntByReference, info: IntByReference)
def ilaver(MAJOR: IntByReference, MINOR: IntByReference, PATCH: IntByReference)
def ilaenv(ispec: IntByReference, name: String, opts: String, n1: IntByReference, n2: IntByReference, n3: IntByReference, n4: IntByReference): Int
def dtrtrs(UPLO: String, TRANS: String, DIAG: String, N: IntByReference, NRHS: IntByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, INFO:IntByReference)
Expand Down