In [1]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

In [2]:
sc = SparkSession.builder.appName('fear-eng').getOrCreate()

24/01/10 22:05:50 WARN Utils: Your hostname, hewens-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.0.108 instead (on interface en0)
24/01/10 22:05:50 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/01/10 22:05:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
anime_df = sc.read.csv('/Users/hewenpang/Documents/tf_explore/RecommendSystem_test/archive/anime.csv',header=True,inferSchema=True)
anime_df.printSchema()

root
 |-- anime_id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- genre: string (nullable = true)
 |-- type: string (nullable = true)
 |-- episodes: string (nullable = true)
 |-- rating: double (nullable = true)
 |-- members: integer (nullable = true)



In [15]:
anime_df.select(['anime_id','genre']).show(10)

+--------+--------------------+
|anime_id|               genre|
+--------+--------------------+
|   32281|Drama, Romance, S...|
|    5114|Action, Adventure...|
|   28977|Action, Comedy, H...|
|    9253|    Sci-Fi, Thriller|
|    9969|Action, Comedy, H...|
|   32935|Comedy, Drama, Sc...|
|   11061|Action, Adventure...|
|     820|Drama, Military, ...|
|   15335|Action, Comedy, H...|
|   15417|Action, Comedy, H...|
+--------+--------------------+
only showing top 10 rows



Multi-hot encode

In [13]:
from pyspark.ml.feature import StringIndexer

genre_df = anime_df.withColumn('genre_item',explode(split(col('genre'),','))).withColumn('genre_item',trim(col('genre_item')))
genre_df.select(['anime_id','genre_item']).show(10)

+--------+------------+
|anime_id|  genre_item|
+--------+------------+
|   32281|       Drama|
|   32281|     Romance|
|   32281|      School|
|   32281|Supernatural|
|    5114|      Action|
|    5114|   Adventure|
|    5114|       Drama|
|    5114|     Fantasy|
|    5114|       Magic|
|    5114|    Military|
+--------+------------+
only showing top 10 rows



In [16]:
string_indexer = StringIndexer(inputCol='genre_item',outputCol='genre_index')
genre_indexed_df  = string_indexer.fit(genre_df).transform(genre_df).withColumn('genre_index',col('genre_index').cast('int'))
genre_indexed_df[['anime_id','genre_item','genre_index']].show(10)

+--------+------------+-----------+
|anime_id|  genre_item|genre_index|
+--------+------------+-----------+
|   32281|       Drama|          5|
|   32281|     Romance|          8|
|   32281|      School|          9|
|   32281|Supernatural|         12|
|    5114|      Action|          1|
|    5114|   Adventure|          2|
|    5114|       Drama|          5|
|    5114|     Fantasy|          3|
|    5114|       Magic|         16|
|    5114|    Military|         23|
+--------+------------+-----------+
only showing top 10 rows



In [17]:
pre_multigot_df = genre_indexed_df.groupBy('anime_id').agg(collect_list('genre_index').alias('genre_indexes'))
pre_multigot_df.show(10)

+--------+--------------------+
|anime_id|       genre_indexes|
+--------+--------------------+
|       1| [1, 2, 0, 5, 4, 25]|
|       5|   [1, 5, 21, 4, 25]|
|       6|           [1, 0, 4]|
|       7|[1, 5, 16, 21, 32...|
|       8|       [2, 3, 6, 12]|
|      15|       [1, 0, 6, 20]|
|      16|       [0, 5, 40, 8]|
|      17|      [0, 6, 10, 20]|
|      18|  [1, 37, 5, 19, 20]|
|      19|[5, 26, 21, 32, 3...|
+--------+--------------------+
only showing top 10 rows



In [18]:
max_genre_index = genre_indexed_df.agg(max(col('genre_index'))).head()['max(genre_index)']
max_genre_index

42

In [20]:
import numpy as np 

@udf(returnType='array<int>')
def multihot_list(l,max_index):
    fill = np.zeros(max_index+1,dtype=np.int32)
    for i in l:
        fill[i] = 1
    return fill.tolist()

In [22]:
multihot_df = pre_multigot_df.withColumn('genre_multihot',multihot_list(col('genre_indexes'),lit(max_genre_index)))
multihot_df.show(10)

+--------+--------------------+--------------------+
|anime_id|       genre_indexes|      genre_multihot|
+--------+--------------------+--------------------+
|       1| [1, 2, 0, 5, 4, 25]|[1, 1, 1, 0, 1, 1...|
|       5|   [1, 5, 21, 4, 25]|[0, 1, 0, 0, 1, 1...|
|       6|           [1, 0, 4]|[1, 1, 0, 0, 1, 0...|
|       7|[1, 5, 16, 21, 32...|[0, 1, 0, 0, 0, 1...|
|       8|       [2, 3, 6, 12]|[0, 0, 1, 1, 0, 0...|
|      15|       [1, 0, 6, 20]|[1, 1, 0, 0, 0, 0...|
|      16|       [0, 5, 40, 8]|[1, 0, 0, 0, 0, 1...|
|      17|      [0, 6, 10, 20]|[1, 0, 0, 0, 0, 0...|
|      18|  [1, 37, 5, 19, 20]|[0, 1, 0, 0, 0, 1...|
|      19|[5, 26, 21, 32, 3...|[0, 0, 0, 0, 0, 1...|
+--------+--------------------+--------------------+
only showing top 10 rows

