Skip to content

Commit

Permalink
scala wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jey committed May 26, 2015
1 parent 66575c7 commit 29ddf24
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
2 changes: 2 additions & 0 deletions cx_spark/build.sbt
Expand Up @@ -8,6 +8,8 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % "1.3.1"

libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.3.1"

libraryDependencies += "com.github.fommil.netlib" % "all" % "1.1.2"

lazy val submit = taskKey[Unit]("Submits Spark job")
submit <<= (Keys.`package` in Compile) map {
(jarFile: File) => s"spark-submit --verbose --driver-memory 64G ${jarFile}" !
Expand Down
76 changes: 54 additions & 22 deletions cx_spark/src/main/scala/CX.scala
Expand Up @@ -24,48 +24,80 @@ object CX {
fromBreeze(result)
}

def transposeMultiply(mat: IndexedRowMatrix, rhs: DenseMatrix): DenseMatrix = {
val rhsBrz = rhs.toBreeze.asInstanceOf[BDM[Double]]
val result =
mat.rows.treeAggregate(BDM.zeros[Double](mat.numCols.toInt, rhs.numCols))(
seqOp = (U: BDM[Double], row: IndexedRow) => {
val rowIdx = row.index.toInt
val rowBrz = row.vector.toBreeze.asInstanceOf[BSV[Double]]
// performs a rank-1 update:
// U += outer(row.vector, rhs(row.index, ::))
for(ipos <- 0 until rowBrz.length) {
val i = rowBrz.index(ipos)
val ival = rowBrz.data(ipos)
for(j <- 0 until rhs.numCols) {
U(i, j) += ival * rhsBrz(rowIdx, j)
}
}
U
},
combOp = (U1, U2) => U1 += U2
)
fromBreeze(result)
}

def gaussianProjection(mat: IndexedRowMatrix, rank: Int): IndexedRowMatrix = {
val rng = new java.util.Random
mat.multiply(DenseMatrix.randn(mat.numCols.toInt, rank, rng))
}

def main(args: Array[String]) {
val prefix = "hdfs:///"
val name = "Lewis_Dalisay_Peltatum_20131115_hexandrum_1_1-masked-100x100"
//val prefix = "hdfs:///"
//val name = "Lewis_Dalisay_Peltatum_20131115_hexandrum_1_1-masked-100x100"
val name = "Lewis_Dalisay_Peltatum_20131115_PDX_Std_1"
val prefix = s"/home/jey/proj/openmsi/data/2014Nov15_PDX_IMS_imzML/$name"
val inpath = s"$prefix/$name.mat.csv"
val conf = new SparkConf().setAppName("CX")
conf.setMaster("local[4]").set("spark.driver.memory", "8G")
val sc = new SparkContext(conf)

/* params */
val numIters = 2
val rank = 8 // rank of approximation
val slack = 10 // extra slack to improve the approximation
val reo = 4 // reorthogonalize after this many iters
val numIters = 1
val rank = 4 // rank of approximation
val slack = rank // extra slack to improve the approximation

/* load matrix */
val nonzeros = sc.textFile(inpath).map(_.split(",")).
map(x => new MatrixEntry(x(1).toLong, x(0).toLong, x(2).toDouble))
val coomat = new CoordinateMatrix(nonzeros, 3743324, 9574) // FIXME: magics
map(x => new MatrixEntry(x(0).toLong, x(1).toLong, x(2).toDouble))
val coomat = new CoordinateMatrix(nonzeros, 951, 781210) // FIXME: magics
val mat = coomat.toIndexedRowMatrix()
mat.rows.cache()

/* approximate principal subspace */
var B = gaussianProjection(mat, rank + slack).toBreeze
for(i <- 0 to numIters) {
if(i % reo == reo-1) {
println("reorth")
B = qr.justQ(B)
}
B = multiplyGramianBy(mat, fromBreeze(B)).toBreeze.asInstanceOf[BDM[Double]]
/* perform randomized SVD */
var Y = gaussianProjection(mat, rank + slack).toBreeze
for(i <- 0 until numIters) {
Y = multiplyGramianBy(mat, fromBreeze(Y)).toBreeze.asInstanceOf[BDM[Double]]
}
assert(Y.cols == rank + slack)
assert(Y.rows == 951)
val Q = qr.justQ(Y)
val B = transposeMultiply(mat, fromBreeze(Q)).transpose.toBreeze.asInstanceOf[BDM[Double]]
val Bsvd = svd.reduced(B)
val U = (Q * Bsvd.U).apply(::, 0 until rank)
val S = Bsvd.S(0 until rank)
val Vt = Bsvd.Vt(0 until rank, ::)

/* compute leverage scores */
val U = svd.reduced(B).U(::, 0 until rank)
// lev = np.sum(U[:,:k]**2,axis=1)
val lev = sum(U :^ 2.0, Axis._0)
val p = lev / rank.toDouble
val rowlev = sum(U :^ 2.0, Axis._1)
val rowp = rowlev / rank.toDouble
val collev = sum(Vt :^ 2.0, Axis._1)
val colp = collev / rank.toDouble

println(p)
println("S:\n")
println(S)
println("RowP:")
println(rowp)
println("\nColP:")
println(colp)
}
}

0 comments on commit 29ddf24

Please sign in to comment.