In [0]:
import pyspark.sql.functions as F

import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import json



In [0]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

In [0]:
num_workers = 4
transcriptions_table = spark.table("marcell.call_centre_processing.transcriptions_udf")

transcriptions_table = transcriptions_table.repartition(num_workers).cache()

transcriptions_table.display()

In [0]:
transcriptions_table = (transcriptions_table
                        .withColumn("transcription_array", F.from_json(F.col("transcription"), "array<struct<text:string,timestamp:double>>"))
                        .withColumn("concatenated_text", F.expr("aggregate(transcription_array, '', (acc, x) -> acc || x.text)"))
)

In [0]:
transcriptions_table.display()


In [0]:
import pandas as pd
from typing import Iterator

In [0]:
@F.pandas_udf("string")
def redact_udf(text_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    import torch
    from transformers import pipeline

    pipe = pipeline("ner", device="cuda:0")

    def redact_text(text: str) -> str:
        entities = pipe(text)
        for entity in entities:
            if entity['entity'] in ['I-ORG', 'I-PER', 'I-LOC']:
                start, end = entity['start'], entity['end']
                text = text[:start] + '*'*(end - start) + text[end:]
        return text

    for text_chunk in text_iterator:
        yield text_chunk.apply(redact_text)

In [0]:

transcriptions_table.withColumn("redacted_transcription", redact_udf("concatenated_text")).write.mode("overwrite").option("overwriteSchema", "true").saveAsTable("marcell.call_centre_processing.transcriptions_redacted")