# Imports 

## Application-specific imports 

In [1]:
import sys

In [2]:
sys.path.append("../config/")
import config

In [3]:
sys.path.append("../metaflow/")
import preprocess_fn

## General 

In [4]:
import pickle
import itertools
import pandas as pd
import sklearn
from sklearn import preprocessing

In [5]:
import pyspark
import pyspark.sql.functions as fn
import pyspark.sql.types as t

# Load data from parquet

In [6]:
df = spark.read.parquet(f'{config.ARTIFACTS}/dataset/M20_cards.parquet')

# Preprocess 

In [7]:
df_filtered = df

## Replace text with keywords based on a dictionary

In [8]:
if "text_features" in df_filtered.columns:
    df_filtered = df_filtered.drop("text_features")

In [9]:
df_filtered = df_filtered.withColumn('text_features', preprocess_fn.udf_text_to_keywords('name', 'originalText'))

## Fetch all the text features from all the cards into one list 

In [10]:
all_text_feats = df_filtered.select("text_features").rdd.flatMap(lambda x: x).collect()

In [11]:
filtered_text_feats = [items for items in all_text_feats if len(items) > 0]

In [12]:
filtered_text_feats = list(itertools.chain.from_iterable(filtered_text_feats))

## Encode the text features into ints

In [13]:
label_encoder = preprocessing.LabelEncoder().fit(filtered_text_feats)

In [14]:
import pickle

In [15]:
with open(f"{config.SPARK_MODELS}/labelencoder_text_feats", "wb") as fp:
    pickle.dump(label_encoder, fp)

In [16]:
@fn.udf(returnType=t.ArrayType(t.IntegerType()))
def text_to_vector(text_features):
    if len(text_features) > 0:
        enc_list = list()
        for item in text_features:
            item = str(item)
            encoded = label_encoder.transform([item])
            encoded = int(encoded[0])
            enc_list.append(encoded)
#             print(f"{item} \t {encoded}")
        return enc_list
    return list()

In [17]:
if "text_features_vect" in df_filtered.columns:
    df_filtered = df_filtered.drop("text_features_vect")

In [18]:
df_filtered = df_filtered.withColumn("text_features_vect", text_to_vector("text_features"))

In [19]:
all_text_feats = df_filtered.select("text_features").rdd.flatMap(lambda x: x).collect()

In [20]:
filtered_text_feats = [items for items in all_text_feats if len(items) > 0]

In [21]:
filtered_text_feats = list(itertools.chain.from_iterable(filtered_text_feats))

In [22]:
df_filtered.createOrReplaceTempView("cards_features")

In [26]:
tbl = spark.sql("""
    SELECT
        *
    FROM
        cards_features
""")

# Save to Parquet

In [27]:
tbl.write.mode("overwrite").parquet(f"{config.TEMP}/M20_cards_text.parquet")

In [28]:
## this fails
# tbl.write.mode("overwrite").parquet('dataset/THB_cards.parquet')

In [29]:
cp -R "/tmp/M20_cards_text.parquet" "../artifacts/dataset/"