In [71]:
import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [72]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, split, collect_set
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType

class SequenceIndexer(Transformer, HasInputCol, HasOutputCol):
    
    def __init__(self, dictionary, maxLength=20, inputCol=None, outputCol=None):
        super(SequenceIndexer, self).__init__()
        self.dictionary = dictionary
        self.maxLength = maxLength
        self.broadcastDict = SparkSession.builder.getOrCreate().sparkContext.broadcast(self.dictionary)
        
        # Initialize Params
        self.inputCol = Param(self, "inputCol", "")
        self.outputCol = Param(self, "outputCol", "")
        
        # Set Params
        if inputCol:
            self._set(inputCol=inputCol)
        if outputCol:
            self._set(outputCol=outputCol)
    
    def _transform(self, dataset):
        brDict = self.broadcastDict.value
        input_col = self.getInputCol()
        output_col = self.getOutputCol()
        
        def map_sequence(seq):
            t = [brDict.get(token, 0) for token in seq]
            if len(t) >= self.maxLength:
                return t[:self.maxLength]
            else:
                return t + [len(brDict)] * (self.maxLength - len(t))
        
        map_udf = udf(map_sequence, ArrayType(IntegerType()))
        return dataset.withColumn(output_col, map_udf(col(input_col)))


In [73]:
# Khởi tạo SparkSession
spark = SparkSession.builder \
    .appName("SequenceIndexerExample") \
    .master("local[*]")\
    .config("spark.driver.memory", "16g") \
    .config("spark.executor.memory", "16g") \
    .config("spark.executor.cores", "4") \
    .config("spark.task.cpus", "2") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

# Đọc dữ liệu từ file parquet
df = spark.read.parquet("dat/sa.parquet").sample(fraction=0.1, seed=42)
df.head(10)

[Row(tweet='$GM: Deutsche Bank cuts to Hold https://t.co/7Fv1ZiFZBS', sentiment=2, url='https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment'),
 Row(tweet='$MDCO: Oppenheimer cuts to Perform', sentiment=2, url='https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment'),
 Row(tweet='$MSGN - Imperial downgrades MSG Networks amid sports-free airwaves https://t.co/Ul2S6XNXw8', sentiment=2, url='https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment'),
 Row(tweet='Canada Goose stock price target cut to $50 from $65 at CFRA', sentiment=2, url='https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment'),
 Row(tweet="Disney downgraded as analyst says parks attendance could take 2 years to 'normalize' https://t.co/InJKourtW3", sentiment=2, url='https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment'),
 Row(tweet='Evercore downgrades bluebird bio and Clovis in premarket analyst action', sentiment=2, url='

In [74]:
# Tách các từ trong cột "tweet" thành các từ đơn lẻ và tạo dictionary
words_df = df.select(explode(split(col("tweet"), "\\s+")).alias("word"))
words_df.head(10)

[Row(word='$GM:'),
 Row(word='Deutsche'),
 Row(word='Bank'),
 Row(word='cuts'),
 Row(word='to'),
 Row(word='Hold'),
 Row(word='https://t.co/7Fv1ZiFZBS'),
 Row(word='$MDCO:'),
 Row(word='Oppenheimer'),
 Row(word='cuts')]

In [75]:
words_list = words_df.select("word").distinct().agg(collect_set("word")).collect()[0][0]
dictionary = {word: idx + 1 for idx, word in enumerate(words_list)}

In [76]:
df = df.withColumn("words", split(col("tweet"), "\\s+"))

# Sử dụng SequenceIndexer
maxLength = 20
indexer = SequenceIndexer(dictionary, maxLength, inputCol="words", outputCol="indexed")

# Áp dụng SequenceIndexer lên DataFrame
indexed_df = indexer.transform(df)

In [77]:
# Hiển thị kết quả để kiểm tra
indexed_df.select("tweet", "indexed").show(truncate=False)


+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------+
|tweet                                                                                                       |indexed                                                                                                                                    |
+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------+
|$GM: Deutsche Bank cuts to Hold https://t.co/7Fv1ZiFZBS                                                     |[863, 10450, 17038, 14660, 3801, 6134, 17433, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514, 19514] 

In [78]:

# Dừng SparkSession
spark.stop()