In [1]:
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.ml.linalg import MatrixUDT, DenseMatrix
from pyspark.sql.functions import col

In [2]:
spark = SparkSession.builder.appName("Sparkify").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/04/12 08:50:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data_path = "data/fewest_events"
data_type = "json"

In [4]:
df = spark.read.format(data_type).load(data_path)

                                                                                

In [5]:
df.show()

+--------------------+------+
|                data|userId|
+--------------------+------+
|[{Coldplay, Logge...|   135|
+--------------------+------+



In [6]:
unique_pages_data = "data/unique_pages.json"

In [7]:
with open(unique_pages_data, 'r') as f:
    unique_pages = json.load(f)

In [8]:
unique_pages

{'Submit Downgrade': 0,
 'Thumbs Down': 1,
 'Home': 2,
 'Downgrade': 3,
 'Roll Advert': 4,
 'Logout': 5,
 'Save Settings': 6,
 'About': 7,
 'Settings': 8,
 'Login': 9,
 'Add to Playlist': 10,
 'Add Friend': 11,
 'NextSong': 12,
 'Thumbs Up': 13,
 'Help': 14,
 'Upgrade': 15,
 'Error': 16,
 'Submit Upgrade': 17,
 'Cancel': 18,
 'Cancellation Confirmation': 19,
 'Submit Registration': 20,
 'Register': 21}

In [9]:
model = SentenceTransformer("all-MiniLM-L6-v2")

In [10]:
cols = ['artist',
 'auth',
 'firstName',
 'gender',
 'itemInSession',
 'lastName',
 'length',
 'level',
 'location',
 'method',
 'page',
 'registration',
 'sessionId',
 'song',
 'status',
 'userAgent',
 'userId']

In [11]:
column_embeddings = {}

In [12]:
for c in cols:
    column_embeddings[c] = model.encode(c)

In [13]:
out = []

In [14]:
def compute_embeddings(data) -> DenseMatrix:
    embeddings = []
    prevTs = float("-inf")
    for row in data:
        assert row['ts'] >= prevTs
        prevTs = row['ts']
        for c in cols:
            body = row[c]
            body = str(body)
            name_embedding = column_embeddings[c]
            body_embedding = model.encode(body)
            summed_embedding = name_embedding + body_embedding
            embeddings.append(summed_embedding)
    embeddings = np.stack(embeddings)
    return DenseMatrix(embeddings.shape[0], embeddings.shape[1], embeddings.flatten().tolist(), isTransposed=True)

In [15]:
compute_emb_udf = udf(compute_embeddings, MatrixUDT())

In [16]:
df_emb = df.withColumn("embeddings", compute_emb_udf(col("data")))

In [17]:
df_emb.show()

[Stage 2:>                                                          (0 + 1) / 1]

+--------------------+------+--------------------+
|                data|userId|          embeddings|
+--------------------+------+--------------------+
|[{Coldplay, Logge...|   135|-0.14549641311168...|
+--------------------+------+--------------------+



                                                                                

In [18]:
sample = df_emb.select("embeddings").first().embeddings

[Stage 3:>                                                          (0 + 1) / 1]

                                                                                

In [19]:
sample_data = df.select("data").first().data

In [20]:
sample_ndarray = sample.toArray()

In [21]:
test_emb = []

In [22]:
sample_data

[Row(artist='Coldplay', auth='Logged In', firstName='Zoey', gender='F', itemInSession=0, lastName='Nelson', length=311.27465, level='free', location='Las Vegas-Henderson-Paradise, NV', method='PUT', page='NextSong', registration=1532433959000, sessionId=134, song='The Scientist', status=200, ts=1538661151000, userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.94 Safari/537.36"', userId='135'),
 Row(artist='Ratatat', auth='Logged In', firstName='Zoey', gender='F', itemInSession=1, lastName='Nelson', length=226.53342, level='free', location='Las Vegas-Henderson-Paradise, NV', method='PUT', page='NextSong', registration=1532433959000, sessionId=134, song='Loud Pipes', status=200, ts=1538661462000, userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.94 Safari/537.36"', userId='135'),
 Row(artist='Fergie', auth='Logged In', firstName='Zoey', gender='F', itemInSession=2, lastName='Nelson'

In [23]:
for row in sample_data:
    for c in cols:
        body = row[c]
        body = str(body)
        name_embedding = column_embeddings[c]
        body_embedding = model.encode(body)
        summed_embedding = name_embedding + body_embedding
        test_emb.append(summed_embedding)
test_emb = np.stack(test_emb)

In [24]:
np.allclose(sample_ndarray, test_emb)

True