## Подготовка spark-окружения в google colab

In [1]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://archive.apache.org/dist/spark/spark-3.1.1/spark-3.1.1-bin-hadoop3.2.tgz
!tar xf spark-3.1.1-bin-hadoop3.2.tgz
!pip install -q findspark

In [2]:
!pip install spark-nlp==5.3.0

Collecting spark-nlp==5.3.0
  Downloading spark_nlp-5.3.0-py2.py3-none-any.whl (564 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/564.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/564.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m532.5/564.8 kB[0m [31m7.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.8/564.8 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spark-nlp
Successfully installed spark-nlp-5.3.0


In [3]:
!pip install pyspark==3.1.1

Collecting pyspark==3.1.1
  Downloading pyspark-3.1.1.tar.gz (212.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.3/212.3 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting py4j==0.10.9 (from pyspark==3.1.1)
  Downloading py4j-0.10.9-py2.py3-none-any.whl (198 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.6/198.6 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.1.1-py2.py3-none-any.whl size=212767583 sha256=2c7e03bc40d65a842932824d5558775e3eb2ea21c1f0952512f3ce0d407e9410
  Stored in directory: /root/.cache/pip/wheels/a0/3f/72/8efd988f9ae041f051c75e6834cd92dd6d13a726e206e8b6f3
Successfully built pyspark
Installing collected packages: py4j, pyspark
  Attempting uninstall: py4j
    Found existing installation: py4j 0

In [4]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.1.1-bin-hadoop3.2"

## `Обработка текстов. Spark NLP`

Для работы с текстами можно использовать Spark ML, однако, набор средств ограничен. В частности, отсутствуют предобученные модели, включая нейросети.

**Spark NLP** поддерживает большое чисто трансформеров и эстиматоров для работы с текстами, включая предобученные нейронные сети.

Пайплайн работы с текстами может выглядеть следующим образом:
1. Загрузка данных в Spark DataFrame
2. Преобразование в необходимый для **Spark NLP** тип данных (Document)
3. Преобразования с использованием **Spark NLP** (токенизация, стемминг, эмбеддинги, инференс моделей)
4. Обратное преобразование к "обычные" типы данных
5. Применение Spark ML

Из преимуществ можно отметить высокую эффективность инференса при наличии большого числа исполнителей.

Так, при правильной конфигурации Spark NLP может обгонять в $5$ раз GPU:

```text
https://github.com/JohnSnowLabs/spark-nlp/issues/570

For BertEmbeddings:
On a local server with 50 cores and 180G memory:
1. 14k sentences take more than 16000 seconds on CPU (27 minutes)
2. 14K sentences take around 500 seconds on Tesla P100 GPU (less than 10)
3. 17K sentences take around 120 seconds on a CPU-based Spark cluster with 10 executors each 5 cores!
This has been tested on a Bert Model with 256 max sentence length and second-to-last-hidden layer which is way slower due to encoding from sentence's context.

Two things are very important, first is that Bert is GPU optimized not CPU. The second is, we distribute TensorFlow over Spark so this boost parallelism into prediction which as you can see it beats a single GPU. (now if you have GPU Spark cluster then this would be flying)
```

Из недостатков можно отметить, что обучение моделей происходит только на драйвере:
```text
https://github.com/JohnSnowLabs/spark-nlp/issues/9266

* Unfortunately, at the moment the trainable annotators can only use Driver and cannot scale (CPU or GPU).
* Training in Spark NLP happens inside Driver and only on 1 GPU
* Prediction/inference also uses 1 GPU device, however, if you are in a cluster mode (multiple executors) then each machine can have 1 GPU device and that way you can distribute and parallelize the computation over multiple GPU
* Currently, due to TensorFlow limitations especially available APIs in Java only 1 GPU per machine is possible
```

In [5]:
import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql.window import Window

from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext

conf = (
    SparkConf()
        .set('spark.ui.port', '4050')
        .set('spark.driver.memory', '15g')
        .set("spark.kryoserializer.buffer.max", "2000M")
        # Укажем библиотеки, необходимые для Spark NLP
        .set("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.3.0")
        .setMaster('local[*]')
)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

### `Применение готового пайплайна`


In [6]:
import sparknlp

from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

Модели можно инферить как для единичных примеров:

In [7]:
from sparknlp.pretrained import PretrainedPipeline

explain_document_pipeline = PretrainedPipeline("explain_document_ml")
annotations = explain_document_pipeline.annotate("We are very happy about SparkNLP")
annotations

explain_document_ml download started this may take some time.
Approx size to download 9 MB
[OK!]


{'document': ['We are very happy about SparkNLP'],
 'spell': ['We', 'are', 'very', 'happy', 'about', 'SparkNLP'],
 'pos': ['PRP', 'VBP', 'RB', 'JJ', 'IN', 'NNP'],
 'lemmas': ['We', 'be', 'very', 'happy', 'about', 'SparkNLP'],
 'token': ['We', 'are', 'very', 'happy', 'about', 'SparkNLP'],
 'stems': ['we', 'ar', 'veri', 'happi', 'about', 'sparknlp'],
 'sentence': ['We are very happy about SparkNLP']}

Так и для распределённых данных:



In [8]:
sample_df = spark.createDataFrame([
       ("We are very happy about SparkNLP", )
], ['text'])
sample_df.show()
explain_document_pipeline.transform(sample_df).toPandas()

+--------------------+
|                text|
+--------------------+
|We are very happy...|
+--------------------+



Unnamed: 0,text,document,sentence,token,spell,lemmas,stems,pos
0,We are very happy about SparkNLP,"[(document, 0, 31, We are very happy about Spa...","[(document, 0, 31, We are very happy about Spa...","[(token, 0, 1, We, {'sentence': '0'}, []), (to...","[(token, 0, 1, We, {'sentence': '0', 'confiden...","[(token, 0, 1, We, {'sentence': '0', 'confiden...","[(token, 0, 1, we, {'sentence': '0', 'confiden...","[(pos, 0, 1, PRP, {'sentence': '0', 'word': 'W..."


Библиотека использует набор концептов из Spark ML:
* **Annotator Approaches** — аналог`Spark ML Estimator`, то есть модель, которую можно обучить
* **Annotator Models** — аналог `Spark ML Transformer`, то есть обученная модель, которая позволяет добавлять колонки в DataFrame

Также можно посмотреть какие стадии обработки данных находятся внутри пайплайна:

In [9]:
explain_document_pipeline.model.stages

[document_811d40a38b24,
 SENTENCE_ce56851acebe,
 REGEX_TOKENIZER_282781ab961b,
 SPELL_79c88338ef12,
 LEMMATIZER_c62ad8f355f9,
 STEMMER_caf11d1f4d0e,
 POS_dbb704204f6f]


Spark NLP работает с колонками в особом формате — в **формате аннотированного текста**. Такой тип определяется как список **аннотаций** (annotation), каждая из которых соответствует некоторой подстроке исходного текста (поля `begin`, `end`) и содержит информацию о типе этого отрезка (`annotatorType`), а также сами аннотации (`result`) и метаинформацию (`metadata`).



In [10]:
sentences = [
  ['Hello, this is an example sentence'],
  ['And this is a second sentence.']
]

data = spark.createDataFrame(sentences).toDF("text")
explained_df = explain_document_pipeline.transform(data)
explained_df.printSchema()

root
 |-- text: string (nullable = true)
 |-- document: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- sentence: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true

Рассмотрим примеры таких аннотированных колонок.

Первым шагом в работе со Spark NLP является преобразование RAW текста в документ.

Видно, что документ состоит из одной аннотации:

In [11]:
[document] = explained_df.select('document').take(1)
document

Row(document=[Row(annotatorType='document', begin=0, end=33, result='Hello, this is an example sentence', metadata={'sentence': '0'}, embeddings=[])])

Применение других моделей преобразует одну колонку с аннотациями в другую:

In [12]:
[tokens] = explained_df.select('token').take(1)
[pos_tags] = explained_df.select('pos').take(1)

In [13]:
tokens.token

[Row(annotatorType='token', begin=0, end=4, result='Hello', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=5, end=5, result=',', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=7, end=10, result='this', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=12, end=13, result='is', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=15, end=16, result='an', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=18, end=24, result='example', metadata={'sentence': '0'}, embeddings=[]),
 Row(annotatorType='token', begin=26, end=33, result='sentence', metadata={'sentence': '0'}, embeddings=[])]

In [14]:
pos_tags.pos

[Row(annotatorType='pos', begin=0, end=4, result='UH', metadata={'sentence': '0', 'word': 'Hello'}, embeddings=[]),
 Row(annotatorType='pos', begin=5, end=5, result=',', metadata={'sentence': '0', 'word': ','}, embeddings=[]),
 Row(annotatorType='pos', begin=7, end=10, result='DT', metadata={'sentence': '0', 'word': 'this'}, embeddings=[]),
 Row(annotatorType='pos', begin=12, end=13, result='VBZ', metadata={'sentence': '0', 'word': 'is'}, embeddings=[]),
 Row(annotatorType='pos', begin=15, end=16, result='DT', metadata={'sentence': '0', 'word': 'an'}, embeddings=[]),
 Row(annotatorType='pos', begin=18, end=24, result='NN', metadata={'sentence': '0', 'word': 'example'}, embeddings=[]),
 Row(annotatorType='pos', begin=26, end=33, result='NN', metadata={'sentence': '0', 'word': 'sentence'}, embeddings=[])]

In [15]:
explain_document_pipeline.transform(sample_df).select("token").show(truncate=False)

+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|token                                                                                                                                                                                                                                                                 |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[{token, 0, 1, We, {sentence -> 0}, []}, {token, 3, 5, are, {sentence -> 0}, []}, {token, 7, 10, very, {sentence -> 0}, []}, {token, 12, 16, happy, {sentence -> 0}, []}, {token, 18, 22, about, {sentence -

Вывод колонок в виде аннотаций затрудняет анализ данных. Для удобного чтения содержимого используется Transformer `sparknlp.Finisher`, который оставляет только поле `result` из каждой аннотации и возвращает их в виде списка:

In [16]:
finisher = sparknlp.Finisher().setInputCols(["token", "lemmas", "pos"])

pipeline = (
    Pipeline()
        .setStages([
            explain_document_pipeline.model,
            finisher
        ])
)

model = pipeline.fit(data)

annotations_finished_df = model.transform(data)
annotations_finished_df.toPandas()

Unnamed: 0,text,finished_token,finished_lemmas,finished_pos
0,"Hello, this is an example sentence","[Hello, ,, this, is, an, example, sentence]","[Hello, ,, this, be, an, example, sentence]","[UH, ,, DT, VBZ, DT, NN, NN]"
1,And this is a second sentence.,"[And, this, is, a, second, sentence, .]","[And, this, be, a, second, sentence, .]","[CC, DT, VBZ, DT, JJ, NN, .]"


### `Дообучение классификатора`

Применим Spark NLP для дообучения на задаче многоклассовой классификации — будем предсказывать категорию новостей:

In [17]:
base_url = 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category'
! wget -O news_category_test.csv $base_url/news_category_test.csv
! wget -O news_category_train.csv $base_url/news_category_train.csv

--2024-03-02 06:52:59--  https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_test.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 16.182.97.56, 52.217.107.102, 52.216.133.45, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|16.182.97.56|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1504408 (1.4M) [text/csv]
Saving to: ‘news_category_test.csv’


2024-03-02 06:52:59 (3.83 MB/s) - ‘news_category_test.csv’ saved [1504408/1504408]

--2024-03-02 06:52:59--  https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 16.182.97.56, 52.217.107.102, 52.216.133.45, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|16.182.97.56|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24032125 (23M) [text/csv]
Saving to: ‘news_category_train.csv’


2024-03-02 06:53:01 (24.

In [18]:
! head news_category_train.csv

category,description
Business," Short sellers, Wall Street's dwindling band of ultra cynics, are seeing green again."
Business," Private investment firm Carlyle Group, which has a reputation for making well timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market."
Business, Soaring crude prices plus worries about the economy and the outlook for earnings are expected to hang over the stock market next week during the depth of the summer doldrums.
Business," Authorities have halted oil export flows from the main pipeline in southern Iraq after intelligence showed a rebel militia could strike infrastructure, an oil official said on Saturday."
Business," Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections."
Business," Stocks ended slightly higher on Friday but stayed near lows for the year as oil prices surged past  #36;

Считаем скачанные данные в spark-dataframe:

In [19]:
train_df = (
    spark.read
      .option("header", True)
      .csv("news_category_train.csv")
)
test_df = (
    spark.read
      .option("header", True)
      .csv("news_category_test.csv")
)

display(train_df.limit(5).toPandas())
display(test_df.limit(5).toPandas())

train_df.count(), test_df.count()

Unnamed: 0,category,description
0,Business,"Short sellers, Wall Street's dwindling band o..."
1,Business,"Private investment firm Carlyle Group, which ..."
2,Business,Soaring crude prices plus worries about the e...
3,Business,Authorities have halted oil export flows from...
4,Business,"Tearaway world oil prices, toppling records a..."


Unnamed: 0,category,description
0,Business,Unions representing workers at Turner Newall...
1,Sci/Tech,"TORONTO, Canada A second team of rocketeer..."
2,Sci/Tech,A company founded by a chemistry researcher a...
3,Sci/Tech,It's barely dawn when Mike Fitzpatrick starts...
4,Sci/Tech,Southern California's smog fighting agency we...


(120000, 7600)

Создадим следующий пайплайн:
1. Преобразование **RAW** (исходного) текста в аннотированный формат (документ)
2. Вычисление эмбеддингов предложений
3. Дообучение классификатора

In [23]:
# Преобразование исходного текста в аннотированный формат
document = (
    sparknlp.base.DocumentAssembler()
        .setInputCol("description")
        .setOutputCol("document")
)

# Вычисление эмбеддингов
use = (
    sparknlp.annotator.UniversalSentenceEncoder.pretrained()
     .setInputCols(["document"])
     .setOutputCol("sentence_embeddings")
)

# Сама модель классификатора
classsifierdl = (
    sparknlp.annotator.ClassifierDLApproach()
      .setInputCols(["sentence_embeddings"])
      .setOutputCol("class")
      .setLabelColumn("category")
      .setMaxEpochs(5)
      .setEnableOutputLogs(True)
)

dl_pipeline = Pipeline(stages=[
    document,
    use,
    classsifierdl
])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


In [24]:
train_subsample = train_df.sample(0.1)
dl_pipeline_model = dl_pipeline.fit(train_subsample)

Обученные модели можно сохранять и загружать:

In [25]:
# Cохраняем обученную модель
dl_pipeline_model.stages[-1].write().overwrite().save('./tmp_classifierDL_model')

In [26]:
# Загружаем
classsifierdl_loaded = (
    sparknlp.annotator.ClassifierDLModel.load("./tmp_classifierDL_model")
      .setInputCols(["sentence_embeddings"])
      .setOutputCol("class")
)

dl_pipeline_eval = Pipeline(stages=[
    document,
    use,
    classsifierdl_loaded
])

Применим получившуюся модель для тестовых данных:

In [28]:
test_df_sample = spark.createDataFrame([
    "Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
    "Scientists have discovered irregular lumps beneath the icy surface of Jupiter's largest moon, Ganymede. These irregular masses may be rock formations, supported by Ganymede's icy shell for btest_dfions of years..."
], T.StringType()).toDF("description")

prediction = dl_pipeline_eval.fit(test_df_sample).transform(test_df_sample)

prediction.select('class.result').show()
prediction.select('class.metadata').show(truncate=False)

+----------+
|    result|
+----------+
|[Business]|
|[Sci/Tech]|
+----------+

+----------------------------------------------------------------------------------------------------------------+
|metadata                                                                                                        |
+----------------------------------------------------------------------------------------------------------------+
|[{Sports -> 3.989106E-8, Business -> 0.9999995, World -> 2.3197624E-7, Sci/Tech -> 2.3028691E-7, sentence -> 0}]|
|[{Sports -> 2.2062956E-7, Business -> 7.281296E-8, World -> 6.2745954E-8, Sci/Tech -> 0.9999995, sentence -> 0}]|
+----------------------------------------------------------------------------------------------------------------+



In [29]:
predictions = dl_pipeline_eval.fit(test_df).transform(test_df)

predictions\
.select('category', 'description', 'class.result')\
.sample(0.01)\
.show(10, truncate=50)

+--------+--------------------------------------------------+----------+
|category|                                       description|    result|
+--------+--------------------------------------------------+----------+
|  Sports|Michael Phelps won the gold medal in the 400 in...|  [Sports]|
|  Sports|The Cleveland Indians pulled within one game of...|  [Sports]|
|Business| The dollar extended gains against the  euro on...|[Business]|
|  Sports| Carly Patterson upstaged Russian diva  Svetlan...|  [Sports]|
|  Sports|IT WAS the night of the longest race and the sh...|  [Sports]|
|Sci/Tech|One thing that #39;s always irritated those who...|[Sci/Tech]|
|Sci/Tech|Hawaii #39;s Keck Observatory has confirmed the...|[Sci/Tech]|
|   World| A group calling itself the Secret Islamic Army...|   [World]|
|  Sports| Sweden's Fredrik Jacobson made his bid  for a ...|  [Sports]|
|  Sports| Troy Glaus was activated from the 60 day disab...|  [Sports]|
+--------+-----------------------------------------

### `Оценка качества`

Оценим качество получившейся модели, посчитав Accuracy.

Это можно сделать двумя путями:
1. Посчитать Accuracy вручную
2. Использовать `MulticlassClassificationEvaluator`

Сделаем это обоими методами.

Оставим только нужные колонки для удобства. Также, преобразуем колонку с предсказанием. `class.result` — список из одного элемента. `F.explode` позволяет его распаковать. Как альтернативу, можно использовать `F.element_at(F.col('class.result'), 1)`:

In [30]:
predictions = predictions.select('category', 'description', F.explode('class.result').alias('prediction'))
predictions.show(10)

+--------+--------------------+----------+
|category|         description|prediction|
+--------+--------------------+----------+
|Business|Unions representi...|  Business|
|Sci/Tech| TORONTO, Canada ...|  Sci/Tech|
|Sci/Tech| A company founde...|  Sci/Tech|
|Sci/Tech| It's barely dawn...|    Sports|
|Sci/Tech| Southern Califor...|  Business|
|Sci/Tech|"The British Depa...|  Sci/Tech|
|Sci/Tech|"confessed author...|  Sci/Tech|
|Sci/Tech|\\FOAF/LOAF  and ...|  Sci/Tech|
|Sci/Tech|"Wiltshire Police...|  Sci/Tech|
|Sci/Tech|In its first two ...|  Sci/Tech|
+--------+--------------------+----------+
only showing top 10 rows



In [31]:
(
    predictions
        .select(F.mean(
            # Для агрегации усреднением нужно выполнить преобразование к вещественному типу
            (F.col('category') == F.col('prediction')).cast(T.FloatType())
        ).alias('accuracy'))
).show()

+------------------+
|          accuracy|
+------------------+
|0.8793421052631579|
+------------------+



### Второй вариант вычисления метрики

`MulticlassClassificationEvaluator` требует колонки вещественного типа, поэтому нам нужно предварительно выполнить кодирование названий категорий в индексы. Один из вариантов это сделать: `join` с таблицей, в которой хранится это соответствие.

Такую табличку можно сделать множеством вариантов. Вот несколько из них:

In [32]:
# Сложный вариант, но полезный, когда классов много и нужно выполнить детерминированное кодирование (за счёт сортировки)
# Для этого сначала категории уникуются, затем устанавливается соответствие, нумеруя с 1
wspec = Window().partitionBy().orderBy('category')
label_to_idx = (
    train_df
        .select('category')
        .distinct()
        .select(F.col('category').alias('label'), F.row_number().over(wspec).alias('idx'))
)
label_to_idx.show()

+--------+---+
|   label|idx|
+--------+---+
|Business|  1|
|Sci/Tech|  2|
|  Sports|  3|
|   World|  4|
+--------+---+



In [33]:
predictions_with_idx = (
    predictions
        .join(
            F.broadcast(label_to_idx).withColumnRenamed('idx', 'category_idx'),
            on=predictions.category == label_to_idx.label
        )
        .drop('label')
        .join(
            F.broadcast(label_to_idx).withColumnRenamed('idx', 'prediction_idx'),
            on=F.col('prediction') == label_to_idx.label
        )
        .drop('label')
)
predictions_with_idx.show(10)

+--------+--------------------+----------+------------+--------------+
|category|         description|prediction|category_idx|prediction_idx|
+--------+--------------------+----------+------------+--------------+
|Business|Unions representi...|  Business|           1|             1|
|Sci/Tech| TORONTO, Canada ...|  Sci/Tech|           2|             2|
|Sci/Tech| A company founde...|  Sci/Tech|           2|             2|
|Sci/Tech| It's barely dawn...|    Sports|           2|             3|
|Sci/Tech| Southern Califor...|  Business|           2|             1|
|Sci/Tech|"The British Depa...|  Sci/Tech|           2|             2|
|Sci/Tech|"confessed author...|  Sci/Tech|           2|             2|
|Sci/Tech|\\FOAF/LOAF  and ...|  Sci/Tech|           2|             2|
|Sci/Tech|"Wiltshire Police...|  Sci/Tech|           2|             2|
|Sci/Tech|In its first two ...|  Sci/Tech|           2|             2|
+--------+--------------------+----------+------------+--------------+
only s

In [34]:
evaluator = MulticlassClassificationEvaluator(
    predictionCol='prediction_idx', labelCol='category_idx', metricName='accuracy'
)
evaluator.evaluate((
    predictions_with_idx
        .select(
            F.col('category_idx').cast(T.DoubleType()),
            F.col('prediction_idx').cast(T.DoubleType())
        )
))

0.8793421052631579

Ешё один вариант для кодирования строк, который больше подходит по стилю при работы с пайплайнами: `StringIndexer`.

In [35]:
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol='label', outputCol='idx')
indexer_model = indexer.fit(train_df.withColumnRenamed('category', 'label'))

predictions_with_idx = indexer_model.transform(
    predictions.withColumnRenamed('category', 'label')
).withColumnRenamed('idx', 'category_idx').withColumnRenamed('label', 'category')

predictions_with_idx = indexer_model.transform(
    predictions_with_idx.withColumnRenamed('prediction', 'label')
).withColumnRenamed('idx', 'prediction_idx').withColumnRenamed('label', 'prediction')
predictions_with_idx.show(10)

+--------+--------------------+----------+------------+--------------+
|category|         description|prediction|category_idx|prediction_idx|
+--------+--------------------+----------+------------+--------------+
|Business|Unions representi...|  Business|         0.0|           0.0|
|Sci/Tech| TORONTO, Canada ...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech| A company founde...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech| It's barely dawn...|    Sports|         1.0|           2.0|
|Sci/Tech| Southern Califor...|  Business|         1.0|           0.0|
|Sci/Tech|"The British Depa...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech|"confessed author...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech|\\FOAF/LOAF  and ...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech|"Wiltshire Police...|  Sci/Tech|         1.0|           1.0|
|Sci/Tech|In its first two ...|  Sci/Tech|         1.0|           1.0|
+--------+--------------------+----------+------------+--------------+
only s

In [36]:
evaluator.evaluate(predictions_with_idx)

0.8793421052631579