# Load the set from JSON

In [1]:
import json
import numpy as np

In [2]:
from pyspark.ml.feature import StringIndexer, IndexToString

# Clean the data 

In [3]:
import pyspark.sql.functions as fn
import pyspark.sql.types as t

In [5]:
df = spark.read.json('../dataset/M20_cards.json')

## Filter out duplicate cards

In [6]:
pd_names = df.select(['number', 'name']).toPandas()
unique_names, indices, counts = np.unique(pd_names['name'], return_index=True, return_counts=True)

In [7]:
pd_unique_names = pd_names.loc[indices]

In [8]:
df_filter = spark.createDataFrame(pd_unique_names)

In [9]:
df_filtered = df_filter.join(df, on='number', how='left').drop(df_filter.name)

In [10]:
num_cards  = df_filtered.count()

In [11]:
print(f'Final number of cards {num_cards}')

Final number of cards 329


## Drop columns 

In [12]:
keep_cols = ['colorIdentity','convertedManaCost','colors','manaCost','name','number','text','power','rarity','subtypes','supertypes','toughness', 'types']

In [13]:
remove_cols = list(set(df.columns) - set(keep_cols))

In [14]:
df_filtered = df_filtered.drop(*remove_cols)

## Filter the text

In [15]:
rules = {
    "When {CARDNAME} enters the battlefield": "ETB_EFFECT",
    
    "Flash (You may cast this spell any time you could cast an instant.)": "FLASH",
    "Flash": "FLASH",
    
    "Reach (This creature can block creatures with flying.)": "REACH",
    "Reach": "REACH",
    
    "Flying (This creature can't be blocked except by creatures with flying or reach.)": "FLYING",
    "Flying": "FLYING",
    
    "Haste (This creature can attack and {T} as soon as it comes under your control.)": "HASTE",
    "Haste": "HASTE",
    
    "Trample (This creature can deal excess combat damage to the player or planeswalker it's attacking.)": "TRAMPLE",
    "Trample": "TRAMPLE",
    
    "Vigilance (Attacking doesn't cause this creature to tap.)": "VIGILANCE",
    "Vigilance": "VIGILANCE",

    "Double strike (This creature deals both first-strike and regular combat damage.)": "DOUBLE_STRIKE",
    "Double strike": "DOUBLE_STRIKE",

    "Deathtouch (Any amount of damage this deals to a creature is enough to destroy it.)": "DEATHTOUCH",
    "Deathtouch": "DEATHTOUCH",
    
    "Protection from green (This creature can't be blocked, targeted, dealt damage, enchanted, or equipped by anything green.)": "PROTECTION_FROM_GREEN",
    "Protection from red (This creature can't be blocked, targeted, dealt damage, enchanted, or equipped by anything red.)": "PROTECTION_FROM_RED",
    "Protection from black (This creature can't be blocked, targeted, dealt damage, enchanted, or equipped by anything black.)": "PROTECTION_FROM_BLACK",
    "Protection from blue (This creature can't be blocked, targeted, dealt damage, enchanted, or equipped by anything blue.)": "PROTECTION_FROM_BLUE",
    "Protection from white (This creature can't be blocked, targeted, dealt damage, enchanted, or equipped by anything white.)": "PROTECTION_FROM_WHITE",
    
    "(As this Saga enters and after your draw step, add a lore counter. Sacrifice after III.)": "SAGA_3",
    "(As this Saga enters and after your draw step, add a lore counter. Sacrifice after IV.)": "SAGA_4"
}

In [16]:
@fn.udf(returnType=t.ArrayType(t.StringType()))
def udf_filter_text(name, text):
    feats = list()
    if isinstance(text, str):
        new_text = text.replace(name, 'CARDNAME')
        for line in new_text.split('\n'):
            for rule, replace in rules.items():
                if line.startswith(rule):
                    line = line.replace(rule, replace)
                    feats.append(replace)
    return feats

In [17]:
if "text_features" in df_filtered.columns:
    df_filtered = df_filtered.drop("text_features")

df_filtered = df_filtered.withColumn('text_features', udf_filter_text('name', 'text'))

In [18]:
# df_filtered.printSchema()

In [19]:
all_text_feats = df_filtered.select("text_features").rdd.flatMap(lambda x: x).collect()

In [20]:
filtered_text_feats = [items for items in all_text_feats if len(items) > 0]

In [21]:
import itertools
filtered_text_feats = list(itertools.chain.from_iterable(filtered_text_feats))

In [22]:
import pandas as pd
import sklearn
from sklearn import preprocessing

In [23]:
lenc = preprocessing.LabelEncoder().fit(filtered_text_feats)

In [24]:
@fn.udf(returnType=t.ArrayType(t.IntegerType()))
def text_to_vector(text_features):
    if len(text_features) > 0:
        enc_list = list()
        for item in text_features:
            item = str(item)
            encoded = lenc.transform([item])
            encoded = int(encoded[0])
            enc_list.append(encoded)
            
            print(f"{item} \t {encoded}")
        return enc_list
    return list()

In [25]:
if "text_features_vect" in df_filtered.columns:
    df_filtered = df_filtered.drop("text_features_vect")

df_filtered = df_filtered.withColumn("text_features_vect", text_to_vector("text_features"))

In [26]:
df_filtered.select("text_features_vect").distinct().show()

+------------------+
|text_features_vect|
+------------------+
|            [2, 3]|
|              [12]|
|                []|
|               [1]|
|               [6]|
|               [3]|
|               [5]|
|               [9]|
|               [4]|
|               [7]|
|              [10]|
|           [3, 12]|
|              [11]|
|            [3, 8]|
|            [3, 4]|
|               [2]|
|               [0]|
+------------------+



## Explode the selected arrays in a string, separated by "," 

In [27]:
def explode_to_strs(df, cols):
    for col in cols:
        df_edited = df.selectExpr(["number", col]).select('number', fn.expr(f"concat_ws(',', {col})").alias(f"str_{col}"))
        df = df.join(df_edited, on='number')
    return df

In [28]:
df_filtered = explode_to_strs(df_filtered, ["colorIdentity", "types", "subtypes", "supertypes"])

## Encode newly created strings

In [29]:
from pyspark.ml.feature import StringIndexer, IndexToString

In [30]:
def encode_strings(df, cols):
    for col in cols:
        indexer = StringIndexer(inputCol=f"{col}", outputCol=f"encoded_{col}", stringOrderType='alphabetAsc')
        model = indexer.fit(df)
        df = model.transform(df)
        
#         indexer.save(f"/tmp/pyspark/stringindexer_{col}")
#         model.save(f"/tmp/pyspark/stringindexer_model_{col}")
    return df

In [31]:
df_filtered = encode_strings(df_filtered, ["rarity", "str_colorIdentity", "str_types", "str_subtypes", "str_supertypes"])

In [32]:
df_filtered.select(["types", "str_types", "encoded_str_types"]).distinct().show()

+--------------------+-----------------+-----------------+
|               types|        str_types|encoded_str_types|
+--------------------+-----------------+-----------------+
|          [Artifact]|         Artifact|              0.0|
|      [Planeswalker]|     Planeswalker|              6.0|
|[Artifact, Creature]|Artifact,Creature|              1.0|
|              [Land]|             Land|              5.0|
|           [Instant]|          Instant|              4.0|
|           [Sorcery]|          Sorcery|              7.0|
|          [Creature]|         Creature|              2.0|
|       [Enchantment]|      Enchantment|              3.0|
+--------------------+-----------------+-----------------+



## Count the number of colors

In [33]:
df_filtered = df_filtered.withColumn("num_colors", fn.size("colors"))

# Create an SQL table 

In [34]:
df_filtered.createOrReplaceTempView("cards_features")

In [35]:
tbl = spark.sql("""
    SELECT
        CAST(number as Integer),
        rarity,
        text_features_vect,
        CAST(convertedManaCost as Integer),
        CAST(num_colors as Integer) as numColors,
        str_colorIdentity as colorIdentity,
        CAST(encoded_str_colorIdentity as Integer) as encodedColorIdentity,
        str_types as types,
        CAST(encoded_str_types as Integer) as encodedTypes,
        str_subtypes as subTypes,
        CAST(encoded_str_subtypes as Integer) as encodedSubTypes,
        str_supertypes as superTypes,
        CAST(encoded_str_supertypes as Integer) as encodedSuperTypes,
        CAST(power as Integer),
        CAST(toughness as Integer)

    FROM
        cards_features
""")

In [36]:
spark.sql("""
    SELECT
        text_features_vect
    FROM
        cards_features
""").show()

+------------------+
|text_features_vect|
+------------------+
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|                []|
|               [3]|
|                []|
|               [3]|
|                []|
|               [3]|
|              [11]|
|               [0]|
+------------------+
only showing top 20 rows



In [37]:
tbl.first()

Row(number=296, rarity='uncommon', text_features_vect=[], convertedManaCost=5, numColors=1, colorIdentity='R', encodedColorIdentity=16, types='Creature', encodedTypes=2, subTypes='Elemental', encodedSubTypes=25, superTypes='', encodedSuperTypes=0, power=5, toughness=4)

# Save to Parquet

In [38]:
tbl.write.mode("overwrite").parquet('/tmp/M20_cards_features.parquet')

In [39]:
rm -rf "../dataset/M20_cards_features.parquet/"

In [40]:
mv "/tmp/M20_cards_features.parquet" "../dataset/"