# Task I - naive recomendation system

* Suppose you will get a new question with some tags and you want to find a list of relevant users that are likely to answer question with these tags.  
* Use the information about users that you saved in ETL-III ntb.
* Create a function that takes as input tags from the new question and returns a list of n relevant users.

In [20]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, explode, count, struct, collect_list, array_sort, reverse, array, lit, desc, broadcast, slice, sum
)

import os

In [2]:
spark = (
    SparkSession
    .builder
    .appName('Analytical app')
    .getOrCreate()
)

In [3]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

questions_input_path = os.path.join(project_path, 'output/questions-transformed')

users_with_tag_output_path = os.path.join(project_path, 'output/users_with_tag')

#answers_input_path = os.path.join(project_path, 'data/answers')

In [11]:
#local_tags = answersDF.select('tags').take(1)[0]['tags']

In [12]:
local_tags = [
  'homework-and-exercises',
  'special-relativity',
  'field-theory',
  'lorentz-symmetry',
  'phase-space'
]

In [4]:
users_with_tags = (    
    spark
    .read
    .option('path', users_with_tag_output_path)
    .load()
).cache()

questionsDF = (
    spark
    .read
    .option('path', questions_input_path)
    .load()
)

In [14]:


question = questionsDF.select('question_id', 'tags').limit(1)

In [24]:
(
    question
    .crossJoin(users_with_tags)
    .withColumnRenamed('tags', 'b')
    .selectExpr(
        'question_id',
        'user_id',
        "FILTER(tag_info, x -> array_contains(b, x.tag)) AS new_tag"
    )
    .withColumn('tag_frequencies', col('new_tag.frequency'))
    .selectExpr(
        'question_id',
        'user_id',
        "AGGREGATE(tag_frequencies, CAST(0 AS long), (value, buffer) -> value + buffer) AS question_relevancy"
    )
    .orderBy(desc('question_relevancy'))
    .withColumn('users', struct('question_relevancy', 'user_id'))
    .groupBy('question_id')
    .agg(collect_list('users').alias('users'))
    .withColumn('users', reverse(array_sort('users')))
    .withColumn('users', col('users.user_id'))
).show(truncate=80)

+-----------+--------------------------------------------------------------------------------+
|question_id|                                                                           users|
+-----------+--------------------------------------------------------------------------------+
|     167813|[1325, 26969, 2451, 104696, 9887, 392, 114696, 26076, 1492, 8563, 50583, 1236...|
+-----------+--------------------------------------------------------------------------------+



In [62]:
def get_relevant_users(users_with_tags, question, n_users):
    return (
        question
        .crossJoin(broadcast(users_with_tags))
        .withColumnRenamed('tags', 'b')
        .selectExpr(
            'question_id',
            'user_id',
            "FILTER(tag_info, x -> array_contains(b, x.tag)) AS new_tag"
        )
        .withColumn('tag_frequencies', col('new_tag.frequency'))
        .selectExpr(
            'question_id',
            'user_id',
            "AGGREGATE(tag_frequencies, CAST(0 AS long), (value, buffer) -> value + buffer) AS question_relevancy"
        )
        .orderBy(desc('question_relevancy'))
        .withColumn('users', struct('question_relevancy', 'user_id'))
        .groupBy('question_id')
        .agg(collect_list('users').alias('users'))
        .withColumn('users', reverse(array_sort('users')))
        .withColumn('users', col('users.user_id'))
        .withColumn('users', slice('users', 1, n_users))
    )

In [63]:
get_relevant_users(users_with_tags, question, 4).show(truncate=70)

+-----------+----------------------------+
|question_id|                       users|
+-----------+----------------------------+
|      32954|[36793, 70392, 21146, 20564]|
|      65848| [21146, 36793, 70392, 4521]|
|     167813|[36793, 112190, 6764, 21146]|
|     121997|[36793, 70392, 20564, 26686]|
|     360185|[112190, 5841, 70392, 21146]|
+-----------+----------------------------+



In [15]:
question.show()

+-----------+--------------------+
|question_id|                tags|
+-----------+--------------------+
|     167813|[homework-and-exe...|
+-----------+--------------------+



In [16]:
users_tag = (
    users_with_tags
    .withColumn('tag_info', explode('tag_info'))
    .withColumn('tag_freq', col('tag_info.frequency'))
    .withColumn('tag', col('tag_info.tag'))
).cache()

In [17]:
users_tag.show(n=5)

+-------+--------------------+--------+----------------+
|user_id|            tag_info|tag_freq|             tag|
+-------+--------------------+--------+----------------+
| 101552|        [1, voltage]|       1|         voltage|
| 101552|          [1, power]|       1|           power|
| 101552|    [1, electricity]|       1|     electricity|
| 101552|[1, electric-curr...|       1|electric-current|
|  16530|        [1, gravity]|       1|         gravity|
+-------+--------------------+--------+----------------+
only showing top 5 rows



In [26]:
(
    question
    .withColumn('tag', explode('tags'))
    .join(users_tag, 'tag')
    .groupBy('question_id', 'user_id')
    .agg(
        sum('tag_freq').alias('relevance')
    )
    .orderBy(desc('relevance'))
    .limit(5)
    .groupBy('question_id')
    .agg(
        collect_list('user_id').alias('users')
    )
).show(truncate=False)

+-----------+---------------------------------+
|question_id|users                            |
+-----------+---------------------------------+
|167813     |[1325, 26969, 2451, 104696, 9887]|
+-----------+---------------------------------+

