In [None]:
import sys
from random import random
from operator import add
import os

In [None]:
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [None]:
def square_2_random(_: int) -> float:
    x = random() * 2 - 1
    y = random() * 2 - 1
    return 1 if x ** 2 + y ** 2 <= 1 else 0

In [None]:
def get_spark_on_k8s_conf(spark_conf, sgx_enabled = True):
    if sgx_enabled:
        return (spark_conf.set("spark.kubernetes.sgx.enabled", "true")
            .set("spark.kubernetes.sgx.driver.mem", "64g")
            .set("spark.kubernetes.sgx.driver.jvm.mem", "12g")
            .set("spark.kubernetes.sgx.executor.mem", "64g")
            .set("spark.kubernetes.sgx.executor.jvm.mem", "12g"))
    else:
        return spark_conf

In [None]:
master = os.getenv("RUNTIME_SPARK_MASTER")
print("k8s master url is " + str(master))
image = os.getenv("RUNTIME_K8S_SPARK_IMAGE")
print("executor image is " + str(image))
driver_ip = os.getenv("LOCAL_IP")
print("driver ip is " + str(driver_ip))

conf = (SparkConf().setMaster(master)
        .setAppName("sgx-pyspark-pi-notebook-example")
        .set("spark.submit.deployMode", "client")
        .set("spark.kubernetes.container.image", image)
        .set("spark.driver.host", driver_ip)
        .set("spark.driver.memory", "32g")
        .set("spark.executor.cores", "8")
        .set("spark.executor.memory", "32g")
        .set("spark.executor.instances", "2")
        .set("spark.cores.max", "32")
        .set("spark.kubernetes.driver.podTemplateFile", "/ppml/spark-driver-template.yaml")
        .set("spark.kubernetes.executor.podTemplateFile", "/ppml/spark-executor-template.yaml")
        .set("spark.kubernetes.authenticate.driver.serviceAccountName", "spark")
        .set("spark.kubernetes.executor.deleteOnTermination", "false")
        .set("spark.network.timeout", "10000000")
        .set("spark.executor.heartbeatInterval", "10000000")
        .set("spark.python.use.daemon", "false")
        .set("spark.python.worker.reuse", "false")
        .set("spark.authenticate", "true")
        .set("spark.authenticate.secret", "1234qwer")
        .set("spark.kubernetes.executor.secretKeyRef.SPARK_AUTHENTICATE_SECRET", "spark-secret:secret")
        .set("spark.kubernetes.driver.secretKeyRef.SPARK_AUTHENTICATE_SECRET", "spark-secret:secret"))

spark_on_k8s_conf = get_spark_on_k8s_conf(spark_conf = conf, sgx_enabled = True)

In [None]:
if __name__ == '__main__':
    sc = SparkSession.builder.config(conf = spark_on_k8s_conf).getOrCreate()
    partiton_num = 2
    n = 100000 * partiton_num
    count = sc.sparkContext.parallelize(range(1, n + 1), partiton_num).map(square_2_random).reduce(add)
    print("[INFO] Successful! Pi is roughly %f" % (4.0 * count / n))
    sc.stop()