In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, row_number, abs, count, lit, when
from pyspark.sql import Window, types

spark_session = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

data_path = '/data/sample264'
meta_path = '/data/meta'

user_id = 776748

steps = 5

alpha             = 0.15
beta_user_track   = 0.5
beta_user_artist  = 0.5
beta_artist_track = 1.0
beta_track_track  = 1.0

data = spark_session.read.parquet(data_path)
copy = spark_session.read.parquet(data_path)

meta  = spark_session.read.parquet(meta_path)

In [2]:
def norm(df, key1, key2, field, n): 
    
    window = Window.partitionBy(key1).orderBy(col(field).desc())
        
    tops_df = df.withColumn('row_number', row_number().over(window))\
        .filter(col('row_number') <= n)\
        .drop(col('row_number')) 
        
    tmp_df = tops_df.groupBy(col(key1)).agg(col(key1), sum(col(field)).alias('sum_' + field))
   
    normalized_df = tops_df.join(tmp_df, key1, 'inner')\
        .withColumn('norm_' + field, col(field) / col('sum_' + field))\
        .cache()

    return normalized_df

In [3]:
unnorm_tt_edges = data\
    .join(copy, copy.userId == data.userId, 'inner')\
    .where(abs(data.timestamp - copy.timestamp) <= 420)\
    .where(data.trackId != copy.trackId)\
    .select(data.trackId.alias('source'), copy.trackId.alias('dest'))\
    .groupBy(col('source'), col('dest'))\
    .count()
    
norm_tt_edges = norm(unnorm_tt_edges, 'source', 'dest', 'count', 1000)\
    .withColumn('nxt', col('norm_count') * beta_track_track)\
    .withColumn('source_type', lit('t'))\
    .withColumn('dest_type', lit('t'))\
    .select(col('source'), col('dest'), col('source_type'), col('dest_type'), col('nxt'))

In [4]:
unnorm_ut_edges = data\
    .select(col('userId').alias('source'), col('trackId').alias('dest'))\
    .groupBy(col('source'), col('dest'))\
    .count()
    
norm_ut_edges = norm(unnorm_ut_edges, 'source', 'dest', 'count', 1000)\
    .withColumn('nxt', col('norm_count') * beta_user_track)\
    .withColumn('source_type', lit('u'))\
    .withColumn('dest_type', lit('t'))\
    .select(col('source'), col('dest'), col('source_type'), col('dest_type'), col('nxt'))

In [5]:
unnorm_ua_edges = data\
    .select(col('userId').alias('source'), col('artistId').alias('dest'))\
    .groupBy(col('source'), col('dest'))\
    .count()
    
norm_ua_edges = norm(unnorm_ua_edges, 'source', 'dest', 'count', 100)\
    .withColumn('nxt', col('norm_count') * beta_user_artist)\
    .withColumn('source_type', lit('u'))\
    .withColumn('dest_type', lit('a'))\
    .select(col('source'), col('dest'), col('source_type'), col('dest_type'), col('nxt'))

In [6]:
unnorm_at_edges = data\
    .select(col('artistId').alias('source'), col('trackId').alias('dest'))\
    .groupBy(col('source'), col('dest'))\
    .count()
    
norm_at_edges = norm(unnorm_at_edges, 'source', 'dest', 'count', 100)\
    .withColumn('nxt', col('norm_count') * beta_artist_track)\
    .withColumn('source_type', lit('a'))\
    .withColumn('dest_type', lit('t'))\
    .select(col('source'), col('dest'), col('source_type'), col('dest_type'), col('nxt'))

In [7]:
norm_edges = norm_tt_edges.union(norm_ut_edges).union(norm_ua_edges).union(norm_at_edges).cache()

In [8]:
user_tracks = data\
    .where(col('userId') == user_id)\
    .groupBy(col('trackId'))\
    .agg(col('trackId'))\
    .withColumn('type', lit('t'))\
    .select(col('trackId').alias('id'), col('type'))
    
user_artists = data\
    .where(col('userId') == user_id)\
    .groupBy(col('artistId'))\
    .agg(col('artistId'))\
    .withColumn('type', lit('a'))\
    .select(col('artistId').alias('id'), col('type'))

In [9]:
user_x = data\
    .groupBy(col('userId'))\
    .agg(col('userId').alias('id'))\
    .withColumn('type', lit('u'))\
    .withColumn('v', when(col('id') == user_id, 1.0).otherwise(0.0))\
    .select(col('id'), col('type'), col('v'))

In [10]:
artist_x = data\
    .groupBy(col('artistId'))\
    .agg(col('artistId'))\
    .join(user_artists, data.artistId == user_artists.id, how = 'left')\
    .withColumn('v', when(user_artists.id.isNull(), 0.0).otherwise(1.0))\
    .select(data.artistId.alias('id'), col('type'), col('v'))

In [11]:
track_x = data\
    .groupBy(col('trackId'))\
    .agg(col('trackId'))\
    .join(user_tracks, col('trackId') == user_tracks.id, 'left')\
    .withColumn('v', when(user_tracks.id.isNull(), 0.0).otherwise(1.0))\
    .select(col('trackId').alias('id'), col('type'), col('v'))

In [12]:
x = user_x.union(artist_x).union(track_x).cache()

In [13]:
for _ in range(steps):
    s = x\
        .join(norm_edges, (x.id == norm_edges.source) & (x.type == norm_edges.source_type), 'inner')\
        .withColumn('nxt', x.v * col('nxt'))\
        .groupBy(col('dest').alias('s_id'), col('dest_type').alias('s_type'))\
        .agg(sum(col('nxt')).alias('sum_nxt'))\
        .withColumn('sum_nxt', (1 - alpha) * col('sum_nxt'))\
        .select(col('s_id'), col('s_type'), col('sum_nxt').alias('s_v')).cache()
    
    x = x\
        .join(s, (x.id == s.s_id) & (x.type == s.s_type), 'left')\
        .withColumn('id', when(s.s_id.isNull(), x.id).otherwise(s.s_id))\
        .withColumn('type', when(s.s_type.isNull(), x.type).otherwise(s.s_type))\
        .withColumn('v', when(s.s_v.isNull(), x.v).otherwise(s.s_v))\
        .withColumn('v', when((col('id') == user_id) & (col('type') == 'u'), alpha + col('v')).otherwise(col('v')))\
        .select(col('id'), col('type'), col('v')).cache()

In [14]:
result = x\
    .where(col('type') != 'u')\
    .orderBy(col('v').desc())\
    .limit(40)\
    .select(col('id'), col('type'), col('v'))\
    .join(meta, x.id == meta.Id, 'inner')\
    .select(col('Name'), col('Artist'), col('v'))

for name, artist, v in result.collect():
    print('%s\t%s\t%s' % (name, artist, v))

Kill The DJ	Artist: Green Day	1.72376154141
Come Out and Play	Artist: The Offspring	1.65353255809
I Hate Everything About You	Artist: Three Days Grace	1.51685667824
Prayer Of The Refugee	Artist: Rise Against	1.49739596049
21 Guns	Artist: Green Day	1.42402966101
Eagle	Artist: Gotthard	1.35873805033
Beautiful disaster	Artist: 311	1.02079963919
Wait And Bleed	Artist: Slipknot	1.02079963919
Here To Stay	Artist: Korn	1.01963284722
Hard Rock Hallelujah	Artist: Lordi	1.01963284722
Nothing Going On	Artist: Clawfinger	0.988808012274
In The End	Artist: Linkin Park	0.902170093838
Numb	Artist: Linkin Park	0.902170093838
Sky is Over	Artist: Serj Tankian	0.771020931713
Kryptonite	Artist: 3 Doors Down	0.771020931713
Take It Out On Me	Artist: Thousand Foot Krutch	0.539536150013
Girls and Boys	Artist: Blur	0.471747838381
Cocaine	Artist: Nomy	0.274487103787
Getting Away With Murder	Artist: Papa Roach	0.270584005378
Artist: Green Day	Artist: Green Day	0.125925925926
Artist: Linkin Park	Artist: Linkin Par