In [139]:
from sparknlp.annotator import AnnotatorModel, AnnotatorType
from pyspark.sql.functions import col, when
import random


class ErrorTransformer(AnnotatorModel):
    inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
    outputAnnotatorType = AnnotatorType.DOCUMENT

    def __init__(self, fraction):
        super(ErrorTransformer, self).__init__(
            classname=None,
            java_model=None,
        )
        self.fraction = fraction
        self.num_affected_rows = 0
        self.affected_rows_indices = []


class CategoricalLabelTransformer(ErrorTransformer):
    def __init__(self, fraction, categories=None):
        super(CategoricalLabelTransformer, self).__init__(fraction=fraction)
        self.categories = categories

    def _transform(self, dataset):
        self.inputCol = self.getInputCols()[0]
        self.outputCol = self.getOutputCol()
        # Determine the number of affected rows
        num_rows = dataset.count()
        print("num_rows:", num_rows)

        # Select random rows to transform
        affected_rows = dataset.sample(False, self.fraction)
        self.num_affected_rows = affected_rows.count()
        print("affected_rows_count:", self.num_affected_rows)

        print(affected_rows.groupBy("language").count().show())

        replaced_column = (
            when(col(self.inputCol) == "en", self._select_random_category("en"))
            .when(col(self.inputCol) == "es", self._select_random_category("es"))
            .when(col(self.inputCol) == "tr", self._select_random_category("tr"))
            .when(col(self.inputCol) == "fr", self._select_random_category("fr"))
            .when(col(self.inputCol) == "de", self._select_random_category("de"))
            .otherwise(col(self.inputCol))
        )

        # Apply transformation to selected rows
        transformed_rows = affected_rows.withColumn(self.inputCol, replaced_column)

        print("transformed_rows:", transformed_rows.show(50))
        print("transformed_rows count:", transformed_rows.count())

        # replace affected_rows in dataset for transformed_rows
        substract_df = dataset.subtract(affected_rows)
        print("dataset count:", dataset.count())
        union_df = substract_df.union(transformed_rows)

        return union_df

    def _select_random_category(self, label):
        other_categories = self.categories.copy()
        other_categories.remove(label)
        category = random.choice(other_categories)
        print(f"selecting random category! {category}")
        return category

In [2]:
from pyspark.ml import Pipeline
import sparknlp

In [None]:
spark = sparknlp.start(apple_silicon=True)

In [3]:
spark.stop()

In [140]:
from pyspark.ml import Pipeline
from pyspark.sql.functions import monotonically_increasing_id

# Define the CategoricalLabelTransformer annotator
cat_label_transformer = (
    CategoricalLabelTransformer(
        fraction=0.1,
        categories=["en", "es", "tr", "fr", "de"],
        # categories=["en"],
    )
    .setInputCols(["language"])
    .setOutputCol("language")
)

# Create the SparkNLP pipeline
pipeline = Pipeline(stages=[cat_label_transformer])
data = spark.read.option("header", True).csv(
    "./data/test_data_consolidated.csv", escape='"'
)

df_with_index = data.withColumn("index", monotonically_increasing_id())
# Create the SparkNLP pipeline
pipeline = Pipeline(stages=[cat_label_transformer])

# Fit the pipeline to your data
model = pipeline.fit(df_with_index)

# Transform your data
transformed_data = model.transform(df_with_index)

num_rows: 52500
affected_rows_count: 5214
+--------+-----+
|language|count|
+--------+-----+
|      en| 1047|
|      tr| 1076|
|      de| 1047|
|      es| 1046|
|      fr|  998|
+--------+-----+

None
selecting random category! fr
selecting random category! tr
selecting random category! en
selecting random category! tr
selecting random category! es
+--------------------+--------+-----+-----+
|         review_body|language|stars|index|
+--------------------+--------+-----+-----+
|No se el resultad...|      tr|    4|    6|
|*Finally!* is all...|      fr|    4|   17|
|Bin sehr zufriede...|      es|    5|   23|
|Las tallas chinas...|      tr|    1|   38|
|...das leider imm...|      es|    3|   60|
|Se pega sobre si ...|      tr|    1|   84|
|ürünü dahayeni ku...|      en|    5|   99|
|fiyatı daha uygun...|      en|    3|  106|
|Calidad en genera...|      tr|    3|  119|
|Ürün genel perfor...|      en|    3|  128|
|Sieht gut aus (fü...|      es|    3|  130|
|I was looking for...|      fr|  

In [141]:
transformed_data.count()

                                                                                

52500

In [142]:
transformed_data.groupBy("language").count().show()

[Stage 564:>                                                        (0 + 3) / 3]

+--------+-----+
|language|count|
+--------+-----+
|      en|10558|
|      tr|11494|
|      de| 9553|
|      es|10364|
|      fr|10531|
+--------+-----+



                                                                                

In [79]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import pandas as pd

# Initialize SparkSession
spark = SparkSession.builder.appName("State_Prefix").getOrCreate()


def get_state_codes(input_df, prefix=None):
    if prefix is not None:
        output_df = input_df.withColumn(
            "state_code", F.concat(F.lit(prefix), F.col("index"))
        )
    else:
        output_df = input_df.withColumn("state_code", F.col("index"))

    return output_df


data = {"state": ["Alabama", "California", "Maine", "Ohio", "Arizona", "Montana"]}
df1 = spark.createDataFrame(pd.DataFrame(data))

df1 = df1.withColumn(
    "index", F.row_number().over(Window.orderBy(F.monotonically_increasing_id())) - 1
)
df1.show()

df1 = get_state_codes(df1, "S")
df1 = df1.drop("index")

df1.show()

24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/04/01 14:29:16 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.


+----------+-----+
|     state|index|
+----------+-----+
|   Alabama|    0|
|California|    1|
|     Maine|    2|
|      Ohio|    3|
|   Arizona|    4|
|   Montana|    5|
+----------+-----+

+----------+----------+
|     state|state_code|
+----------+----------+
|   Alabama|        S0|
|California|        S1|
|     Maine|        S2|
|      Ohio|        S3|
|   Arizona|        S4|
|   Montana|        S5|
+----------+----------+



In [108]:
from sparknlp.base import Transformer
from sparknlp.annotator import *
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import DataFrame, Row
import random


class CategoricalLabelTransformer(
    Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable
):
    def __init__(self, inputCol=None, outputCol=None, fraction=0.1, categories=None):
        super(CategoricalLabelTransformer, self).__init__()
        self.fraction = fraction
        self.categories = categories
        self.inputCol = inputCol
        self.outputCol = outputCol

    def setFraction(self, value):
        self.fraction = value

    def setCategories(self, value):
        self.categories = value

    def _transform(self, dataset):
        def corrupt_row(row):
            if random.random() < self.fraction:
                current_val = row[self.inputCol]
                if isinstance(self.categories, dict):
                    new_val = self.categories.get(current_val, current_val)
                else:
                    categories_min_current = [
                        c for c in self.categories if c != current_val
                    ]
                    new_val = random.choice(categories_min_current)
                return Row(**{**row.asDict(), self.outputCol: new_val})
            return row

        return dataset.rdd.map(corrupt_row).toDF()

In [109]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import DocumentAssembler, Finisher
from sparknlp.base import *
import random

# Initialize SparkSession
spark = SparkSession.builder.appName(
    "CategoricalLabelTransformerPipeline"
).getOrCreate()

# Sample data
data = [("apple", 1), ("banana", 2), ("apple", 3), ("orange", 4), ("banana", 5)]
columns = ["fruit", "id"]
df = spark.createDataFrame(data, columns)

# Initialize other Spark NLP components
document_assembler = DocumentAssembler().setInputCol("fruit").setOutputCol("document")

tokenizer = Tokenizer().setInputCols(["document"]).setOutputCol("tokens")

# Initialize the custom transformer
cat_label_transformer = CategoricalLabelTransformer(
    inputCol="fruit",
    outputCol="corrupted_fruit",
    fraction=0.2,
    categories=["apple", "banana", "orange"],
)

# Create the pipeline
pipeline = Pipeline(stages=[document_assembler, tokenizer, cat_label_transformer])

# Fit and transform the pipeline
pipeline_model = pipeline.fit(df)
transformed_df = pipeline_model.transform(df)

# Show the transformed DataFrame
transformed_df.show()

                                                                                

PySparkValueError: [CANNOT_DETERMINE_TYPE] Some of types cannot be determined after inferring.