In [None]:
from pyspark.sql.functions import *
import urllib


class AWSManager:
    def __init__(self):
        self

    def load_aws_credentials(self):
        delta_table_path = "dbfs:/user/hive/warehouse/authentication_credentials"
        aws_keys_df = spark.read.format("delta").load(delta_table_path)
        ACCESS_KEY = aws_keys_df.select("Access key ID").collect()[0]["Access key ID"]
        SECRET_KEY = aws_keys_df.select("Secret access key").collect()[0][
            "Secret access key"
        ]
        ENCODED_SECRET_KEY = urllib.parse.quote(string=SECRET_KEY, safe="")
        return ACCESS_KEY, SECRET_KEY, ENCODED_SECRET_KEY

    def mount_s3_bucket(self):
        ACCESS_KEY, SECRET_KEY, ENCODED_SECRET_KEY = self.load_aws_credentials()
        AWS_S3_BUCKET = "user-0a5040edb649-bucket"
        MOUNT_NAME = "/mnt/user-0a5040edb649-bucket"
        SOURCE_URL = "s3n://{0}:{1}@{2}".format(
            ACCESS_KEY, ENCODED_SECRET_KEY, AWS_S3_BUCKET
        )
        try:
            dbutils.fs.mount(SOURCE_URL, MOUNT_NAME)
            print(f"Successfully mounted {SOURCE_URL} to {MOUNT_NAME}.")
        except Exception as e:
            if "already mounted" in str(e):
                print(
                    f"{MOUNT_NAME} already mounted. You can unmount using dbutils.fs.unmount('{MOUNT_NAME}') if needed."
                )
            else:
                print(f"Error mounting {SOURCE_URL} to {MOUNT_NAME}: {e}")


class DataTransformer:
    def __init__(self, df):
        self.df = df

    def process_pindata(self):
        # Replace missing values specific to df_pin
        missing_values = {
            "follower_count": "User Info Error",
            "description": "No description available%",
            "tag_list": "N,o, ,T,a,g,s, ,A,v,a,i,l,a,b,l,e",
            "image_src": "Image src error.",
            "poster_name": "User Info Error",
            "title": "No Title Data Available",
        }
        for column, value in missing_values.items():
            self.df = self.df.withColumn(
                column, when(col(column).like(value), None).otherwise(col(column))
            )
        # Transformations specific to df_pin
        self.df = self.df.withColumn(
            "follower_count", regexp_replace("follower_count", "k", "000")
        )
        self.df = self.df.withColumn(
            "follower_count", regexp_replace("follower_count", "M", "000000")
        )
        self.df = self.df.withColumn(
            "follower_count", col("follower_count").cast("int")
        )
        self.df = self.df.withColumn(
            "save_location", regexp_replace("save_location", "Local save in ", "")
        )
        self.df = self.df.withColumnRenamed("index", "ind")
        pin_column_order = [
            "ind",
            "unique_id",
            "title",
            "description",
            "follower_count",
            "poster_name",
            "tag_list",
            "is_image_or_video",
            "image_src",
            "save_location",
            "category",
        ]
        self.df = self.df.select(pin_column_order)
        return self.df

    def process_geodata(self):
        self.df = self.df.withColumn(
            "coordinates", array(col("latitude"), col("longitude"))
        )
        self.df = self.df.drop("latitude", "longitude")
        self.df = self.df.withColumn("timestamp", to_timestamp("timestamp"))
        geo_column_order = ["ind", "country", "coordinates", "timestamp"]
        self.df = self.df.select(geo_column_order)
        return self.df

    def process_userdata(self):
        self.df = self.df.withColumn(
            "user_name", concat_ws(" ", "first_name", "last_name")
        )
        self.df = self.df.drop("first_name", "last_name")
        self.df = self.df.withColumn("date_joined", to_timestamp("date_joined"))
        user_column_order = ["ind", "user_name", "age", "date_joined"]
        self.df = self.df.select(user_column_order)
        return self.df


class BatchProcessor:
    def __init__(self):
        self.topics = ["pin", "geo", "user"]

    def load_topics_into_dataframe(self, topic):
        try:
            file_path = f"/mnt/user-0a5040edb649-bucket/topics/0a5040edb649.{topic}/partition=0/*.json"
            df = spark.read.format("json").option("inferSchema", "true").load(file_path)
            return df
        except Exception as e:
            print(f"Failed to load data for topic {topic}: {e}")
            return None

    def start_batch_pipeline(self):
        dfs = {}
        transformers = {}
        for topic in self.topics:
            try:
                dfs[topic] = self.load_topics_into_dataframe(topic)
                transformers[topic] = DataTransformer(dfs[topic])
            except Exception as e:
                print(f"Error processing data for {topic}: {e}")
                break

        return (
            transformers["pin"].process_pindata(),
            transformers["geo"].process_geodata(),
            transformers["user"].process_userdata(),
        )


class StreamProcessor:
    def __init__(self, aws_manager):
        self.aws_manager = aws_manager

    # Define a function to read streaming data from a Kinesis stream
    def get_stream_data(self, topic):
        ACCESS_KEY, SECRET_KEY, ENCODED_SECRET_KEY = (
            self.aws_manager.load_aws_credentials()
        )
        # Read the data stream from Kinesis using the specified parameters
        stream_df = (
            spark.readStream.format("kinesis")
            .option("streamName", f"streaming-0a5040edb649-{topic}")
            .option("initialPosition", "earliest")
            .option("region", "us-east-1")
            .option("awsAccessKey", ACCESS_KEY)
            .option("awsSecretKey", SECRET_KEY)
            .load()
        )
        # Return the streaming DataFrame
        return stream_df

    # Define a function to create a DataFrame from a streaming DataFrame
    def create_df_from_stream(self, stream, schema):
        # Select the 'data' column and cast it as a string
        json_df = stream.selectExpr("CAST(data as STRING) as json_data")
        # Extract the structured JSON data using the specified schema and alias it as 'data'
        structrued_df = json_df.select(
            "json_data", from_json(json_df.json_data, schema).alias("data")
        )
        # Select the fields from the structured 'data' column and return the DataFrame
        return structrued_df.select("data.*")

    def create_topic_dataframe(self, topic):
        schema_dict = {"pin": pin_schema, "geo": geo_schema, "user": user_schema}
        topic_stream = self.get_stream_data(topic)
        return self.create_df_from_stream(topic_stream, schema_dict[topic])

    # Define a function to write a DataFrame to a Delta table as a streaming output
    def write_df_to_table(self, df, table_name):
        checkpoint_location = "/tmp/kinesis/_checkpoints/"
        query = (
            df.writeStream.format("delta")
            .outputMode("append")
            .option("checkpointLocation", checkpoint_location)
            .table(table_name)
        )
        return query

    def start_stream_pipeline(self):
        topics = ["pin", "geo", "user"]
        dfs = {}
        transformers = {}

        for topic in topics:
            try:
                dfs[topic] = self.create_topic_dataframe(topic)
                transformers[topic] = DataTransformer(dfs[topic])
            except Exception as e:
                print(f"Error processing data for {topic}: {e}")
                break

        return (
            transformers["pin"].process_pindata(),
            transformers["geo"].process_geodata(),
            transformers["user"].process_userdata(),
        )