In [1]:
import org.apache.spark.sql.{Dataset, Encoder, Encoders,SparkSession}
import org.apache.spark.sql.functions._
import scala.util.Random
import org.apache.spark.rdd.RDD
// import scala.reflect.ClassTag

In [2]:
implicit val spark: SparkSession = SparkSession.builder()
  .appName("DRGU Convergence")
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

spark = org.apache.spark.sql.SparkSession@65bf05b7


org.apache.spark.sql.SparkSession@65bf05b7

## simulate medium data

In [4]:
case class Obs(
  i: String,                   // cluster ID
  x: Array[Double],           // covariates
  y: Double,                  // outcome
  timeIndex: Option[Int] = None,      // optional time index
  z: Option[Double] = None       // optional treatment indicator
)

case class ObsPair(left: Obs, right: Obs)

defined class Obs
defined class ObsPair


In [86]:
def sampleAllPairs(obsDS: Dataset[Obs])(implicit spark: SparkSession): Dataset[ObsPair] = {
  import spark.implicits._

  // Add index to each observation
  val indexed: RDD[(Obs, Long)] = obsDS.rdd.zipWithIndex()

  // Cartesian join and filter only (i < j) to avoid duplicates
  val allPairs: RDD[ObsPair] = indexed.cartesian(indexed)
    .filter { case ((_, i), (_, j)) => i < j }
    .map { case ((obs1, _), (obs2, _)) => ObsPair(obs1, obs2) }

  // Convert back to Dataset
  spark.createDataset(allPairs)
}

sampleAllPairs: (obsDS: org.apache.spark.sql.Dataset[Obs])(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]


In [87]:
val numClusters = 100
val numObsPerCluster = 1
val random = new Random()

val data = (1 to numClusters).flatMap { clusterId =>
  (1 to numObsPerCluster).map { obsId =>
    Obs(
      i = s"c$clusterId",
      x = Array.fill(3)(random.nextDouble()), // 3 covariates
      y = random.nextDouble(),               // random outcome
      timeIndex = Some(obsId),               // time index
      z = Some(if (random.nextDouble() > 0.5) 1.0 else 0.0)          // random treatment indicator
    )
  }
}

// Convert to Dataset
val df: Dataset[Obs] = data.toDS()

// Show the generated data
df.show(false)

numClusters = 100
numObsPerCluster = 1
random = scala.util.Random@cf4ec75
data = Vector(Obs(c1,[D@5b1ed8a4,0.26727861661399677,Some(1),Some(0.0)), Obs(c2,[D@2e42f723,0.5149149160314703,Some(1),Some(0.0)), Obs(c3,[D@7debd857,0.9755815222132482,Some(1),Some(0.0)), Obs(c4,[D@565714db,0.6975674830727124,Some(1),Some(0.0)), Obs(c5,[D@7b212d4d,0.11503153848005032,Some(1),Some(0.0)), Obs(c6,[D@426336e5,0.549017848538372,Some(1),Some(0.0)), Obs(c7,[D@61e9679e,0.3231093820614067,Some(1),Some(1.0)), Obs(c8,[D@f7656b4,0.3172881173615496,Some(1),Some(0.0)), Obs(c9,[D@7d4f5952,0.48898110165869113,Some(1),Some(1.0)), Obs(c10,[D@26409c00,0.6671974452966365,Some(1),Some(0.0)), Obs(c11,[D@cbfcac1,0.3642719686297312,Some(1),Some(1.0...


+---+---------------------------------------------------------------+-------------------+---------+---+
|i  |x                                                              |y                  |timeIndex|z  |
+---+---------------------------------------------------------------+-------------------+---------+---+
|c1 |[0.9918427154549613, 0.983865677645629, 0.7492280517239661]    |0.26727861661399677|1        |0.0|
|c2 |[0.11866828159590148, 0.20402558516094627, 0.9908379988373542] |0.5149149160314703 |1        |0.0|
|c3 |[0.0070273262388163005, 0.4692129118162638, 0.434815518884742] |0.9755815222132482 |1        |0.0|
|c4 |[0.13355639374335848, 0.6079674441559634, 0.16040416784189582] |0.6975674830727124 |1        |0.0|
|c5 |[0.25433411385524496, 0.15869414570899432, 0.15211071681610944]|0.11503153848005032|1        |0.0|
|c6 |[0.45407733423450114, 0.3244968416180056, 0.12241357953110676] |0.549017848538372  |1        |0.0|
|c7 |[0.5849997300425186, 0.5250342098938594, 0.4780013611862648

Vector(Obs(c1,[D@5b1ed8a4,0.26727861661399677,Some(1),Some(0.0)), Obs(c2,[D@2e42f723,0.5149149160314703,Some(1),Some(0.0)), Obs(c3,[D@7debd857,0.9755815222132482,Some(1),Some(0.0)), Obs(c4,[D@565714db,0.6975674830727124,Some(1),Some(0.0)), Obs(c5,[D@7b212d4d,0.11503153848005032,Some(1),Some(0.0)), Obs(c6,[D@426336e5,0.549017848538372,Some(1),Some(0.0)), Obs(c7,[D@61e9679e,0.3231093820614067,Some(1),Some(1.0)), Obs(c8,[D@f7656b4,0.3172881173615496,Some(1),Some(0.0)), Obs(c9,[D@7d4f5952,0.48898110165869113,Some(1),Some(1.0)), Obs(c10,[D@26409c00,0.6671974452966365,Some(1),Some(0.0)), Obs(c11,[D@cbfcac1,0.3642719686297312,Some(1),Some(1.0...

sampleAllPairs: (obsDS: org.apache.spark.sql.Dataset[Obs])(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]


## function for generate paired data

In [2]:
def assignDeterministicPartition(
    ds: Dataset[Obs],
    idColumn: String = "i",  // cluster ID or any stable ID
    numPartitions: Int
)(implicit spark: SparkSession): Dataset[(Obs, Int)] = {
  import spark.implicits._

  val withBucket = ds.withColumn("bucket", (xxhash64(col(idColumn)) % numPartitions).cast("int"))
  val repartitioned = withBucket.repartition($"bucket")
  repartitioned.select(
    struct($"i", $"x", $"y", $"timeIndex", $"z").as("obs"),
    $"bucket"
  ).as[(Obs, Int)]
}


assignDeterministicPartition: (ds: org.apache.spark.sql.Dataset[Obs], idColumn: String, numPartitions: Int)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[(Obs, Int)]


In [3]:
def sampleObsPairsFromRepartitioned(
    partitioned: Dataset[(Obs, Int)],
    samplePerPartition: Int,
    seed: Long = System.currentTimeMillis()
)(implicit spark: SparkSession): Dataset[ObsPair] = {
  import spark.implicits._

  partitioned
    .groupByKey { case (_, bucket) => bucket }
    .flatMapGroups { case (_, iter) =>
      val rand = new scala.util.Random(seed)
      val items = iter.map(_._1).toIndexedSeq
      val n = items.length

      if (n < 2) Iterator.empty
      else {
        val sampledPairs = Iterator.fill(samplePerPartition) {
          var i = 0
          var j = 0
          do {
            i = rand.nextInt(n)
            j = rand.nextInt(n)
          } while (i == j)

          val (ii, jj) = if (i < j) (i, j) else (j, i)
          ObsPair(items(ii), items(jj))
        }
        sampledPairs
      }
    }
}

def sampleObsPairsFromRepartitionedRow(
    partitioned: Dataset[(Obs, Int)],
    matchesPerRow: Int,
    seed: Long = System.currentTimeMillis()
)(implicit spark:SparkSession): Dataset[ObsPair] = {
  import spark.implicits._

  partitioned
    .groupByKey { case (_, bucket) => bucket }
    .flatMapGroups { case (_, iter) =>
      val rand = new scala.util.Random(seed)
      val items = iter.map(_._1).toIndexedSeq
      val n = items.length

      if (n < 2) Iterator.empty
      else {
        val sampledPairs  = (0 until n).flatMap { i =>
          (0 until matchesPerRow).map { _ =>
            var j = i
            while (j == i) j = rand.nextInt(n)
            ObsPair(items(i), items(j))
          }
        }
        sampledPairs
      }
    }
}

sampleObsPairsFromRepartitioned: (partitioned: org.apache.spark.sql.Dataset[(Obs, Int)], samplePerPartition: Int, seed: Long)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]
sampleObsPairsFromRepartitionedRow: (partitioned: org.apache.spark.sql.Dataset[(Obs, Int)], matchesPerRow: Int, seed: Long)(implicit spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.Dataset[ObsPair]


In [4]:
import breeze.linalg._

case class PairFeatures(
  Wt_i: Array[Double],
  Wt_j: Array[Double],
  Xg_ij: Array[Double],
  Xg_ji: Array[Double],
  yi: Double,
  yj: Double,
  zi: Double,
  zj: Double
)

defined class PairFeatures


In [5]:
def toPairFeatures(pair: ObsPair): PairFeatures = {
  val Obs(_, x1, y1, _, Some(z1)) = pair.left
  val Obs(_, x2, y2, _, Some(z2)) = pair.right

  val Wt_i = Array(1.0) ++ x1
  val Wt_j = Array(1.0) ++ x2
  val Xg_ij = Array(1.0) ++ x1 ++ x2
  val Xg_ji = Array(1.0) ++ x2 ++ x1

  PairFeatures(Wt_i, Wt_j, Xg_ij, Xg_ji, y1, y2, z1, z2)
}

toPairFeatures: (pair: ObsPair)PairFeatures


### use function to generate paired data

In [42]:
repartitioned.unpersist()

[obs: struct<i: string, x: array<double> ... 3 more fields>, bucket: int]

In [43]:
val repartitioned = assignDeterministicPartition(df, idColumn = "i", numPartitions = 10).cache()
repartitioned.count() // trigger caching

repartitioned = [obs: struct<i: string, x: array<double> ... 3 more fields>, bucket: int]


100

In [89]:
// all pairs
val sampledPairs: Dataset[ObsPair] = sampleAllPairs(df)

sampledPairs = [left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]


[left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]

In [10]:
val sampledPairs: Dataset[ObsPair] =
  sampleObsPairsFromRepartitioned(repartitioned, samplePerPartition = 100)

sampledPairs.show()

+--------------------+--------------------+
|                left|               right|
+--------------------+--------------------+
|{c99, [0.24858109...|{c113, [0.9366210...|
|{c180, [0.8132829...|{c974, [0.5968742...|
|{c229, [0.9142135...|{c382, [0.5414748...|
|{c570, [0.0529473...|{c642, [0.5457118...|
|{c466, [0.9257026...|{c855, [0.9445837...|
|{c113, [0.9366210...|{c974, [0.5968742...|
|{c99, [0.24858109...|{c982, [0.1866435...|
|{c473, [0.0114870...|{c836, [0.8790051...|
|{c99, [0.24858109...|{c678, [0.0119801...|
|{c573, [0.9632667...|{c965, [0.2515332...|
|{c238, [0.2204298...|{c573, [0.9632667...|
|{c296, [0.7088656...|{c965, [0.2515332...|
|{c113, [0.9366210...|{c965, [0.2515332...|
|{c68, [0.15071439...|{c573, [0.9632667...|
|{c92, [0.03963054...|{c642, [0.5457118...|
|{c92, [0.03963054...|{c838, [0.3137387...|
|{c187, [0.0267444...|{c974, [0.5968742...|
|{c891, [0.3015660...|{c974, [0.5968742...|
|{c99, [0.24858109...|{c600, [0.0336387...|
|{c410, [0.5248068...|{c542, [0.

sampledPairs = [left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]


[left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]

In [90]:
// repartitioned.unpersist()

val pairFeatureDS: Dataset[PairFeatures] = sampledPairs.map(toPairFeatures)

pairFeatureDS.show()

pairFeatureDS = [Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]


+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+---+---+
|                Wt_i|                Wt_j|               Xg_ij|               Xg_ji|                 yi|                 yj| zi| zj|
+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+---+---+
|[1.0, 0.991842715...|[1.0, 0.118668281...|[1.0, 0.991842715...|[1.0, 0.118668281...|0.26727861661399677| 0.5149149160314703|0.0|0.0|
|[1.0, 0.991842715...|[1.0, 0.007027326...|[1.0, 0.991842715...|[1.0, 0.007027326...|0.26727861661399677| 0.9755815222132482|0.0|0.0|
|[1.0, 0.991842715...|[1.0, 0.133556393...|[1.0, 0.991842715...|[1.0, 0.133556393...|0.26727861661399677| 0.6975674830727124|0.0|0.0|
|[1.0, 0.991842715...|[1.0, 0.254334113...|[1.0, 0.991842715...|[1.0, 0.254334113...|0.26727861661399677|0.11503153848005032|0.0|0.0|
|[1.0, 0.991842715...|[1.0, 0.454077334...|[1.0, 0.991842715..

[Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]

In [91]:
pairFeatureDS.count()

4950

## functions for fisher update

In [25]:
def safeSig(x: Double): Double = {
  val clipped = math.max(-10.0, math.min(10.0, x))
  1.0 / (1.0 + math.exp(-clipped))
}

safeSig: (x: Double)Double


In [26]:
def computeHFDFDHforPair(
  p: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double], DenseVector[Double], DenseVector[Double]) = {
  val delta = theta("delta")(0)
  val beta  = theta("beta")
  val gamma = theta("gamma")

  val Wt_i  = DenseVector(p.Wt_i)
  val Wt_j  = DenseVector(p.Wt_j)
  val Xg_ij = DenseVector(p.Xg_ij)
  val Xg_ji = DenseVector(p.Xg_ji)

  // ---- sigmoid preds ----
  val pi_i = safeSig(Wt_i dot beta)
  val pi_j = safeSig(Wt_j dot beta)
  val g_ij = safeSig(Xg_ij dot gamma)
  val g_ji = safeSig(Xg_ji dot gamma)

  val I_ij = if (p.yi >= p.yj) 1.0 else 0.0
  val I_ji = 1.0 - I_ij

  // ---- compute h ----
  val A    = p.zi * (1 - p.zj) / (2 * pi_i * (1 - pi_j))
  val B    = p.zj * (1 - p.zi) / (2 * pi_j * (1 - pi_i))
  val num1 = A * (I_ij - g_ij)
  val num2 = B * (I_ji - g_ji)
  val h1   = num1 + num2 + 0.5 * (g_ij + g_ji)
  val h2   = 0.5 * (p.zi + p.zj)
  val h3   = 0.5 * (p.zi * (1 - p.zj) * I_ij + p.zj * (1 - p.zi) * I_ji)
  val h    = DenseVector(h1, h2, h3)

  // ---- compute f ----
  val f1 = delta
  val f2 = 0.5 * (pi_i + pi_j)
  val f3 = 0.5 * (pi_i * (1 - pi_j) * g_ij + pi_j * (1 - pi_i) * g_ji)
  val f  = DenseVector(f1, f2, f3)

  // ---- gradient df ----
  // df1
  val df1Delta = 1.0

  // df2
  val dPiI    = pi_i * (1 - pi_i) * Wt_i
  val dPiJ    = pi_j * (1 - pi_j) * Wt_j
  val df2Beta = 0.5 * (dPiI + dPiJ)

  // df3
  val df3Beta = 0.5 * (
    ((1 - pi_j)*pi_i*(1 - pi_i)*Wt_i - pi_i*pi_j*(1 - pi_j)*Wt_j)*g_ij +
    ((1 - pi_i)*pi_j*(1 - pi_j)*Wt_j - pi_j*pi_i*(1 - pi_i)*Wt_i)*g_ji
  )
  val df3Gamma = 0.5 * (
    pi_i*(1 - pi_j)*g_ij*(1 - g_ij)*Xg_ij +
    pi_j*(1 - pi_i)*g_ji*(1 - g_ji)*Xg_ji
  )

  val pb = beta.length
  val qg = gamma.length
  val df = DenseMatrix.vertcat(
    DenseVector(1.0 +: Array.fill(pb + qg)(0.0): _*).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df2Beta, DenseVector.zeros[Double](qg)).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), df3Beta, df3Gamma).toDenseMatrix
  )

  // ---- gradient dh ----
  // dh1 / dβ
  val dh1Beta = - (I_ij - g_ij)*A*((1.0 - pi_i)*Wt_i - pi_j*Wt_j) -
                 (I_ji - g_ji)*B*((1.0 - pi_j)*Wt_j - pi_i*Wt_i)
  // dh1 / dγ
  val dh1Gamma = -A*(I_ij - g_ij)*g_ij*(1 - g_ij)*Xg_ij -
                  B*(I_ji - g_ji)*g_ji*(1 - g_ji)*Xg_ji +
                  0.5*(g_ij*(1 - g_ij)*Xg_ij + g_ji*(1 - g_ji)*Xg_ji)
  // dh2, dh3 are zero
  val dh2Beta  = DenseVector.zeros[Double](pb)
  val dh2Gamma = DenseVector.zeros[Double](qg)
  val dh3Beta  = DenseVector.zeros[Double](pb)
  val dh3Gamma = DenseVector.zeros[Double](qg)

  // assemble dh matrix
  val dh = DenseMatrix.vertcat(
    DenseVector.vertcat(DenseVector(0.0), dh1Beta, dh1Gamma).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), dh2Beta, dh2Gamma).toDenseMatrix,
    DenseVector.vertcat(DenseVector(0.0), dh3Beta, dh3Gamma).toDenseMatrix
  )

  // return (h, f, df, dh)
  (h, f, df.toDenseVector, dh.toDenseVector)
}

computeHFDFDHforPair: (p: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


In [27]:
def computeBUforPair(
  pf: PairFeatures,
  theta: Map[String, DenseVector[Double]]
): (DenseVector[Double], DenseVector[Double]) = {
  // 1) get h, f, D, M
  val (h, f, dfVec, dhVec) = computeHFDFDHforPair(pf, theta) 
  val d = dfVec.length / 3
  val D = new DenseMatrix(3, d, dfVec.data)       // 3×d
  val M = new DenseMatrix(3, d, dhVec.data)       // 3×d
  val Vinv = DenseMatrix.eye[Double](3)           // 3x3, identity matrix for simplicity

  // 2) G = D^T * V⁻¹  (d×3)
  val G = D.t * Vinv

  // 3) B_i = G * (D - M)   (d×d)
  val B_i = G * (D - M)

  // 4) u_i = G * (h - f)   (d)
  val u_i = G * (h - f)

  (B_i.toDenseVector, u_i)
}

computeBUforPair: (pf: PairFeatures, theta: Map[String,breeze.linalg.DenseVector[Double]])(breeze.linalg.DenseVector[Double], breeze.linalg.DenseVector[Double])


In [28]:
def updateTheta(
  theta: Map[String, DenseVector[Double]],
  step: DenseVector[Double]
): Map[String, DenseVector[Double]] = {
  var idx    = 0
  // delta (scalar)
  val newDelta = theta("delta") + step(idx to idx)
  idx        += 1
  // beta
  val bLen    = theta("beta").length
  val newBeta = theta("beta") + step(idx until idx + bLen)
  idx        += bLen
  // gamma
  val gLen     = theta("gamma").length
  val newGamma = theta("gamma") + step(idx until idx + gLen)

  Map(
    "delta" -> newDelta,
    "beta"  -> newBeta,
    "gamma" -> newGamma
  )
}

updateTheta: (theta: Map[String,breeze.linalg.DenseVector[Double]], step: breeze.linalg.DenseVector[Double])Map[String,breeze.linalg.DenseVector[Double]]


In [29]:
def generalizedInverse(matrix: DenseMatrix[Double]): DenseMatrix[Double] = {
  val svd.SVD(u, s, vt) = svd(matrix)
  val sInv = DenseMatrix.zeros[Double](s.length, s.length)
  
  // Invert non-zero singular values
  for (i <- 0 until s.length) {
    if (s(i) > 1e-10) { // Threshold to avoid division by zero
      sInv(i, i) = 1.0 / s(i)
    }
  }
  
  // Compute the pseudoinverse
  vt.t * sInv * u.t
}

generalizedInverse: (matrix: breeze.linalg.DenseMatrix[Double])breeze.linalg.DenseMatrix[Double]


In [30]:
def computeStep(pairFeatureDS: Dataset[PairFeatures], theta: Map[String, DenseVector[Double]]): DenseVector[Double] = {
  val d = theta("beta").length + theta("gamma").length + 1  // total params

  val BURDD: RDD[(DenseVector[Double], DenseVector[Double])] =
    pairFeatureDS.rdd.mapPartitions { iter =>
      iter.grouped(200).flatMap { batch =>
        batch.map(p => computeBUforPair(p, theta))
      }
    }

  val BUsum = BURDD.aggregate(
    (DenseVector.zeros[Double](d * d), DenseVector.zeros[Double](d), 0)
  )(
    (acc, value) => {
      val BAcc = acc._1
      val UAcc = acc._2
      val c = acc._3
      val B = value._1
      val U = value._2
      (BAcc + B, UAcc + U, c + 1)
    },
    (acc1, acc2) => {
      val B1 = acc1._1
      val U1 = acc1._2
      val c1 = acc1._3
      val B2 = acc2._1
      val U2 = acc2._2
      val c2 = acc2._3
      (B1 + B2, U1 + U2, c1 + c2)
    }
  )

  val Btot = new DenseMatrix(d, d, BUsum._1.data) // d×d
  val Utot = BUsum._2 // d
  val countBU = BUsum._3

  val Bmean = Btot / countBU.toDouble
  val Umean = Utot / countBU.toDouble

  // Fisher scoring update

  val J = - Bmean
  val lambda = 0.000
  val Iden = DenseMatrix.eye[Double](d) // Identity matrix of size d
  val step = generalizedInverse(J + lambda * Iden) * (-Umean)

  step
}

computeStep: (pairFeatureDS: org.apache.spark.sql.Dataset[PairFeatures], theta: Map[String,breeze.linalg.DenseVector[Double]])breeze.linalg.DenseVector[Double]


### use functions to update theta

In [92]:
val p = 3
val thetaInit =  Map(
  "delta" -> DenseVector(0.5),
  "beta" -> DenseVector.zeros[Double](p + 1),   // where p is dim(x)
  "gamma" -> DenseVector.zeros[Double](2 * p + 1)
)

// θ holds “delta”, “beta”, “gamma” as DenseVector’s
var theta: Map[String, DenseVector[Double]] = thetaInit

p = 3
thetaInit = Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
theta = Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))


Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))

In [32]:
val step = computeStep(pairFeatureDS, theta)
theta = updateTheta(theta, step)
println(s"Updated theta: $theta_updated")

step = DenseVector(0.06198679902031275, -0.13865192327808404, 0.09768482532165743, -0.1110750597878438, 0.37614075514617396, 0.32590325936754716, 0.15008069789480064, 0.05127167894737973, -0.40196636289840587, 0.15004688615406966, 0.05126495688003067, -0.4017221217750656)
theta = Map(delta -> DenseVector(0.5619867990203128), beta -> DenseVector(-0.13865192327808404, 0.09768482532165743, -0.1110750597878438, 0.37614075514617396), gamma -> DenseVector(0.32590325936754716, 0.15008069789480064, 0.05127167894737973, -0.40196636289840587, 0.15004688615406966, 0.05126495688003067, -0.4017221217750656))


Updated theta: Map(delta -> DenseVector(0.5308551422833616), beta -> DenseVector(0.28271462397361746, 0.13873782088654424, 0.14926160005343678, 0.13333156094939042), gamma -> DenseVector(-0.005432798717790621, -0.003125994201698155, -0.0029221142288405086, -0.0013614447598632427, -0.003125994201698155, -0.00292211422884051, -0.0013614447598632431))


Map(delta -> DenseVector(0.5619867990203128), beta -> DenseVector(-0.13865192327808404, 0.09768482532165743, -0.1110750597878438, 0.37614075514617396), gamma -> DenseVector(0.32590325936754716, 0.15008069789480064, 0.05127167894737973, -0.40196636289840587, 0.15004688615406966, 0.05126495688003067, -0.4017221217750656))

In [95]:
val tol = 1e-4
var diff = 10.0
var iteration = 0
var theta: Map[String, DenseVector[Double]] = thetaInit
val sampledPairs = sampleAllPairs(df)
val pairedDataset = sampledPairs.map(toPairFeatures)

while (diff > tol && iteration < 20) {
  //val sampledPairs = sampleObsPairsFromRepartitioned(repartitioned, samplePerPartition = 500)
  //val pairedDataset = sampledPairs.map(toPairFeatures)
  val step = computeStep(pairedDataset, theta)
  val theta_updated = updateTheta(theta, step)
  diff = norm(step, 2) // L2 norm of the step
  // println(s"Computed step: $step")
  // println(s"Updated theta: $theta_updated")
  println(s"Iteration $iteration: diff = $diff")
  theta = theta_updated
  iteration += 1
}

Iteration 0: diff = 1.9639061305544048
Iteration 1: diff = 3.3423466854396637
Iteration 2: diff = 0.44654409203274587
Iteration 3: diff = 0.14060876054956517
Iteration 4: diff = 0.03911302264068535
Iteration 5: diff = 0.011702273233117457
Iteration 6: diff = 0.0034896765488990457
Iteration 7: diff = 0.0010550043125069457
Iteration 8: diff = 3.184404990516072E-4


tol = 1.0E-4
diff = 9.679347392356155E-5
iteration = 10
theta = Map(delta -> DenseVector(0.4407139218431786), beta -> DenseVector(0.9447480621142229, -0.3554965085241989, -0.2256921198870797, -1.5930217922594627), gamma -> DenseVector(-0.1728416454986385, -1.052860061081924, -1.422205878818063, -1.839956525554742, 1.3221326314862196, 1.426023655655939, 1.0283232733342655))
sampledPairs = [left: struct<i: string, x: array<double> ... 3 more fields>, right: struct<i: string, x: array<double> ... 3 more fields>]
pairedDataset = [Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]


Iteration 9: diff = 9.679347392356155E-5


[Wt_i: array<double>, Wt_j: array<double> ... 6 more fields]

In [24]:
// θ holds “delta”, “beta”, “gamma” as DenseVector’s
var theta: Map[String, DenseVector[Double]] = thetaInit

val d = theta("beta").length + theta("gamma").length + 1  // total params

val BURDD: RDD[(DenseVector[Double], DenseVector[Double])] =
  pairFeatureDS.rdd.mapPartitions { iter =>
    iter.grouped(200).flatMap { batch =>
      batch.map(p => computeBUforPair(p, theta))
    }
  }

val BUsum = BURDD.aggregate(
  (DenseVector.zeros[Double](d * d), DenseVector.zeros[Double](d), 0)
)(
  (acc, value) => {
    val BAcc = acc._1
    val UAcc = acc._2
    val c = acc._3
    val B = value._1
    val U = value._2
    (BAcc + B, UAcc + U, c + 1)
  },
  (acc1, acc2) => {
    val B1 = acc1._1
    val U1 = acc1._2
    val c1 = acc1._3
    val B2 = acc2._1
    val U2 = acc2._2
    val c2 = acc2._3
    (B1 + B2, U1 + U2, c1 + c2)
  }
)

val Btot = new DenseMatrix(d, d, BUsum._1.data) // d×d
val Utot = BUsum._2 // d
val countBU = BUsum._3

val Bmean = Btot / countBU.toDouble
val Umean = Utot / countBU.toDouble

println(s"Mean B: ${Bmean}")
println(s"Mean U: ${Umean}")

theta = Map(delta -> DenseVector(0.5), beta -> DenseVector(0.0, 0.0, 0.0, 0.0), gamma -> DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
d = 12
BURDD = MapPartitionsRDD[3539] at mapPartitions at <console>:55
BUsum = (DenseVector(1900.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 118.75, 57.81135249334939, 59.765317135018435, 58.70424805411324, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0727738293904063, 57.81135249334939, 32.86966917459252, 29.1444579566715, 28.567094696492642, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4120854171027657, 59.765317135018435, 2...


Mean B: 1.0  0.0                   0.0010909335944160033  ... (12 total)
0.0  0.0625                0.030427027628078626   ...
0.0  0.030427027628078626  0.017299825881364483   ...
0.0  0.031455430071062336  0.015339188398248157   ...
0.0  0.030896972660059602  0.015035312998154022   ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
0.0  0.0                   0.0                    ...
Mean U: DenseVector(0.03, 0.0024342105263157896, 0.0014227572856000085, 9.293875752235241E-4, 0.0021970846715712988, 4.93421052631579E-4, 2.8734645485617015E-4, 2.661329751133956E-4, 1.1105092974345534E-4, 2.8734645485617015E-4, 2.661329751133956E-4, 1.1105092974345534E-4)


(DenseVector(1900.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 118.75, 57.81135249334939, 59.765317135018435, 58.70424805411324, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0727738293904063, 57.81135249334939, 32.86966917459252, 29.1444579566715, 28.567094696492642, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4120854171027657, 59.765317135018435, 2...