In [None]:
# USe this if running inside Databricks
#data = sc.textFile("dbfs:/FileStore/tables/friend_user_ids.txt")
# data = sc.textFile("./friend_user_ids.txt")
from pyspark.sql import SparkSession


spark = SparkSession.builder.appName("MoviePlotSearchEngine").getOrCreate()
spark.sparkContext
data = spark.sparkContext.textFile("./friend_user_ids.txt")


In [10]:
def extract_user_friends(line):
    parts = line.split("\t")
    if len(parts) == 2:
        user_id = parts[0].strip()
        friends_list = parts[1].split(",")
        if user_id.isdigit():
            user_id = int(user_id)
            friends = [int(friend.strip()) for friend in friends_list if friend.strip().isdigit()]
            return user_id, friends
    return None

In [11]:
parsed_users = data.map(extract_user_friends).filter(lambda x: x is not None)
print(parsed_users.take(5)) 

[Stage 2:>                                                          (0 + 1) / 1]

[(0, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]), (1, [0, 5, 20, 135, 2409, 8715, 8932, 10623, 12347, 12846, 13840, 13845, 14005, 20075, 21556, 22939, 23520, 28193, 29724, 29791, 29826, 30691, 31232, 31435, 32317, 32489, 34394, 35589, 35605, 35606, 35613, 35633, 35648, 35678, 38737, 43447, 44846, 44887, 49226, 49985, 623, 629, 4999, 6156, 13912, 14248, 15190, 17636, 19217, 20074, 27536, 29481, 29726, 29767, 30257, 33060, 34250, 34280, 34392, 34406, 34418, 34420, 34439, 34450, 34651, 45054, 49592]), (2, [0, 117, 135, 1220, 2755, 12453, 24539, 24714, 41456, 45046, 49927, 6893, 13795, 16659, 32828, 41878]), (3, [0, 12, 41, 55, 1532, 12636, 13185, 27552, 38737]), (4, [0,

                                                                                

In [12]:
import itertools

def map_friends_to_connections(user_with_friends):
    user_id, friends = user_with_friends
    connections = []

    for friend_id in friends:
        connection = (min(user_id, friend_id), max(user_id, friend_id))
        connections.append((connection, 0))

    for friend_pair in itertools.combinations(friends, 2):
        mutual_connection = (min(friend_pair[0], friend_pair[1]), max(friend_pair[0], friend_pair[1]))
        connections.append((mutual_connection, 1))

    return connections

friend_connections = parsed_users.flatMap(map_friends_to_connections).cache()

mutual_friend_counts = friend_connections.reduceByKey(lambda a, b: a + b).filter(lambda edge: edge[1] > 0)

In [13]:
def generate_recommendations(mutual_friend_info):
    connection, mutual_count = mutual_friend_info
    user_1, user_2 = connection
    return [(user_1, (user_2, mutual_count)), (user_2, (user_1, mutual_count))]

import heapq

def sort_and_truncate_recommendations(recs):
    return heapq.nlargest(10, recs, key=lambda x: (x[1], -x[0]))

# def sort_and_truncate_recommendations(recommendations):
#     if len(recommendations) > 1024:
#         recommendations = sorted(recommendations, key=lambda x: (-x[1], x[0]))[:10]  
#     else:
#         recommendations.sort(key=lambda x: (-x[1], x[0]))  
#     return list(map(lambda x: x[0], recommendations))[:10]  

recommendations = mutual_friend_counts.flatMap(generate_recommendations) \
                                      .groupByKey() \
                                      .map(lambda user_recs: (user_recs[0], sort_and_truncate_recommendations(list(user_recs[1]))))
import random
all_user_ids = recommendations.map(lambda rec: rec[0]).distinct().cache().collect()

                                                                                

In [14]:
random_subset_user_ids = random.sample(all_user_ids, 10)
filtered_recommendations = recommendations.filter(lambda rec: rec[0] in random_subset_user_ids)
formatted_output = filtered_recommendations.map(lambda rec: f"{rec[0]}\t{','.join([str(r[0]) for r in rec[1]])}")

In [15]:
for line in formatted_output.collect():
    print(line)

[Stage 9:>                                                          (0 + 2) / 2]

2692	2637,1688,2647,2659,2689,2691,2666,2667,2668,2677
36488	26735,26736,36486,36483,36487,36489,36495,36497,8956,11903
49230	49187,49200,49201,49206,49209,49211,49213,49217,49222,49228
12270	12248,12263,12278,12262,12265,12271,12279,12282,12290,12292
22711	22705,22706,22707,22713,22718,22708,14596,22712,22714,22715
10359	20997,7029,10332,10362,10435,11162,11250,12250,16624,20966
23661	23660,13998,44891,13966,14044,14078,14130,14138,14188,14055
10773	10794,10800,10784,10810,10811,4828,10788,10813,10805,10807
3775	3772,3761,3766,3767,3770,3773,3762,3765,3768,3771
6891	1918,4755,6318,6883,6884,6885,6886,6887,6888,6889


                                                                                