In [1]:
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.ml.linalg import MatrixUDT, DenseMatrix, VectorUDT, DenseVector
from pyspark.sql.functions import col, expr, lit
from pyspark.sql.types import ArrayType

In [2]:
spark = SparkSession.builder.appName("Sparkify").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/25 16:53:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data_path = "data/fewest_events"
data_type = "json"

In [4]:
df_original = spark.read.json("data/mini_sparkify_event_data.json")

                                                                                

In [5]:
schema = ArrayType(df_original.schema)

In [6]:
df = spark.read.format(data_type).load(data_path)

In [7]:
df.show()

+--------------------+------+
|                data|userId|
+--------------------+------+
|[{Coldplay, Logge...|   135|
+--------------------+------+



In [8]:
from src.dataset.transformation import ungroup

In [9]:
ungrouped_df = ungroup(df, alias="data", original_schema=schema)

In [10]:
ungrouped_df.show()

+------+------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+
|userId|      artist|     auth|firstName|gender|itemInSession|lastName|   length|level|            location|method|    page| registration|sessionId|                song|status|           ts|           userAgent|
+------+------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+
|   135|    Coldplay|Logged In|     Zoey|     F|            0|  Nelson|311.27465| free|Las Vegas-Henders...|   PUT|NextSong|1532433959000|      134|       The Scientist|   200|1538661151000|"Mozilla/5.0 (Win...|
|   135|     Ratatat|Logged In|     Zoey|     F|            1|  Nelson|226.53342| free|Las Vegas-Henders...|   PUT|NextSong|1532433959000|      134|    

In [11]:
df135 = df.filter(df.userId == 135)
df135transformed = df135.withColumn("data", expr("slice(data, 2, size(data))")).withColumn("userId", lit('136'))
df = df.union(df135transformed)

In [12]:
df.show()

+--------------------+------+
|                data|userId|
+--------------------+------+
|[{Coldplay, Logge...|   135|
|[{Ratatat, Logged...|   136|
+--------------------+------+



In [13]:
unique_pages_data = "data/unique_pages.json"

In [14]:
with open(unique_pages_data, 'r') as f:
    unique_pages = json.load(f)

In [15]:
unique_pages

{'Submit Downgrade': 0,
 'Thumbs Down': 1,
 'Home': 2,
 'Downgrade': 3,
 'Roll Advert': 4,
 'Logout': 5,
 'Save Settings': 6,
 'About': 7,
 'Settings': 8,
 'Login': 9,
 'Add to Playlist': 10,
 'Add Friend': 11,
 'NextSong': 12,
 'Thumbs Up': 13,
 'Help': 14,
 'Upgrade': 15,
 'Error': 16,
 'Submit Upgrade': 17,
 'Cancel': 18,
 'Cancellation Confirmation': 19,
 'Submit Registration': 20,
 'Register': 21}

In [16]:
model_name = "all-MiniLM-L6-v2"

In [17]:
model = SentenceTransformer(model_name)

In [18]:
cols = ['artist',
 'auth',
 'firstName',
 'gender',
 'itemInSession',
 'lastName',
 'length',
 'level',
 'location',
 'method',
 'registration',
 'sessionId',
 'song',
 'status',
 'userAgent',
 'userId',
 'page']

In [19]:
target_col = 'page'

In [20]:
column_embeddings = {}

In [21]:
for c in cols:
    column_embeddings[c] = model.encode(c)

In [22]:
def compute_embeddings(data) -> DenseMatrix:
    embeddings = []
    for row in data:
        for c in cols:
            body = row[c]
            body = str(body)
            name_embedding = column_embeddings[c]
            body_embedding = model.encode(body)
            summed_embedding = name_embedding + body_embedding
            embeddings.append(summed_embedding)
    embeddings = np.stack(embeddings)
    return DenseMatrix(embeddings.shape[0], embeddings.shape[1], embeddings.flatten().tolist(), isTransposed=True)

In [23]:
def compute_target_vector(data):
    targets = []
    for row in data:
        for c in cols:
            body = row[c]
            if body in unique_pages and c == target_col:
                targets.append(unique_pages[body])
            else:
                targets.append(-1)
    return DenseVector(targets)

In [24]:
compute_emb_udf = udf(compute_embeddings, MatrixUDT())
compute_target_udf = udf(compute_target_vector, VectorUDT())

In [25]:
df_emb = df.withColumn("embeddings", compute_emb_udf(col("data")))

In [26]:
df_emb_target = df_emb.withColumn("target", compute_target_udf(col("data")))

In [27]:
df_emb_target.show()

                                                                                

+--------------------+------+--------------------+--------------------+
|                data|userId|          embeddings|              target|
+--------------------+------+--------------------+--------------------+
|[{Coldplay, Logge...|   135|-0.14549641311168...|[-1.0,-1.0,-1.0,-...|
|[{Ratatat, Logged...|   136|-0.12564811110496...|[-1.0,-1.0,-1.0,-...|
+--------------------+------+--------------------+--------------------+



                                                                                

In [28]:
# replace the "data" in db_emb_target with tuple (embeddings, target)
df_emb_target = df_emb_target.drop("data")

In [29]:
from pyspark.sql.functions import struct

In [30]:
df_emb_target = df_emb_target.withColumn("data", struct(df_emb_target.embeddings, df_emb_target.target))

In [31]:
df_emb_target.show()

                                                                                

+------+--------------------+--------------------+--------------------+
|userId|          embeddings|              target|                data|
+------+--------------------+--------------------+--------------------+
|   135|-0.14549641311168...|[-1.0,-1.0,-1.0,-...|{-0.1454964131116...|
|   136|-0.12564811110496...|[-1.0,-1.0,-1.0,-...|{-0.1256481111049...|
+------+--------------------+--------------------+--------------------+



In [32]:
from src.dataset.embeddings import Embeddings

In [33]:
embs = Embeddings(model_name)

In [34]:
df_result = embs.process_df(df, cols, target_col, unique_pages, "data", "embeddings", "target")

In [35]:
from pyspark.testing import assertDataFrameEqual
assertDataFrameEqual(df_result, df_emb_target)

                                                                                

In [36]:
df135 = df_emb_target.filter(df_emb_target.userId == 135)

In [37]:
df135_embeddings = df135.select("embeddings").collect()

                                                                                

In [38]:
truth = df135_embeddings[0].embeddings.toArray()

In [39]:
truth = torch.tensor(truth)

In [40]:
from datasets import Dataset, IterableDataset

In [41]:
ds = Dataset.from_spark(df_emb_target)

                                                                                

In [42]:
type(ds)

datasets.arrow_dataset.Dataset

In [43]:
# iterable_ds = IterableDataset.from_spark(df_emb_target)

In [44]:
# next(iter(iterable_ds))

In [45]:
from src.dataset.thin_wrapper import ThinWrapperDataset

In [46]:
wrap = ThinWrapperDataset(ds, max_seq_len=50)

In [47]:
len(ds)

2

In [48]:
ds_sample = ds[0]

In [49]:
type(ds_sample)

dict

In [50]:
emb, targs = ds_sample["embeddings"], ds_sample["target"]

In [51]:
emb

{'type': 1,
 'numRows': 102,
 'numCols': 384,
 'colPtrs': None,
 'rowIndices': None,
 'values': [-0.1454964131116867,
  0.004952546209096909,
  0.023328552022576332,
  0.009444467723369598,
  -0.014912711456418037,
  0.13468433916568756,
  0.2181049883365631,
  -0.08392201364040375,
  0.1363278329372406,
  -0.008958367630839348,
  -0.11979084461927414,
  -0.039625927805900574,
  0.040641024708747864,
  -0.07973599433898926,
  -0.03676091879606247,
  0.07051996886730194,
  0.005219798535108566,
  0.11572129279375076,
  0.012479012832045555,
  -0.05985037237405777,
  -0.13616269826889038,
  0.021026315167546272,
  -0.14290177822113037,
  0.02199457213282585,
  -0.0058483644388616085,
  0.11823181807994843,
  -0.03524499014019966,
  0.062315937131643295,
  0.01347874104976654,
  -0.06588500738143921,
  0.010380551218986511,
  0.06554959714412689,
  0.026107538491487503,
  -0.012775782495737076,
  0.0005982667207717896,
  -0.003033110871911049,
  -0.045946937054395676,
  0.0298049822449684

In [52]:
type(emb)

dict

In [53]:
type(emb["values"])

list

In [54]:
len(targs["values"])

102

In [55]:
type(targs["values"])

list

In [56]:
from torch.utils.data import DataLoader
from src.preprocess.collate_fn import mat_collate_fn

In [65]:
wrapdl = DataLoader(wrap, batch_size=1, collate_fn=mat_collate_fn)

In [66]:
sample = next(iter(wrapdl))

In [67]:
emb, pos_indices, targets, masks = sample

In [68]:
emb.shape

torch.Size([1, 50, 384])

In [62]:
pos_indices.shape

torch.Size([2, 50])

In [63]:
targets.shape

torch.Size([2, 50])

In [64]:
masks.shape

torch.Size([2, 50, 50])

In [None]:
def collate_fn(x):
    return x

In [None]:
dl = DataLoader(ds, batch_size=2, collate_fn=collate_fn)

In [None]:
sampled = next(iter(dl))

In [None]:
sampled

[{'userId': '135',
  'embeddings': {'type': 1,
   'numRows': 102,
   'numCols': 384,
   'colPtrs': None,
   'rowIndices': None,
   'values': [-0.1454964131116867,
    0.004952546209096909,
    0.023328552022576332,
    0.009444467723369598,
    -0.014912711456418037,
    0.13468433916568756,
    0.2181049883365631,
    -0.08392201364040375,
    0.1363278329372406,
    -0.008958367630839348,
    -0.11979084461927414,
    -0.039625927805900574,
    0.040641024708747864,
    -0.07973599433898926,
    -0.03676091879606247,
    0.07051996886730194,
    0.005219798535108566,
    0.11572129279375076,
    0.012479012832045555,
    -0.05985037237405777,
    -0.13616269826889038,
    0.021026315167546272,
    -0.14290177822113037,
    0.02199457213282585,
    -0.0058483644388616085,
    0.11823181807994843,
    -0.03524499014019966,
    0.062315937131643295,
    0.01347874104976654,
    -0.06588500738143921,
    0.010380551218986511,
    0.06554959714412689,
    0.026107538491487503,
    -0.0127

In [None]:
type(sampled)

list

In [None]:
type(sampled[0])

dict

In [None]:
type(sampled[0]["embeddings"]["values"])

list

In [None]:
rows, cols = sampled[0]["embeddings"]["numRows"], sampled[0]["embeddings"]["numCols"]

In [None]:
actual = torch.tensor(sampled[0]["embeddings"]["values"], dtype=truth.dtype).view(rows, cols)

In [None]:
torch.allclose(actual, torch.tensor(truth))

  torch.allclose(actual, torch.tensor(truth))


True

In [None]:
type(sampled[0]["target"]["values"])

list

In [None]:
from src.preprocess.collate_fn import mat_collate_fn

In [None]:
collated_sample = mat_collate_fn(next(iter(dl)))

  return _nested.nested_tensor(


In [None]:
emb, pos_indices, targets, masks = collated_sample

In [None]:
emb.shape

torch.Size([2, 102, 384])

In [None]:
pos_indices.shape

torch.Size([2, 102])

In [None]:
targets.shape

torch.Size([2, 102])

In [None]:
masks.shape

torch.Size([2, 102, 102])