In [1]:
import org.apache.spark.sql.{Dataset, Encoder, Encoders,SparkSession}
import org.apache.spark.sql.functions._
import scala.util.Random
import org.apache.spark.rdd.RDD
// import scala.reflect.ClassTag

In [4]:
implicit val spark: SparkSession = SparkSession.builder()
  .appName("DRGU Debugging")
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

spark = org.apache.spark.sql.SparkSession@2e87ab56


org.apache.spark.sql.SparkSession@2e87ab56

## simulate small data

In [3]:
case class Obs(
  i: String,                   // cluster ID
  x: Array[Double],           // covariates
  y: Double,                  // outcome
  timeIndex: Option[Int] = None,      // optional time index
  z: Option[Double] = None       // optional treatment indicator
)

case class ObsPair(left: Obs, right: Obs)

defined class Obs
defined class ObsPair


In [4]:


// Generate small synthetic data
val numClusters = 10
val numObsPerCluster = 1
val random = new Random()

val data = (1 to numClusters).flatMap { clusterId =>
  (1 to numObsPerCluster).map { obsId =>
    Obs(
      i = s"c$clusterId",
      x = Array.fill(3)(random.nextDouble()), // 3 covariates
      y = random.nextDouble(),               // random outcome
      timeIndex = Some(obsId),               // time index
      z = Some(random.nextDouble())          // random treatment indicator
    )
  }
}

// Convert to Dataset
val df: Dataset[Obs] = data.toDS()

// Show the generated data
df.show(false)

+---+---------------------------------------------------------------+-------------------+---------+-------------------+
|i  |x                                                              |y                  |timeIndex|z                  |
+---+---------------------------------------------------------------+-------------------+---------+-------------------+
|c1 |[0.21117451386379826, 0.7337139693881689, 0.13668762864606376] |0.3935042601155878 |1        |0.91852898534016   |
|c2 |[0.5551291598312791, 0.7765361928694281, 0.27366372362138724]  |0.884753148989808  |1        |0.06225156631387341|
|c3 |[0.4409948958419089, 0.6634485942467294, 0.2879274360179592]   |0.20146890333974976|1        |0.9822233109903844 |
|c4 |[0.5501508382113921, 0.6020954024318038, 0.07284599681938175]  |0.07301750449391564|1        |0.42330505560451925|
|c5 |[0.893311039469886, 0.5841559004165905, 0.38774190907515693]   |0.14553253420089685|1        |0.8061781651296719 |
|c6 |[0.2652286521155143, 0.029188125259

numClusters = 10
numObsPerCluster = 1
random = scala.util.Random@61ad06c6
data = Vector(Obs(c1,[D@5a1e4103,0.3935042601155878,Some(1),Some(0.91852898534016)), Obs(c2,[D@3f5ac56f,0.884753148989808,Some(1),Some(0.06225156631387341)), Obs(c3,[D@36c87f74,0.20146890333974976,Some(1),Some(0.9822233109903844)), Obs(c4,[D@427eebe4,0.07301750449391564,Some(1),Some(0.42330505560451925)), Obs(c5,[D@6ac0bcf1,0.14553253420089685,Some(1),Some(0.8061781651296719)), Obs(c6,[D@3d87d7b,0.4223080076185145,Some(1),Some(0.7455777139544757)), Obs(c7,[D@761e91ab,0.9851781329252142,Some(1),Some(0.36573628678447456)), Obs(c8,[D@430a6547,0.13527752328098452,Some(1),Some(0.898594495805197)), Obs(c9,[D@2c426ee3,0.6455018108568834,Some(1),Some...


Vector(Obs(c1,[D@5a1e4103,0.3935042601155878,Some(1),Some(0.91852898534016)), Obs(c2,[D@3f5ac56f,0.884753148989808,Some(1),Some(0.06225156631387341)), Obs(c3,[D@36c87f74,0.20146890333974976,Some(1),Some(0.9822233109903844)), Obs(c4,[D@427eebe4,0.07301750449391564,Some(1),Some(0.42330505560451925)), Obs(c5,[D@6ac0bcf1,0.14553253420089685,Some(1),Some(0.8061781651296719)), Obs(c6,[D@3d87d7b,0.4223080076185145,Some(1),Some(0.7455777139544757)), Obs(c7,[D@761e91ab,0.9851781329252142,Some(1),Some(0.36573628678447456)), Obs(c8,[D@430a6547,0.13527752328098452,Some(1),Some(0.898594495805197)), Obs(c9,[D@2c426ee3,0.6455018108568834,Some(1),Some...

## test

In [5]:
def assignDeterministicPartition(
    ds: Dataset[Obs],
    idColumn: String = "i",  // cluster ID or any stable ID
    numPartitions: Int
)(implicit spark: SparkSession): Dataset[(Obs, Int)] = {
  import spark.implicits._

  val withBucket = ds.withColumn("bucket", (xxhash64(col(idColumn)) % numPartitions).cast("int"))
  val repartitioned = withBucket.repartition($"bucket")
  repartitioned.select(
    struct($"i", $"x", $"y", $"timeIndex", $"z").as("obs"),
    $"bucket"
  ).as[(Obs, Int)]
}



assignDeterministicPartition: (ds: org.apache.spark.sql.Dataset[Obs], idColumn: String, numPartitions: Int)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[(Obs, Int)]


In [6]:
def sampleObsPairsFromRepartitioned(
    partitioned: Dataset[(Obs, Int)],
    samplePerPartition: Int,
    seed: Long = System.currentTimeMillis()
)(implicit spark: SparkSession): Dataset[ObsPair] = {
  import spark.implicits._

  partitioned
    .groupByKey { case (_, bucket) => bucket }
    .flatMapGroups { case (_, iter) =>
      val rand = new scala.util.Random(seed)
      val items = iter.map(_._1).toIndexedSeq
      val n = items.length

      if (n < 2) Iterator.empty
      else {
        val sampledPairs = Iterator.fill(samplePerPartition) {
          var i = 0
          var j = 0
          do {
            i = rand.nextInt(n)
            j = rand.nextInt(n)
          } while (i == j)

          val (ii, jj) = if (i < j) (i, j) else (j, i)
          ObsPair(items(ii), items(jj))
        }
        sampledPairs
      }
    }
}

def sampleObsPairsFromRepartitionedRow(
    partitioned: Dataset[(Obs, Int)],
    matchesPerRow: Int,
    seed: Long = System.currentTimeMillis()
)(implicit spark:SparkSession): Dataset[ObsPair] = {
  import spark.implicits._

  partitioned
    .groupByKey { case (_, bucket) => bucket }
    .flatMapGroups { case (_, iter) =>
      val rand = new scala.util.Random(seed)
      val items = iter.map(_._1).toIndexedSeq
      val n = items.length

      if (n < 2) Iterator.empty
      else {
        val sampledPairs  = (0 until n).flatMap { i =>
          (0 until matchesPerRow).map { _ =>
            var j = i
            while (j == i) j = rand.nextInt(n)
            ObsPair(items(i), items(j))
          }
        }
        sampledPairs
      }
    }
}

sampleObsPairsFromRepartitioned: (partitioned: org.apache.spark.sql.Dataset[(Obs, Int)], samplePerPartition: Int, seed: Long)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]
sampleObsPairsFromRepartitionedRow: (partitioned: org.apache.spark.sql.Dataset[(Obs, Int)], matchesPerRow: Int, seed: Long)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]


In [7]:
val repartitioned = assignDeterministicPartition(df, idColumn = "i", numPartitions = 2).cache()
repartitioned.count() // trigger caching

val sampledPairs: Dataset[ObsPair] =
  sampleObsPairsFromRepartitioned(repartitioned, samplePerPartition = 3)

sampledPairs.show()


+--------------------+--------------------+
|                left|               right|
+--------------------+--------------------+
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c5, [0.893311039...|{c7, [0.590306759...|
|{c5, [0.893311039...|{c6, [0.265228652...|
|{c1, [0.211174513...|{c5, [0.893311039...|
|{c8, [0.046584247...|{c9, [0.112529737...|
|{c8, [0.046584247...|{c9, [0.112529737...|
|{c8, [0.046584247...|{c9, [0.112529737...|
+--------------------+--------------------+



repartitioned = [obs: struct<i: string, x: array<double> ... 3 more fields>, bucket: int]
sampledPairs = [left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]


[left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]

In [8]:
val sampledPairsRow: Dataset[ObsPair] =
  sampleObsPairsFromRepartitionedRow(repartitioned, matchesPerRow = 3)

sampledPairsRow.show()

+--------------------+--------------------+
|                left|               right|
+--------------------+--------------------+
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c2, [0.555129159...|{c4, [0.550150838...|
|{c4, [0.550150838...|{c2, [0.555129159...|
|{c4, [0.550150838...|{c2, [0.555129159...|
|{c4, [0.550150838...|{c2, [0.555129159...|
|{c1, [0.211174513...|{c5, [0.893311039...|
|{c1, [0.211174513...|{c3, [0.440994895...|
|{c1, [0.211174513...|{c3, [0.440994895...|
|{c3, [0.440994895...|{c7, [0.590306759...|
|{c3, [0.440994895...|{c7, [0.590306759...|
|{c3, [0.440994895...|{c5, [0.893311039...|
|{c5, [0.893311039...|{c1, [0.211174513...|
|{c5, [0.893311039...|{c6, [0.265228652...|
|{c5, [0.893311039...|{c3, [0.440994895...|
|{c6, [0.265228652...|{c3, [0.440994895...|
|{c6, [0.265228652...|{c7, [0.590306759...|
|{c6, [0.265228652...|{c7, [0.590306759...|
|{c7, [0.590306759...|{c1, [0.211174513...|
|{c7, [0.590306759...|{c3, [0.44

sampledPairsRow = [left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]


[left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]

In [9]:
repartitioned.unpersist()

[obs: struct<i: string, x: array<double> ... 3 more fields>, bucket: int]

## Change data

In [10]:
import breeze.linalg._

case class PairFeatures(
  Wt_i: Array[Double],
  Wt_j: Array[Double],
  Xg_ij: Array[Double],
  Xg_ji: Array[Double],
  yi: Double,
  yj: Double,
  zi: Double,
  zj: Double
)

defined class PairFeatures


In [11]:
def toPairFeatures(pair: ObsPair): PairFeatures = {
  val Obs(_, x1, y1, _, Some(z1)) = pair.left
  val Obs(_, x2, y2, _, Some(z2)) = pair.right

  val Wt_i = Array(1.0) ++ x1
  val Wt_j = Array(1.0) ++ x2
  val Xg_ij = Array(1.0) ++ x1 ++ x2
  val Xg_ji = Array(1.0) ++ x2 ++ x1

  PairFeatures(Wt_i, Wt_j, Xg_ij, Xg_ji, y1, y2, z1, z2)
}

toPairFeatures: (pair: ObsPair)PairFeatures


In [12]:
val pairFeatureDS: Dataset[PairFeatures] = sampledPairs.map(toPairFeatures)

pairFeatureDS = [Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]


[Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]

In [13]:
pairFeatureDS.show()

+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+
|                Wt_i|                Wt_j|               Xg_ij|               Xg_ji|                 yi|                 yj|                 zi|                 zj|
+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+
|[1.0, 0.555129159...|[1.0, 0.550150838...|[1.0, 0.555129159...|[1.0, 0.550150838...|  0.884753148989808|0.07301750449391564|0.06225156631387341|0.42330505560451925|
|[1.0, 0.555129159...|[1.0, 0.550150838...|[1.0, 0.555129159...|[1.0, 0.550150838...|  0.884753148989808|0.07301750449391564|0.06225156631387341|0.42330505560451925|
|[1.0, 0.555129159...|[1.0, 0.550150838...|[1.0, 0.555129159...|[1.0, 0.550150838...|  0.884753148989808|0.07301750449391564|0.06225156631387341|0.42330505560451925|
|[1.

## compute HF

In [14]:
def safeSig(x: Double): Double = {
  val clipped = math.max(-10.0, math.min(10.0, x))
  1.0 / (1.0 + math.exp(-clipped))
}

safeSig: (x: Double)Double


### row wise

In [15]:
def computeHFforPair(
  p: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double]) = {
  val delta = theta("delta")(0)
  val beta  = theta("beta")
  val gamma = theta("gamma")

  val Wt_i = DenseVector(p.Wt_i)
  val Wt_j = DenseVector(p.Wt_j)
  val Xg_ij = DenseVector(p.Xg_ij)
  val Xg_ji = DenseVector(p.Xg_ji)

  val pi_i = safeSig(Wt_i dot beta)
  val pi_j = safeSig(Wt_j dot beta)
  val g_ij = safeSig(Xg_ij dot gamma)
  val g_ji = safeSig(Xg_ji dot gamma)

  val I_ij = if (p.yi >= p.yj) 1.0 else 0.0
  val I_ji = 1.0 - I_ij

  // h
  val num1 = p.zi * (1 - p.zj) / (2 * pi_i * (1 - pi_j)) * (I_ij - g_ij)
  val num2 = p.zj * (1 - p.zi) / (2 * pi_j * (1 - pi_i)) * (I_ji - g_ji)
  val h1 = num1 + num2 + 0.5 * (g_ij + g_ji)
  val h2 = 0.5 * (p.zi + p.zj)
  val h3 = 0.5 * (p.zi * (1 - p.zj) * I_ij + p.zj * (1 - p.zi) * I_ji)
  val h = DenseVector(h1, h2, h3)

  // f
  val f1 = delta
  val f2 = 0.5 * (pi_i + pi_j)
  val f3 = 0.5 * (pi_i * (1 - pi_j) * g_ij + pi_j * (1 - pi_i) * g_ji)
  val f = DenseVector(f1, f2, f3)

  (h, f)
}

computeHFforPair: (p: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


In [16]:
val p = 3
val theta = Map(
  "delta" -> DenseVector(0.1),
  "beta" -> DenseVector.zeros[Double](p + 1),   // where p is dim(x)
  "gamma" -> DenseVector.zeros[Double](2 * p + 1)
)

p = 3
theta = Map(delta -> DenseVector(0.1), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))


Map(delta -> DenseVector(0.1), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))

In [17]:
val hfRDD:RDD[(DenseVector[Double], DenseVector[Double])] = pairFeatureDS.rdd.map(p => computeHFforPair(p, theta))

// Specify the size of the DenseVector (e.g., 3 for h and f)
val vectorSize = 3

val (hSum, fSum, count) = hfRDD.aggregate(
  (DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize), 0)
)(
  { case ((hAcc, fAcc, c), (h, f)) => (hAcc + h, fAcc + f, c + 1) },
  { case ((h1, f1, c1), (h2, f2, c2)) => (h1 + h2, f1 + f2, c1 + c2) }
)

val hMean = hSum / count.toDouble
val fMean = fSum / count.toDouble

println(s"Mean h: $hMean")
println(s"Mean f: $fMean")

Mean h: DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
Mean f: DenseVector(0.09999999999999999, 0.5, 0.125)


hfRDD = MapPartitionsRDD[60] at map at <console>:40
vectorSize = 3
hSum = DenseVector(0.6616141579363142, 4.465040228586322, 0.26725805338061737)
fSum = DenseVector(0.8999999999999999, 4.5, 1.125)
count = 9
hMean = DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
fMean = DenseVector(0.09999999999999999, 0.5, 0.125)


DenseVector(0.09999999999999999, 0.5, 0.125)

### batch wise

In [18]:
def safeSigmoid(x: DenseVector[Double]): DenseVector[Double] = 
  x.map(v => 1.0 / (1.0 + math.exp(-math.max(-10.0, math.min(10.0, v)))))

def safeSigmoid(x: DenseMatrix[Double]): DenseMatrix[Double] =
  x.map(v => 1.0 / (1.0 + math.exp(-math.max(-10.0, math.min(10.0, v)))))

safeSigmoid: (x: breeze.linalg.DenseVector[Double])breeze.linalg.DenseVector[Double] <and> (x: breeze.linalg.DenseMatrix[Double])breeze.linalg.DenseMatrix[Double]
safeSigmoid: (x: breeze.linalg.DenseVector[Double])breeze.linalg.DenseVector[Double] <and> (x: breeze.linalg.DenseMatrix[Double])breeze.linalg.DenseMatrix[Double]


In [19]:
def computeHFMatrixBatch(
  batch: Seq[PairFeatures],
  theta: Map[String, DenseVector[Double]]
): (Seq[DenseVector[Double]], Seq[DenseVector[Double]]) = {

  val n = batch.size
  val p = theta("beta").length
  val d = theta("gamma").length

  val delta = theta("delta")(0)
  val beta  = theta("beta")
  val gamma = theta("gamma")

  // Convert to matrix form
  val Wt_i   = DenseMatrix(batch.map(_.Wt_i): _*)         // n x p
  val Wt_j   = DenseMatrix(batch.map(_.Wt_j): _*)         // n x p
  val Xg_ij  = DenseMatrix(batch.map(_.Xg_ij): _*)        // n x d
  val Xg_ji  = DenseMatrix(batch.map(_.Xg_ji): _*)        // n x d

  val zi     = DenseVector(batch.map(_.zi).toArray)
  val zj     = DenseVector(batch.map(_.zj).toArray)
  val yi     = DenseVector(batch.map(_.yi).toArray)
  val yj     = DenseVector(batch.map(_.yj).toArray)

  // Predictions
  val pi_i = safeSigmoid(Wt_i * beta)
  val pi_j = safeSigmoid(Wt_j * beta)
  val g_ij = safeSigmoid(Xg_ij * gamma)
  val g_ji = safeSigmoid(Xg_ji * gamma)

  val I_ij = (yi <:< yj).map(v => if (v) 0.0 else 1.0) // Use map for element-wise comparison
  val I_ji = 1.0 - I_ij

  // Compute h
  val h1 = ((zi *:* (1.0 - zj)) *:* (1.0/(2.0 * (pi_i *:* (1.0 - pi_j))))) *:* (I_ij - g_ij) +
           ((zj *:* (1.0 - zi)) *:* (1.0/(2.0 * (pi_j *:* (1.0 - pi_i))))) *:* (I_ji - g_ji) +
           0.5 * (g_ij + g_ji)

  val h2 = 0.5 * (zi + zj)
  val h3 = 0.5 * (zi *:* (1.0 - zj) *:* I_ij + zj *:* (1.0 - zi) *:* I_ji)

  val hList = (0 until n).map(i => DenseVector(h1(i), h2(i), h3(i)))

  // Compute f
  val f1 = DenseVector.fill(n)(delta)
  val f2 = 0.5 * (pi_i + pi_j)
  val f3 = 0.5 * (pi_i *:* (1.0- pi_j) *:* g_ij + pi_j *:* (1.0 - pi_i) *:* g_ji)

  val fList = (0 until n).map(i => DenseVector(f1(i), f2(i), f3(i)))

  (hList, fList)
}

computeHFMatrixBatch: (batch: Seq[PairFeatures], theta: Map[String,breeze.linalg.DenseVector[Double]])(Seq[breeze.linalg.DenseVector[Double]], Seq[breeze.linalg.DenseVector[Double]])


In [20]:
val hfRDD: RDD[(DenseVector[Double], DenseVector[Double])] =
  pairFeatureDS.rdd.mapPartitions { iter =>
    iter.grouped(200).flatMap { batch =>
      val (hList, fList) = computeHFMatrixBatch(batch, theta)
      hList.zip(fList) // Explicitly zip the two sequences
    }
  }

hfRDD = MapPartitionsRDD[61] at mapPartitions at <console>:40


MapPartitionsRDD[61] at mapPartitions at <console>:40

In [21]:
val firstElement = hfRDD.take(1).headOption

firstElement match {
  case Some((h, f)) =>
    println(s"First h: $h")
    println(s"First f: $f")
  case None =>
    println("hfRDD is empty.")
}

First h: DenseVector(0.13894651070935415, 0.24277831095919633, 0.017950081786955403)
First f: DenseVector(0.1, 0.5, 0.125)


firstElement = Some((DenseVector(0.13894651070935415, 0.24277831095919633, 0.017950081786955403),DenseVector(0.1, 0.5, 0.125)))


Some((DenseVector(0.13894651070935415, 0.24277831095919633, 0.017950081786955403),DenseVector(0.1, 0.5, 0.125)))

In [22]:
// Specify the size of the DenseVector (e.g., 3 for h and f)
val vectorSize = 3

val (hSum, fSum, count) = hfRDD.aggregate(
  (DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize), 0)
)(
  { case ((hAcc, fAcc, c), (h, f)) => (hAcc + h, fAcc + f, c + 1) },
  { case ((h1, f1, c1), (h2, f2, c2)) => (h1 + h2, f1 + f2, c1 + c2) }
)

val hMean = hSum / count.toDouble
val fMean = fSum / count.toDouble

println(s"Mean h: $hMean")
println(s"Mean f: $fMean")

Mean h: DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
Mean f: DenseVector(0.09999999999999999, 0.5, 0.125)


vectorSize = 3
hSum = DenseVector(0.6616141579363142, 4.465040228586322, 0.26725805338061737)
fSum = DenseVector(0.8999999999999999, 4.5, 1.125)
count = 9
hMean = DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
fMean = DenseVector(0.09999999999999999, 0.5, 0.125)


DenseVector(0.09999999999999999, 0.5, 0.125)

## Compute $\partial h$ and $\partial f$

### row wise

In [23]:


def computeHFDFforPair(
  p: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double], DenseVector[Double]) = {
  val delta = theta("delta")(0)
  val beta  = theta("beta")
  val gamma = theta("gamma")

  val Wt_i = DenseVector(p.Wt_i)
  val Wt_j = DenseVector(p.Wt_j)
  val Xg_ij = DenseVector(p.Xg_ij)
  val Xg_ji = DenseVector(p.Xg_ji)

  val pi_i = safeSig(Wt_i dot beta)
  val pi_j = safeSig(Wt_j dot beta)
  val g_ij = safeSig(Xg_ij dot gamma)
  val g_ji = safeSig(Xg_ji dot gamma)

  val I_ij = if (p.yi >= p.yj) 1.0 else 0.0
  val I_ji = 1.0 - I_ij

  // h
  val num1 = p.zi * (1 - p.zj) / (2 * pi_i * (1 - pi_j)) * (I_ij - g_ij)
  val num2 = p.zj * (1 - p.zi) / (2 * pi_j * (1 - pi_i)) * (I_ji - g_ji)
  val h1 = num1 + num2 + 0.5 * (g_ij + g_ji)
  val h2 = 0.5 * (p.zi + p.zj)
  val h3 = 0.5 * (p.zi * (1 - p.zj) * I_ij + p.zj * (1 - p.zi) * I_ji)
  val h = DenseVector(h1, h2, h3)

  // f
  val f1 = delta
  val f2 = 0.5 * (pi_i + pi_j)
  val f3 = 0.5 * (pi_i * (1 - pi_j) * g_ij + pi_j * (1 - pi_i) * g_ji)
  val f = DenseVector(f1, f2, f3)


  // gradient of f1
  val df1Delta = 1.0

  // gradient of f2 w.r.t. beta
  val dPiI = pi_i * (1 - pi_i) * Wt_i
  val dPiJ = pi_j * (1 - pi_j) * Wt_j
  val df2Beta = 0.5 * (dPiI + dPiJ)

  // gradient of f3
  val df3Beta =
    0.5 * (
      ((1 - pi_j) * pi_i * (1 - pi_i) * Wt_i - pi_i * pi_j * (1 - pi_j) * Wt_j)  * g_ij +
      ((1 - pi_i) * pi_j * (1 - pi_j) * Wt_j - pi_j * pi_i * (1 - pi_i) * Wt_i)  * g_ji
    )

  val df3Gamma =
    0.5 * (
      pi_i * (1 - pi_j) * g_ij * (1 - g_ij) * Xg_ij +
      pi_j * (1 - pi_i) * g_ji * (1 - g_ji) * Xg_ji
    )

  // final gradient
  val pb = beta.length
  val qg = gamma.length
  val df = DenseMatrix.vertcat(
    DenseVector(1.0 +: Array.fill(pb + qg)(0.0): _*).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df2Beta, DenseVector.zeros[Double](qg)).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df3Beta, df3Gamma).toDenseMatrix
  )

  (h, f, df.toDenseVector)
}

computeHFDFforPair: (p: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


In [24]:
val hfDfRDD: RDD[(DenseVector[Double], DenseVector[Double], DenseVector[Double])] =
  pairFeatureDS.rdd.mapPartitions { iter =>
    iter.grouped(200).flatMap { batch =>
      batch.map(p => computeHFDFforPair(p, theta))
    }
  }

val thetaLength = theta("beta").length + theta("gamma").length + 1 // +1 for delta
val vectorSize = 3 // Size of h and f vectors

val (hSumDf, fSumDf, dfSum, countDf) = hfDfRDD.aggregate(
  (DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize*thetaLength), 0)
)(
  { case ((hAcc, fAcc, dfAcc, c), (h, f, df)) => (hAcc + h, fAcc + f, dfAcc + df, c + 1) },
  { case ((h1, f1, df1, c1), (h2, f2, df2, c2)) => (h1 + h2, f1 + f2, df1 + df2, c1 + c2) }
)

val hMeanDf = hSumDf / countDf.toDouble
val fMeanDf = fSumDf / countDf.toDouble
val dfMeanDf = dfSum / countDf.toDouble

println(s"Mean h: $hMeanDf")
println(s"Mean f: $fMeanDf")
println(s"Mean df: $dfMeanDf")


Mean h: DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
Mean f: DenseVector(0.09999999999999999, 0.5, 0.125)
Mean df: DenseVector(1.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.10471979154322772, 0.0, 0.0, 0.14610285571637865, 0.0, 0.0, 0.09207643841461047, 0.0, 0.0, 0.0, 0.0625, 0.0, 0.0, 0.02617994788580693, 0.0, 0.0, 0.036525713929094664, 0.0, 0.0, 0.023019109603652617, 0.0, 0.0, 0.02617994788580693, 0.0, 0.0, 0.036525713929094664, 0.0, 0.0, 0.023019109603652617)


hfDfRDD = MapPartitionsRDD[62] at mapPartitions at <console>:43
thetaLength = 12
vectorSize = 3
hSumDf = DenseVector(0.6616141579363142, 4.465040228586322, 0.26725805338061737)
fSumDf = DenseVector(0.8999999999999999, 4.5, 1.125)
dfSum = DenseVector(9.0, 0.0, 0.0, 0.0, 2.25, 0.0, 0.0, 0.9424781238890495, 0.0, 0.0, 1.3149257014474078, 0.0, 0.0, 0.8286879457314942, 0.0, 0.0, 0.0, 0.5625, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, 0.20717198643287354, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, ...


DenseVector(9.0, 0.0, 0.0, 0.0, 2.25, 0.0, 0.0, 0.9424781238890495, 0.0, 0.0, 1.3149257014474078, 0.0, 0.0, 0.8286879457314942, 0.0, 0.0, 0.0, 0.5625, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, 0.20717198643287354, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, ...

In [25]:
def computeHFDFDHforPair(
  p: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double], DenseVector[Double], DenseVector[Double]) = {
  val delta = theta("delta")(0)
  val beta  = theta("beta")
  val gamma = theta("gamma")

  val Wt_i  = DenseVector(p.Wt_i)
  val Wt_j  = DenseVector(p.Wt_j)
  val Xg_ij = DenseVector(p.Xg_ij)
  val Xg_ji = DenseVector(p.Xg_ji)

  // ---- sigmoid preds ----
  val pi_i = safeSig(Wt_i dot beta)
  val pi_j = safeSig(Wt_j dot beta)
  val g_ij = safeSig(Xg_ij dot gamma)
  val g_ji = safeSig(Xg_ji dot gamma)

  val I_ij = if (p.yi >= p.yj) 1.0 else 0.0
  val I_ji = 1.0 - I_ij

  // ---- compute h ----
  val A    = p.zi * (1 - p.zj) / (2 * pi_i * (1 - pi_j))
  val B    = p.zj * (1 - p.zi) / (2 * pi_j * (1 - pi_i))
  val num1 = A * (I_ij - g_ij)
  val num2 = B * (I_ji - g_ji)
  val h1   = num1 + num2 + 0.5 * (g_ij + g_ji)
  val h2   = 0.5 * (p.zi + p.zj)
  val h3   = 0.5 * (p.zi * (1 - p.zj) * I_ij + p.zj * (1 - p.zi) * I_ji)
  val h    = DenseVector(h1, h2, h3)

  // ---- compute f ----
  val f1 = delta
  val f2 = 0.5 * (pi_i + pi_j)
  val f3 = 0.5 * (pi_i * (1 - pi_j) * g_ij + pi_j * (1 - pi_i) * g_ji)
  val f  = DenseVector(f1, f2, f3)

  // ---- gradient df ----
  // df1
  val df1Delta = 1.0

  // df2
  val dPiI    = pi_i * (1 - pi_i) * Wt_i
  val dPiJ    = pi_j * (1 - pi_j) * Wt_j
  val df2Beta = 0.5 * (dPiI + dPiJ)

  // df3
  val df3Beta = 0.5 * (
    ((1 - pi_j)*pi_i*(1 - pi_i)*Wt_i - pi_i*pi_j*(1 - pi_j)*Wt_j)*g_ij +
    ((1 - pi_i)*pi_j*(1 - pi_j)*Wt_j - pi_j*pi_i*(1 - pi_i)*Wt_i)*g_ji
  )
  val df3Gamma = 0.5 * (
    pi_i*(1 - pi_j)*g_ij*(1 - g_ij)*Xg_ij +
    pi_j*(1 - pi_i)*g_ji*(1 - g_ji)*Xg_ji
  )

  val pb = beta.length
  val qg = gamma.length
  val df = DenseMatrix.vertcat(
    DenseVector(1.0 +: Array.fill(pb + qg)(0.0): _*).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df2Beta, DenseVector.zeros[Double](qg)).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df3Beta, df3Gamma).toDenseMatrix
  )

  // ---- gradient dh ----
  // dh1 / dβ
  val dh1Beta = - (I_ij - g_ij)*A*((1.0 - pi_i)*Wt_i - pi_j*Wt_j) -
                 (I_ji - g_ji)*B*((1.0 - pi_j)*Wt_j - pi_i*Wt_i)
  // dh1 / dγ
  val dh1Gamma = -A*(I_ij - g_ij)*g_ij*(1 - g_ij)*Xg_ij -
                  B*(I_ji - g_ji)*g_ji*(1 - g_ji)*Xg_ji +
                  0.5*(g_ij*(1 - g_ij)*Xg_ij + g_ji*(1 - g_ji)*Xg_ji)
  // dh2, dh3 are zero
  val dh2Beta  = DenseVector.zeros[Double](pb)
  val dh2Gamma = DenseVector.zeros[Double](qg)
  val dh3Beta  = DenseVector.zeros[Double](pb)
  val dh3Gamma = DenseVector.zeros[Double](qg)

  // assemble dh matrix
  val dh = DenseMatrix.vertcat(
    DenseVector.vertcat(DenseVector(0.0), dh1Beta, dh1Gamma).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), dh2Beta, dh2Gamma).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), dh3Beta, dh3Gamma).toDenseMatrix
  )

  // return (h, f, df, dh)
  (h, f, df.toDenseVector, dh.toDenseVector)
}

computeHFDFDHforPair: (p: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


In [26]:
val hfDfDhRDD: RDD[(DenseVector[Double], DenseVector[Double], DenseVector[Double], DenseVector[Double])] =
  pairFeatureDS.rdd.mapPartitions { iter =>
    iter.grouped(200).flatMap { batch =>
      batch.map(p => computeHFDFDHforPair(p, theta))
    }
  }

val (hSumDf, fSumDf, dfSum, dhSum, countDf) = hfDfDhRDD.aggregate(
  (DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize), DenseVector.zeros[Double](vectorSize * thetaLength), DenseVector.zeros[Double](vectorSize * thetaLength), 0)
)(
  { case ((hAcc, fAcc, dfAcc, dhAcc, c), (h, f, df, dh)) => (hAcc + h, fAcc + f, dfAcc + df, dhAcc + dh, c + 1) },
  { case ((h1, f1, df1, dh1, c1), (h2, f2, df2, dh2, c2)) => (h1 + h2, f1 + f2, df1 + df2, dh1 + dh2, c1 + c2) }
)

val hMeanDf = hSumDf / countDf.toDouble
val fMeanDf = fSumDf / countDf.toDouble
val dfMeanDf = dfSum / countDf.toDouble
val dhMeanDf = dhSum / countDf.toDouble

println(s"Mean h: $hMeanDf")
println(s"Mean f: $fMeanDf")
println(s"Mean df: $dfMeanDf")
println(s"Mean dh: $dhMeanDf")


Mean h: DenseVector(0.07351268421514602, 0.49611558095403585, 0.02969533926451304)
Mean f: DenseVector(0.09999999999999999, 0.5, 0.125)
Mean df: DenseVector(1.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.10471979154322772, 0.0, 0.0, 0.14610285571637865, 0.0, 0.0, 0.09207643841461047, 0.0, 0.0, 0.0, 0.0625, 0.0, 0.0, 0.02617994788580693, 0.0, 0.0, 0.036525713929094664, 0.0, 0.0, 0.023019109603652617, 0.0, 0.0, 0.02617994788580693, 0.0, 0.0, 0.036525713929094664, 0.0, 0.0, 0.023019109603652617)
Mean dh: DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02196209092548669, 0.0, 0.0, 0.05765826238362978, 0.0, 0.0, -0.04267139663254082, 0.0, 0.0, 0.35662182894621347, 0.0, 0.0, 0.14039517278092065, 0.0, 0.0, 0.22089843059257508, 0.0, 0.0, 0.12950809670868135, 0.0, 0.0, 0.12941412731817734, 0.0, 0.0, 0.19206929940076023, 0.0, 0.0, 0.15084379502495174, 0.0, 0.0)


hfDfDhRDD = MapPartitionsRDD[63] at mapPartitions at <console>:51
hSumDf = DenseVector(0.6616141579363142, 4.465040228586322, 0.26725805338061737)
fSumDf = DenseVector(0.8999999999999999, 4.5, 1.125)
dfSum = DenseVector(9.0, 0.0, 0.0, 0.0, 2.25, 0.0, 0.0, 0.9424781238890495, 0.0, 0.0, 1.3149257014474078, 0.0, 0.0, 0.8286879457314942, 0.0, 0.0, 0.0, 0.5625, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, 0.20717198643287354, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, ...


DenseVector(9.0, 0.0, 0.0, 0.0, 2.25, 0.0, 0.0, 0.9424781238890495, 0.0, 0.0, 1.3149257014474078, 0.0, 0.0, 0.8286879457314942, 0.0, 0.0, 0.0, 0.5625, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, 0.20717198643287354, 0.0, 0.0, 0.23561953097226238, 0.0, 0.0, 0.32873142536185196, 0.0, 0.0, ...

## compute B, u, and perform fisher update

In [36]:
def computeBUforPair(
  pf: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double]) = {
  // 1) get h, f, D, M
  val (h, f, dfVec, dhVec) = computeHFDFDHforPair(pf, theta) 
  val d = dfVec.length / 3
  val D = new DenseMatrix(3, d, dfVec.data)       // 3×d
  val M = new DenseMatrix(3, d, dhVec.data)       // 3×d
  val Vinv = DenseMatrix.eye[Double](3)           // 3x3, identity matrix for simplicity

  // 2) G = D^T * V⁻¹  (d×3)
  val G = D.t * Vinv

  // 3) B_i = G * (D - M)   (d×d)
  val B_i = G * (D - M)

  // 4) u_i = G * (h - f)   (d)
  val u_i = G * (h - f)

  (B_i.toDenseVector, u_i)
}

lastException = null


computeBUforPair: (pf: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


null

In [22]:
val thetaInit =  Map(
  "delta" -> DenseVector(0.5),
  "beta" -> DenseVector.zeros[Double](p + 1),   // where p is dim(x)
  "gamma" -> DenseVector.zeros[Double](2 * p + 1)
)

thetaInit = Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))


Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))

In [23]:
// θ holds “delta”, “beta”, “gamma” as DenseVector’s
var theta: Map[String, DenseVector[Double]] = thetaInit

val d = theta("beta").length + theta("gamma").length + 1  // total params

val BURDD: RDD[(DenseVector[Double], DenseVector[Double])] =
  pairFeatureDS.rdd.mapPartitions { iter =>
    iter.grouped(200).flatMap { batch =>
      batch.map(p => computeBUforPair(p, theta))
    }
  }

val BUsum = BURDD.aggregate(
  (DenseVector.zeros[Double](d * d), DenseVector.zeros[Double](d), 0)
)(
  (acc, value) => {
    val BAcc = acc._1
    val UAcc = acc._2
    val c = acc._3
    val B = value._1
    val U = value._2
    (BAcc + B, UAcc + U, c + 1)
  },
  (acc1, acc2) => {
    val B1 = acc1._1
    val U1 = acc1._2
    val c1 = acc1._3
    val B2 = acc2._1
    val U2 = acc2._2
    val c2 = acc2._3
    (B1 + B2, U1 + U2, c1 + c2)
  }
)

val Btot = new DenseMatrix(d, d, BUsum._1.data) // d×d
val Utot = BUsum._2 // d
val countBU = BUsum._3

val Bmean = Btot / countBU.toDouble
val Umean = Utot / countBU.toDouble

println(s"Mean B: ${Bmean}")
println(s"Mean U: ${Umean}")

Mean B: 1.0  0.0                   -0.02196209092548669  ... (12 total)
0.0  0.0625                0.02617994788580693   ...
0.0  0.02617994788580693   0.01476407592104698   ...
0.0  0.036525713929094664  0.01653795314897124   ...
0.0  0.023019109603652617  0.008142367117191476  ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
Mean U: DenseVector(-0.42648731578485394, -9.71104761491062E-4, -4.940142386359029E-5, -0.0037512510970289497, 0.0028359487909702156, -0.005956541295967935, -0.002241982718161384, -0.00349778272868647, -0.0022908522261908468, -0.002241982718161384, -0.00349778272868647, -0.0022908522261908468)


theta = Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
d = 12
BURDD = MapPartitionsRDD[69] at mapPartitions at <console>:59
BUsum = (DenseVector(9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5625, 0.23561953097226238, 0.32873142536185196, 0.20717198643287354, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1976588183293802, 0.23561953097226238, 0.13287668328942281, 0.14884157834074116, 0.07328130405472329, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5189243614526681, 0.328731425...


(DenseVector(9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5625, 0.23561953097226238, 0.32873142536185196, 0.20717198643287354, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1976588183293802, 0.23561953097226238, 0.13287668328942281, 0.14884157834074116, 0.07328130405472329, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5189243614526681, 0.328731425...

In [24]:
Bmean

1.0  0.0                   -0.02196209092548669  ... (12 total)
0.0  0.0625                0.02617994788580693   ...
0.0  0.02617994788580693   0.01476407592104698   ...
0.0  0.036525713929094664  0.01653795314897124   ...
0.0  0.023019109603652617  0.008142367117191476  ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...
0.0  0.0                   0.0                   ...


In [25]:
Umean

DenseVector(-0.42648731578485394, -9.71104761491062E-4, -4.940142386359029E-5, -0.0037512510970289497, 0.0028359487909702156, -0.005956541295967935, -0.002241982718161384, -0.00349778272868647, -0.0022908522261908468, -0.002241982718161384, -0.00349778272868647, -0.0022908522261908468)

In [30]:
// Fisher scoring update

val J = - Bmean
val lambda = 0.1
val Iden = DenseMatrix.eye[Double](d) // Identity matrix of size d
val step = (J + lambda * Iden) \ (-Umean)

lastException = null
J = 
lambda = 0.1
Iden = 


-1.0  -0.0                   0.02196209092548669    ... (12 total)
-0.0  -0.0625                -0.02617994788580693   ...
-0.0  -0.02617994788580693   -0.01476407592104698   ...
-0.0  -0.036525713929094664  -0.01653795314897124   ...
-0.0  -0.023019109603652617  -0.008142367117191476  ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
-0.0  -0.0                   -0.0                   ...
1.0  0.0  0....


In [31]:
step

DenseVector(-0.42153383177574355, -0.32043466144285104, -0.14294038128948897, -0.15860999367901168, -0.1499575629414855, 0.06541702450124144, 0.02500222524641803, 0.03848536905458946, 0.02505503874790853, 0.02500222524641804, 0.03848536905458946, 0.025055038747908525)

In [32]:
theta

Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))

In [34]:
def updateTheta(
  theta: Map[String, DenseVector[Double]],
  step: DenseVector[Double]
): Map[String, DenseVector[Double]] = {
  var idx    = 0
  // delta (scalar)
  val newDelta = theta("delta") + step(idx to idx)
  idx        += 1
  // beta
  val bLen    = theta("beta").length
  val newBeta = theta("beta") + step(idx until idx + bLen)
  idx        += bLen
  // gamma
  val gLen     = theta("gamma").length
  val newGamma = theta("gamma") + step(idx until idx + gLen)

  Map(
    "delta" -> newDelta,
    "beta"  -> newBeta,
    "gamma" -> newGamma
  )
}

updateTheta: (theta: Map[String,breeze.linalg.DenseVector[Double]], step: breeze.linalg.DenseVector[Double])Map[String,breeze.linalg.DenseVector[Double]]


In [35]:
theta = updateTheta(theta, step)

theta = Map(delta -> DenseVector(0.07846616822425645), beta -> DenseVector(-0.32043466144285104, -0.14294038128948897, -0.15860999367901168, -0.1499575629414855), gamma -> DenseVector(0.06541702450124144, 0.02500222524641803, 0.03848536905458946, 0.02505503874790853, 0.02500222524641804, 0.03848536905458946, 0.025055038747908525))


Map(delta -> DenseVector(0.07846616822425645), beta -> DenseVector(-0.32043466144285104, -0.14294038128948897, -0.15860999367901168, -0.1499575629414855), gamma -> DenseVector(0.06541702450124144, 0.02500222524641803, 0.03848536905458946, 0.02505503874790853, 0.02500222524641804, 0.03848536905458946, 0.025055038747908525))

In [36]:
theta

Map(delta -> DenseVector(0.07846616822425645), beta -> DenseVector(-0.32043466144285104, -0.14294038128948897, -0.15860999367901168, -0.1499575629414855), gamma -> DenseVector(0.06541702450124144, 0.02500222524641803, 0.03848536905458946, 0.02505503874790853, 0.02500222524641804, 0.03848536905458946, 0.025055038747908525))

In [2]:
def computeStep(pairFeatureDS: Dataset[PairFeatures], theta: Map[String, DenseVector[Double]]): DenseVector[Double] = {
  val BURDD: RDD[(DenseVector[Double], DenseVector[Double])] =
    pairFeatureDS.rdd.mapPartitions { iter =>
      iter.grouped(200).flatMap { batch =>
        batch.map(p => computeBUforPair(p, theta))
      }
    }

  val BUsum = BURDD.aggregate(
    (DenseVector.zeros[Double](d * d), DenseVector.zeros[Double](d), 0)
  )(
    (acc, value) => {
      val BAcc = acc._1
      val UAcc = acc._2
      val c = acc._3
      val B = value._1
      val U = value._2
      (BAcc + B, UAcc + U, c + 1)
    },
    (acc1, acc2) => {
      val B1 = acc1._1
      val U1 = acc1._2
      val c1 = acc1._3
      val B2 = acc2._1
      val U2 = acc2._2
      val c2 = acc2._3
      (B1 + B2, U1 + U2, c1 + c2)
    }
  )

  val Btot = new DenseMatrix(d, d, BUsum._1.data) // d×d
  val Utot = BUsum._2 // d
  val countBU = BUsum._3

  val Bmean = Btot / countBU.toDouble
  val Umean = Utot / countBU.toDouble

  // Fisher scoring update

  val J = - Bmean
  val lambda = 0.1
  val Iden = DenseMatrix.eye[Double](d) // Identity matrix of size d
  val step = (J + lambda * Iden) \ (-Umean)

  step
}

computeStep: (pairFeatureDS: org.apache.spark.sql.Dataset[PairFeatures], theta: Map[String,breeze.linalg.DenseVector[Double]])breeze.linalg.DenseVector[Double]


In [3]:
val step = computeStep(pairFeatureDS, theta)
theta = updateTheta(theta, step)
println(s"Updated theta: $theta")

Updated theta: Map(delta -> DenseVector(-0.8599798772551582), beta -> DenseVector(-3.461557492968856, -1.499875041358923, -1.9856162244277393, -1.3263636445788052), gamma -> DenseVector(0.019980803132977315, 8.425610394920229E-4, 0.010308204784821178, 0.009590864647366998, 2.2034939143689494E-4, 0.00973069743719486, 0.00913878995438562))


step = DenseVector(-0.9384460454794147, -3.1411228315260047, -1.356934660069434, -1.8270062307487276, -1.1764060816373196, -0.045436221368264126, -0.02415966420692601, -0.028177164269768285, -0.01546417410054153, -0.024781875854981147, -0.028754671617394603, -0.015916248793522905)
theta = Map(delta -> DenseVector(-0.8599798772551582), beta -> DenseVector(-3.461557492968856, -1.499875041358923, -1.9856162244277393, -1.3263636445788052), gamma -> DenseVector(0.019980803132977315, 8.425610394920229E-4, 0.010308204784821178, 0.009590864647366998, 2.2034939143689494E-4, 0.00973069743719486, 0.00913878995438562))


Map(delta -> DenseVector(-0.8599798772551582), beta -> DenseVector(-3.461557492968856, -1.499875041358923, -1.9856162244277393, -1.3263636445788052), gamma -> DenseVector(0.019980803132977315, 8.425610394920229E-4, 0.010308204784821178, 0.009590864647366998, 2.2034939143689494E-4, 0.00973069743719486, 0.00913878995438562))