In [None]:
import sys
sys.path.append("/mnt/code")

In [None]:
import os,socket,json, time
from pyspark.sql import SparkSession, functions as F, types as T
import boto3
from botocore.exceptions import ClientError


In [None]:
import pyspark
print("pyspark:", pyspark.__version__)  # e.g. 3.3.2

In [None]:
## This needs a bucket level  actions ["s3:ListBucket","s3:GetBucketLocation"],
def delete_if_exists(s3_uri: str):
    if not s3_uri.startswith("s3://"):
        raise ValueError("Path must start with s3://")

    # split into bucket + key
    _, _, bucket_and_key = s3_uri.partition("s3://")
    bucket, _, key = bucket_and_key.partition("/")

    s3 = boto3.client("s3")

    try:
        # check if object exists
        s3.head_object(Bucket=bucket, Key=key)
    except ClientError as e:
        if e.response["Error"]["Code"] == "404":
            print(f"{s3_uri} does not exist, nothing to delete.")
            return
        else:
            raise
    # delete if it exists
    s3.delete_object(Bucket=bucket, Key=key)
    print(f"Deleted {s3_uri}")

## This deletes without bucket level actions
def delete_object_idempotent(uri: str):
    _, _, rest = uri.partition("s3://")
    bucket, _, key = rest.partition("/")
    s3 = boto3.client("s3")
    try:
        s3.delete_object(Bucket=bucket, Key=key)
        print(f"Delete requested: s3://{bucket}/{key}")
    except ClientError as e:
        print("Delete failed:", e)
        raise

In [None]:
import os
import socket
import boto3
from pyspark.sql import SparkSession, functions as F

# ---------- Helpers ----------

def _log_from_workers(partition_iter):
    """Runs on executors; prints IRSA/env visibility."""
    role = os.environ.get("AWS_ROLE_ARN", "<missing>")
    tok  = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE", "<missing>")
    akid = os.environ.get("AWS_ACCESS_KEY_ID", "<missing>")
    host = socket.gethostname()
    print(f"[EXECUTOR {host}] AWS_ROLE_ARN={role}")
    print(f"[EXECUTOR {host}] AWS_WEB_IDENTITY_TOKEN_FILE={tok}")
    print(f"[EXECUTOR {host}] AWS_ACCESS_KEY_ID set? {'yes' if akid!='<missing>' else 'no'}")
    for row in partition_iter:
        yield row

def create_synthetic_text(spark: SparkSession, s3_text_path: str, n_lines: int = 1000, seed: int = 42):
    """Create a 1-col DF ('value') and write it as TEXT to s3a://..."""
    df = (
        spark.range(0, n_lines, 1, numPartitions=max(1, min(32, n_lines // 100)))
             .withColumn("value", F.concat_ws(
                 " ",
                 F.lit("line"),
                 F.col("id").cast("string"),
                 F.lit("payload"),
                 F.sha2(F.concat(F.col("id").cast("string"), F.lit(str(seed))), 256)
             ))
             .select("value")
    )
    df.write.mode("overwrite").text(s3_text_path)
    print(f"Wrote synthetic text ({n_lines} lines) to: {s3_text_path}")

'''
When the JVM is already up, it won’t see new env vars (Java reads env once at process start). Your STS creds end up in Python, but S3A keeps using the old, empty provider chain.

Two ways out. Since restarting the kernel is annoying, here’s a patch-in-place route that works even with an existing session:
1. Fix without restarting: inject STS creds into Hadoop conf and flush S3A cache
    Drop this helper right after you call assume_with_web_identity() and before any s3a:// read/write:
2.Before building the session, stop the old one so your .config("spark.jars.packages", ...) and IRSA settings actually load:
try:
    SparkSession.getActiveSession().stop()
except Exception:
    pass
Then build a fresh session and proceed (your earlier “fresh-session” version is fine). But in Domino/managed notebooks, sometimes a platform daemon immediately recreates 
a baseline session; when that happens, the Hadoop-conf injection above is the reliable move.
'''
def apply_s3a_temp_creds_to_hadoop(spark, creds: dict, region: str = None):
    """
    Inject temporary AWS creds directly into Hadoop config so S3A uses them,
    even if the Spark JVM/session was already running.
    """
    if region is None:
        region = os.environ.get("AWS_REGION", "us-east-1")

    hconf = spark._jsc.hadoopConfiguration()

    # Tell S3A to use temp session creds (AKIA + SECRET + TOKEN)
    hconf.set("fs.s3a.access.key", creds["AWS_ACCESS_KEY_ID"])
    hconf.set("fs.s3a.secret.key", creds["AWS_SECRET_ACCESS_KEY"])
    hconf.set("fs.s3a.session.token", creds["AWS_SESSION_TOKEN"])
    hconf.set("fs.s3a.aws.region", region)

    # Make the provider selection unambiguous
    # (TemporaryAWSCredentialsProvider is used when session.token is present,
    #  but we set it explicitly to avoid the wrong chain.)
    hconf.set("fs.s3a.aws.credentials.provider",
              "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider")

    # Safety: ensure S3A impl is present + no stale cached FS is reused
    hconf.set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
    hconf.set("fs.s3a.path.style.access", "true")
    hconf.set("fs.s3a.impl.disable.cache", "true")

    # Drop any cached FileSystem instances so the new creds take effect
    jvm = spark._jvm
    jvm.org.apache.hadoop.fs.FileSystem.closeAll()


def assume_with_web_identity():
    """Use IRSA (web identity) to obtain short-lived env creds via STS."""
    role_arn = os.environ["AWS_ROLE_ARN"]
    token_fn = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
    session_name = os.environ.get("AWS_ROLE_SESSION_NAME", "spark-irsa")
    region = os.environ.get("AWS_REGION", "us-east-1")

    with open(token_fn, "r") as f:
        jwt = f.read()

    sts = boto3.client("sts", region_name=region)
    resp = sts.assume_role_with_web_identity(
        RoleArn=role_arn,
        RoleSessionName=session_name,
        WebIdentityToken=jwt,
        DurationSeconds=3600,
    )
    c = resp["Credentials"]
    creds = {
        "AWS_ACCESS_KEY_ID": c["AccessKeyId"],
        "AWS_SECRET_ACCESS_KEY": c["SecretAccessKey"],
        "AWS_SESSION_TOKEN": c["SessionToken"],
        "AWS_REGION": region,
    }
    os.environ.update(creds)  # driver process sees them too
    return creds

def to_s3a(u: str) -> str:
    return u.replace("s3://", "s3a://", 1) if u and u.startswith("s3://") else u

# ---------- Main ----------

def main(s3_input_text: str, s3_output_parquet: str = None, make_synth: bool = False, synth_lines: int = 1000):
    # Fresh session (avoid "Using an existing Spark session" which ignores new jars/config)
    try:
        # Stop any existing session to ensure classpath + provider configs apply
        spark  # type: ignore
        SparkSession.getActiveSession().stop()  # type: ignore
    except Exception:
        pass

    # 1) Bootstrap short-lived env creds from IRSA
    creds = assume_with_web_identity()

    s3a_input  = to_s3a(s3_input_text)
    s3a_output = to_s3a(s3_output_parquet) if s3_output_parquet else None
    region     = os.environ.get("AWS_REGION", "us-east-1")
    if SparkSession.getActiveSession():
        SparkSession.getActiveSession().stop()
    # 2) Build Spark – pass temp env creds to driver & executors    
    spark = (
        SparkSession.builder
        .appName("basic-s3a-dataframe")
        # If your image lacks S3A, fetch jars (match Hadoop minor); if outbound blocked, replace with spark.jars=/path1,/path2
        .config("spark.jars.packages",
                "org.apache.hadoop:hadoop-aws:3.3.4,com.amazonaws:aws-java-sdk-bundle:1.12.610")
        # Propagate short-lived creds (works with EnvironmentVariableCredentialsProvider)
        .config("spark.driverEnv.AWS_ACCESS_KEY_ID", creds["AWS_ACCESS_KEY_ID"])
        .config("spark.driverEnv.AWS_SECRET_ACCESS_KEY", creds["AWS_SECRET_ACCESS_KEY"])
        .config("spark.driverEnv.AWS_SESSION_TOKEN", creds["AWS_SESSION_TOKEN"])
        .config("spark.executorEnv.AWS_ACCESS_KEY_ID", creds["AWS_ACCESS_KEY_ID"])
        .config("spark.executorEnv.AWS_SECRET_ACCESS_KEY", creds["AWS_SECRET_ACCESS_KEY"])
        .config("spark.executorEnv.AWS_SESSION_TOKEN", creds["AWS_SESSION_TOKEN"])
        # S3A basics
        .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
        .config("spark.hadoop.fs.s3a.path.style.access", "true")
        .config("spark.hadoop.fs.s3a.aws.region", region)
        .config("spark.driver.extraJavaOptions", "-Dcom.amazonaws.sdk.disableEc2Metadata=true")
        .getOrCreate()
    )
    apply_s3a_temp_creds_to_hadoop(spark, creds)
    # Driver-side sanity
    print("\n=== DRIVER ENV CHECK ===")
    print("AWS_ROLE_ARN:", os.environ.get("AWS_ROLE_ARN"))
    print("AWS_WEB_IDENTITY_TOKEN_FILE:", os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE"))
    print("AWS_ACCESS_KEY_ID set?:", "yes" if os.environ.get("AWS_ACCESS_KEY_ID") else "no")
    print("fs.s3a.impl:", spark._jsc.hadoopConfiguration().get("fs.s3a.impl"))
    print("=====================================\n")

    # 3) Create synthetic input (optional)
    if make_synth:
        create_synthetic_text(spark, s3a_input, n_lines=synth_lines)

    # 4) Read text from S3, fan out to executors (logs IRSA/env visibility), do a tiny transform
    df_text = spark.read.text(s3a_input)
    _ = df_text.rdd.mapPartitions(_log_from_workers).count()

    df_non_empty = df_text.filter(F.length(F.trim("value")) > 0)
    print("Non-empty line count:", df_non_empty.count())

    # 5) Optional write to S3
    if s3a_output:
        (df_non_empty
            .withColumn("len", F.length("value"))
            .withColumn("ts", F.current_timestamp())
            .write.mode("overwrite").parquet(s3a_output))
        print(f"Wrote parquet to: {s3a_output}")

    spark.stop()

# Example:
# main("s3://<BUCKET_NAME>/end-to-end/sample-input/test.txt",
#      "s3://<BUCKET_NAME>/end-to-end/sample-output/", make_synth=True, synth_lines=1000)


In [None]:
s3_bucket = os.environ['S3_BUCKET_NAME']
s3_input_text=f"s3://{s3_bucket}/end-to-end/sample-input/test.txt"
s3_output_parquet=f"s3://{s3_bucket}/end-to-end/sample-output/"
#Delete if it exists. Not necessary because we use overwrite mode
#delete_if_exists(s3_input_text)
#delete_if_exists(s3_output_parquet)

In [None]:
main(s3_input_text,s3_output_parquet,make_synth=True,synth_lines=1000)