In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, row_number, abs, count
from pyspark.sql import Window, types

spark_session = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

graph_path = "/data/sample264"

In [None]:
def norm(df, key1, key2, field, n): 
    
    window = Window.partitionBy(key1).orderBy(col(field).desc())
        
    tops_df = df.withColumn('row_number', row_number().over(window))\
        .filter(col('row_number') <= n)\
        .drop(col('row_number')) 
        
    tmp_df = tops_df.groupBy(col(key1)).agg(col(key1), sum(col(field)).alias('sum_' + field))
   
    normalized_df = tops_df.join(tmp_df, key1, 'inner')\
        .withColumn('norm_' + field, col(field) / col('sum_' + field))\
        .cache()

    return normalized_df

In [None]:
graph = spark_session.read.parquet(graph_path)

unnorm = graph\
    .groupBy(col('userId'), col('trackId'))\
    .count()
    
norm = norm(unnorm, 'userId', 'trackId', 'count', 1000)\
    .select(col('userId'), col('trackId'))\
    .orderBy(col('norm_count').desc(), col('userId'), col('trackId'))\
    .limit(40)
    
result = norm.collect()

for user_id, track_id in result:
    print('%s\t%s' % (user_id, track_id))

66	965774
116	867268
128	852564
131	880170
195	946408
215	860111
235	897176
300	857973
321	915545
328	943482
333	818202
346	864911
356	961308
428	943572
431	902497
445	831381
488	841340
542	815388
617	946395
649	901672
658	937522
662	881433
698	935934
708	952432
746	879259
747	879259
776	946408
784	806468
806	866581
811	948017
837	799685
901	871513
923	879322
934	940714
957	945183
989	878364
999	967768
1006	962774
1049	849484
1057	920458
