Skip to content

Commit

Permalink
v1.2 switched back to double
Browse files Browse the repository at this point in the history
  • Loading branch information
dedztbh committed Sep 22, 2020
1 parent 76a69c9 commit f2957b5
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 44 deletions.
13 changes: 9 additions & 4 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ plugins {
id("com.github.johnrengelman.shadow") version "6.0.0"
}
group = "com.dedztbh"
version = "1.1"
version = "1.2"

val ejmlVersion = "0.39"

Expand All @@ -17,11 +17,16 @@ repositories {
dependencies {
implementation("org.ejml:ejml-core:${ejmlVersion}")
implementation("org.ejml:ejml-kotlin:${ejmlVersion}")
implementation("org.ejml:ejml-fdense:${ejmlVersion}")
// testImplementation(kotlin("test-junit"))
implementation("org.ejml:ejml-ddense:${ejmlVersion}")
implementation("org.ejml:ejml-simple:${ejmlVersion}")
}
tasks.withType<KotlinCompile>() {
tasks.withType<KotlinCompile> {
kotlinOptions.jvmTarget = "1.8"
kotlinOptions.freeCompilerArgs = listOf(
"-Xno-param-assertions",
"-Xno-call-assertions",
"-Xno-receiver-assertions"
)
}

tasks {
Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/operator/Prob0Init.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class Prob0Init(N: Int) : ProbFinder(N) {
override fun printResult() {
val probs = eval(getZeroVec())
allStates(N).forEach { endState ->
var prob = 1f
var prob = 1.0
endState.forEach { i ->
val x = probs[i]
prob *= if (i > 0) x else 1f - x
prob *= if (i > 0) x else 1.0 - x
}
println("Pr[%s] = %.10f".format(endState.joinToString(","), prob))
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/kotlin/operator/ProbAllInit.kt
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
package operator

import org.ejml.data.FMatrixRMaj
import org.ejml.data.DMatrixRMaj
import kotlin.math.pow

/**
* Created by DEDZTBH on 2020/09/15.
* Project CMU_Coin-flipping_Experience
*/

fun allStatesFloat(n: Int) =
fun allStatesDouble(n: Int) =
Array(2.0.pow(n).toInt()) {
FloatArray(n) { i -> ((it shr i) and 1).toFloat() }.apply { reverse() }
DoubleArray(n) { i -> ((it shr i) and 1).toDouble() }.apply { reverse() }
}

class ProbAllInit(N: Int) : ProbFinder(N) {
override fun printResult() {
eval(
// a 2^n by n matrix
FMatrixRMaj(allStatesFloat(N))
DMatrixRMaj(allStatesDouble(N))
).print()
}
}
44 changes: 22 additions & 22 deletions src/main/kotlin/operator/ProbFinder.kt
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
package operator

import org.ejml.data.FMatrixRMaj
import org.ejml.dense.row.CommonOps_FDRM
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.CommonOps_DDRM
import org.ejml.kotlin.minus
import org.ejml.kotlin.plus
import org.ejml.kotlin.plusAssign
import org.ejml.kotlin.times
import util.*

sealed class Operation
data class MatrixOp(val opVec: FMatrixRMaj, val opBias: FMatrixRMaj) : Operation()
data class MatrixOp(val opVec: DMatrixRMaj, val opBias: DMatrixRMaj) : Operation()
data class CNot(val i: Int, val j: Int) : Operation()
data class CSwap(val i: Int, val j: Int, val k: Int) : Operation()
data class CCNot(val i: Int, val j: Int, val k: Int) : Operation()
data class Gen1Bit(val i: Int, val p: Float, val q: Float) : Operation()
data class Gen1Bit(val i: Int, val p: Double, val q: Double) : Operation()

/**
* Created by DEDZTBH on 2020/09/15.
* Project CMU_Coin-flipping_Experience
*/
abstract class ProbFinder(val N: Int) : Operator {
fun getOneVec(): FMatrixRMaj = FMatrixRMaj(arrayOf(FloatArray(N) { 1f }))
fun getZeroVec(): FMatrixRMaj = FMatrixRMaj(1, N)
fun getOneVec(): DMatrixRMaj = DMatrixRMaj(arrayOf(DoubleArray(N) { 1.0 }))
fun getZeroVec(): DMatrixRMaj = DMatrixRMaj(1, N)
fun saveMatrix() {
if (matrixDirty) {
operations.add(MatrixOp(opVec, opBias))
Expand All @@ -41,13 +41,13 @@ abstract class ProbFinder(val N: Int) : Operator {
val i = readInt()
when (cmd) {
"Flip" -> {
opVec[i] = 0f
opBias[i] = 0.5f
opVec[i] = 0.0
opBias[i] = 0.5
matrixDirty = true
}
"Not" -> {
opVec[i] *= -1f
opBias[i] = 1f - opBias[i]
opVec[i] *= -1.0
opBias[i] = 1.0 - opBias[i]
matrixDirty = true
}
"CNot" -> { // cannot represent in matrix op
Expand All @@ -68,14 +68,14 @@ abstract class ProbFinder(val N: Int) : Operator {
operations.add(CCNot(i, j, k))
}
"GenFlip" -> {
val j = readFloat()
opVec[i] = 0f
val j = readDouble()
opVec[i] = 0.0
opBias[i] = j
matrixDirty = true
}
"Gen1Bit" -> { // cannot represent in matrix op
val p = readFloat()
val q = readFloat()
val p = readDouble()
val q = readDouble()
saveMatrix()
operations.add(Gen1Bit(i, p, q))
}
Expand All @@ -86,21 +86,21 @@ abstract class ProbFinder(val N: Int) : Operator {

override fun done() = saveMatrix()

fun eval(probs: FMatrixRMaj): FMatrixRMaj = probs.apply {
fun eval(probs: DMatrixRMaj): DMatrixRMaj = probs.apply {
val broadcastVec =
FMatrixRMaj(probs.numRows, 1, true, *FloatArray(probs.numRows) { 1f })
DMatrixRMaj(probs.numRows, 1, true, *DoubleArray(probs.numRows) { 1.0 })
operations.forEach {
it.apply {
when (this) {
is MatrixOp -> {
CommonOps_FDRM.elementMult(probs, broadcastVec * opVec)
CommonOps_DDRM.elementMult(probs, broadcastVec * opVec)
probs += broadcastVec * opBias
}
is CNot -> {
val x = probs getColumn i
val y = probs getColumn j
//(1 - x) * y + x * (1 - y)
probs.putColumn(j, (-2f * x mul y) + x + y)
probs.putColumn(j, (-2.0 * x mul y) + x + y)
}
is CSwap -> {
val x = probs getColumn i
Expand All @@ -118,17 +118,17 @@ abstract class ProbFinder(val N: Int) : Operator {
val y = probs getColumn j
val z = probs getColumn k
//(1 - x) * (1 - y)
val probBoth0 = (x mul y) - x - y + 1f
val probBoth0 = (x mul y) - x - y + 1.0
//probBoth0 * z + (1 - probBoth0) * (1 - z)
probs.putColumn(k, (2f * probBoth0 mul z) - probBoth0 - z + 1f)
probs.putColumn(k, (2.0 * probBoth0 mul z) - probBoth0 - z + 1.0)
}
is Gen1Bit -> {
val x = probs getColumn i
val xx = x mul x
val pComp = 1f - p
val pComp = 1.0 - p
//x * (1 - q) * x + (1 - x) * (p + (1 - p) * x)
//= xx(1-q) + p + (1-p)x - xp - xx(1-p)
probs.putColumn(i, (xx * (1f - q)) + p + (x * pComp) - (x * p) - (xx * pComp))
probs.putColumn(i, (xx * (1.0 - q)) + p + (x * pComp) - (x * p) - (xx * pComp))
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/kotlin/operator/Tester.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package operator

import util.readFloat
import util.readDouble
import util.readInt
import kotlin.random.Random

Expand Down Expand Up @@ -32,10 +32,10 @@ class Tester(N: Int) : Operator {
val k = readInt()
if (coins[i] or coins[j] != 0) coins flip k
}
"GenFlip" -> coins[i] = if (Random.nextDouble() < readFloat()) 1 else 0
"GenFlip" -> coins[i] = if (Random.nextDouble() < readDouble()) 1 else 0
"Gen1Bit" -> {
val p = readFloat()
val q = readFloat()
val p = readDouble()
val q = readDouble()
if (coins[i] == 0) {
if (Random.nextDouble() < p) coins[i] = 1
} else
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/util/io.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ fun read(): String {
}

fun readInt() = read().toInt()
fun readFloat() = read().toFloat()
fun readDouble() = read().toDouble()

lateinit var reader: BufferedReader
16 changes: 9 additions & 7 deletions src/main/kotlin/util/matrix.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package util

import org.ejml.data.FMatrixRMaj
import org.ejml.dense.row.CommonOps_FDRM
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.CommonOps_DDRM
import org.ejml.simple.ops.SimpleOperations_DDRM

/**
* Created by DEDZTBH on 2020/09/18.
Expand All @@ -11,20 +12,21 @@ import org.ejml.dense.row.CommonOps_FDRM
/**
* Retrieve column i
*/
infix fun FMatrixRMaj.getColumn(i: Int) = CommonOps_FDRM.extractColumn(this, i, null)
infix fun DMatrixRMaj.getColumn(i: Int) = CommonOps_DDRM.extractColumn(this, i, null)

/**
* Replace column i
*/
fun FMatrixRMaj.putColumn(i: Int, col: FMatrixRMaj) = col.data.forEachIndexed { rowi, fl -> set(rowi, i, fl) }
val simpleOps = SimpleOperations_DDRM()
fun DMatrixRMaj.putColumn(i: Int, col: DMatrixRMaj) = simpleOps.setColumn(this, i, 0, *col.data)

/**
* Element-wise multiplication (pure)
*/
infix fun FMatrixRMaj.mul(other: FMatrixRMaj) = createLike().also { CommonOps_FDRM.elementMult(this, other, it) }
infix fun DMatrixRMaj.mul(other: DMatrixRMaj) = createLike().also { CommonOps_DDRM.elementMult(this, other, it) }

/**
* Scalar Multiplication (pure)
*/
operator fun FMatrixRMaj.times(f: Float) = createLike().also { CommonOps_FDRM.scale(f, this, it) }
operator fun Float.times(f: FMatrixRMaj) = f * this
operator fun DMatrixRMaj.times(f: Double) = createLike().also { CommonOps_DDRM.scale(f, this, it) }
operator fun Double.times(f: DMatrixRMaj) = f * this

0 comments on commit f2957b5

Please sign in to comment.