In [1]:
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.ml.linalg import MatrixUDT, DenseMatrix, VectorUDT, DenseVector
from pyspark.sql.functions import col, expr, lit

In [2]:
spark = SparkSession.builder.appName("Sparkify").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/13 16:32:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data_path = "data/fewest_events"
data_type = "json"

In [4]:
df = spark.read.format(data_type).load(data_path)

In [5]:
df.show()

+--------------------+------+
|                data|userId|
+--------------------+------+
|[{Coldplay, Logge...|   135|
+--------------------+------+



In [6]:
df135 = df.filter(df.userId == 135)
df135transformed = df135.withColumn("data", expr("slice(data, 2, size(data))")).withColumn("userId", lit('136'))
df = df.union(df135transformed)

In [7]:
df.show()

+--------------------+------+
|                data|userId|
+--------------------+------+
|[{Coldplay, Logge...|   135|
|[{Ratatat, Logged...|   136|
+--------------------+------+



In [8]:
unique_pages_data = "data/unique_pages.json"

In [9]:
with open(unique_pages_data, 'r') as f:
    unique_pages = json.load(f)

In [10]:
unique_pages

{'Submit Downgrade': 0,
 'Thumbs Down': 1,
 'Home': 2,
 'Downgrade': 3,
 'Roll Advert': 4,
 'Logout': 5,
 'Save Settings': 6,
 'About': 7,
 'Settings': 8,
 'Login': 9,
 'Add to Playlist': 10,
 'Add Friend': 11,
 'NextSong': 12,
 'Thumbs Up': 13,
 'Help': 14,
 'Upgrade': 15,
 'Error': 16,
 'Submit Upgrade': 17,
 'Cancel': 18,
 'Cancellation Confirmation': 19,
 'Submit Registration': 20,
 'Register': 21}

In [11]:
model = SentenceTransformer("all-MiniLM-L6-v2")

In [12]:
cols = ['artist',
 'auth',
 'firstName',
 'gender',
 'itemInSession',
 'lastName',
 'length',
 'level',
 'location',
 'method',
 'registration',
 'sessionId',
 'song',
 'status',
 'userAgent',
 'userId',
 'page']

In [13]:
target_col = 'page'

In [14]:
column_embeddings = {}

In [15]:
for c in cols:
    column_embeddings[c] = model.encode(c)

In [16]:
def compute_embeddings(data) -> DenseMatrix:
    embeddings = []
    for row in data:
        for c in cols:
            body = row[c]
            body = str(body)
            name_embedding = column_embeddings[c]
            body_embedding = model.encode(body)
            summed_embedding = name_embedding + body_embedding
            embeddings.append(summed_embedding)
    embeddings = np.stack(embeddings)
    return DenseMatrix(embeddings.shape[0], embeddings.shape[1], embeddings.flatten().tolist(), isTransposed=True)

In [17]:
def compute_target_vector(data):
    targets = []
    for row in data:
        for c in cols:
            body = row[c]
            if body in unique_pages and c == target_col:
                targets.append(unique_pages[body])
            else:
                targets.append(-1)
    return DenseVector(targets)

In [18]:
compute_emb_udf = udf(compute_embeddings, MatrixUDT())
compute_target_udf = udf(compute_target_vector, VectorUDT())

In [19]:
df_emb = df.withColumn("embeddings", compute_emb_udf(col("data")))

In [20]:
df_emb_target = df_emb.withColumn("target", compute_target_udf(col("data")))

In [21]:
df_emb_target.show()

                                                                                

+--------------------+------+--------------------+--------------------+
|                data|userId|          embeddings|              target|
+--------------------+------+--------------------+--------------------+
|[{Coldplay, Logge...|   135|-0.14549641311168...|[-1.0,-1.0,-1.0,-...|
|[{Ratatat, Logged...|   136|-0.12564811110496...|[-1.0,-1.0,-1.0,-...|
+--------------------+------+--------------------+--------------------+



                                                                                

In [22]:
# replace the "data" in db_emb_target with tuple (embeddings, target)
df_emb_target = df_emb_target.drop("data")

In [23]:
from pyspark.sql.functions import struct

In [24]:
df_emb_target = df_emb_target.withColumn("data", struct(df_emb_target.embeddings, df_emb_target.target))

In [25]:
df_emb_target.show()

                                                                                

+------+--------------------+--------------------+--------------------+
|userId|          embeddings|              target|                data|
+------+--------------------+--------------------+--------------------+
|   135|-0.14549641311168...|[-1.0,-1.0,-1.0,-...|{-0.1454964131116...|
|   136|-0.12564811110496...|[-1.0,-1.0,-1.0,-...|{-0.1256481111049...|
+------+--------------------+--------------------+--------------------+



                                                                                

In [45]:
df135 = df_emb_target.filter(df_emb_target.userId == 135)

In [46]:
df135_embeddings = df135.select("embeddings").collect()

                                                                                

In [52]:
truth = df135_embeddings[0].embeddings.toArray()

In [59]:
truth = torch.tensor(truth)

In [26]:
from datasets import Dataset, IterableDataset

In [27]:
ds = Dataset.from_spark(df_emb_target)

                                                                                

In [28]:
# iterable_ds = IterableDataset.from_spark(df_emb_target)

In [29]:
# next(iter(iterable_ds))

In [30]:
ds_sample = ds[0]

In [31]:
emb, targs = ds_sample["embeddings"], ds_sample["target"]

In [32]:
from torch.utils.data import DataLoader

In [33]:
sample = None

In [34]:
def collate_fn(x):
    return x

In [35]:
dl = DataLoader(ds, batch_size=2, collate_fn=collate_fn)

In [36]:
sampled = next(iter(dl))

In [63]:
type(sampled)

list

In [64]:
type(sampled[0])

dict

In [41]:
type(sampled[0]["embeddings"]["values"])

list

In [53]:
rows, cols = sampled[0]["embeddings"]["numRows"], sampled[0]["embeddings"]["numCols"]

In [61]:
actual = torch.tensor(sampled[0]["embeddings"]["values"], dtype=truth.dtype).view(rows, cols)

In [62]:
torch.allclose(actual, torch.tensor(truth))

  torch.allclose(actual, torch.tensor(truth))


True

In [43]:
type(sampled[0]["target"]["values"])

list