In [1]:
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, explode
from pyspark.sql.types import StringType, IntegerType, BooleanType, FloatType, TimestampType

DATA_DIR = Path("data/ml-latest-small")

In [2]:
spark = SparkSession.builder\
    .master("local")\
    .appName("Word Count")\
    .getOrCreate()

In [3]:
# TODO: Spark schemas?
def read_df(file: str, types: dict):
    df = spark.read\
        .format("csv")\
        .option("header", "true")\
        .load(str(DATA_DIR / file))
    return set_dtypes(df, types)

def set_dtypes(df, types: dict):
    for column, dtype in types.items():
        dtype = {
            str: StringType(),
            int: IntegerType(),
            float: FloatType(),
            bool: BooleanType(),
            "time": TimestampType() 
        }[dtype]
        df = df.withColumn(column, col(column).cast(dtype))
    return df

df_links = read_df("links.csv", {
    "movieId": int,
    "imdbId": int,
    "tmdbId": int
})
df_movies = read_df("movies.csv", {
    "movieId": int,
    "title": str,
    "genres": str
})
df_ratings = read_df("ratings.csv", {
    "userId": int,
    "movieId": int,
    "rating": float,
    "timestamp": "time"
})
df_tags = read_df("tags.csv", {
    "userId": int,
    "movieId": int,
    "tag": str,
    "timestamp": "time"
})

In [7]:
# Given a user, get the number of movies watched per genre
def search_user(user_id: int):
    rated_movies = df_ratings.filter(df_ratings.userId == user_id).select("movieId").distinct()
    tagged_movies = df_tags.filter(df_tags.userId == user_id).select("movieId").distinct()
    movies = rated_movies.union(tagged_movies).distinct().join(df_movies, on=["movieId"], how="inner")
    movies = movies.select(movies.movieId, explode(split(movies.genres, "\|")).alias("genre"))
    movies = movies.groupBy("genre").count()
    return movies.collect()
search_user(1)

[Row(genre='Crime', count=45),
 Row(genre='Romance', count=26),
 Row(genre='Thriller', count=55),
 Row(genre='Adventure', count=85),
 Row(genre='Drama', count=68),
 Row(genre='War', count=22),
 Row(genre='Fantasy', count=47),
 Row(genre='Mystery', count=18),
 Row(genre='Musical', count=22),
 Row(genre='Animation', count=29),
 Row(genre='Film-Noir', count=1),
 Row(genre='Horror', count=17),
 Row(genre='Western', count=7),
 Row(genre='Comedy', count=83),
 Row(genre='Children', count=42),
 Row(genre='Action', count=90),
 Row(genre='Sci-Fi', count=40)]

In [24]:
df_ratings.filter("userId = 21").filter("movieId=1").collect()

[Row(userId=21, movieId=1, rating=3.5, timestamp=None)]

In [28]:
# Given a list of users, search all movies watched by each user
def search_movies_by_users(user_ids: [str]):
    rated_movies = df_ratings.filter(df_ratings.userId.isin(user_ids))
    tagged_movies = df_tags.filter(df_tags.userId.isin(user_ids))
    movies = rated_movies.join(tagged_movies, on=["userId", "movieId"], how="outer")
    movies = movies.select("userId", "movieId").distinct()
    movies = movies.groupBy("movieId").count().join(df_movies, on=["movieId"]).select("title", "count")
    return movies.collect()
search_movies_by_users([1, 21])

[Row(title='Toy Story (1995)', count=2)]

In [9]:
# Search movie by id/title, show the average rating, the number of users that have watched the movie
def search_movies_watched_by_id(movie_id: int):
    rated_movies = df_ratings.where(f"movieId = {movie_id}").select("movieId")
    tagged_movies = df_tags.where(f"movieId = {movie_id}").select("movieId")
    movies = rated_movies.union(tagged_movies).join(df_movies, on=["movieId"], how="inner")
    watched = movies.count()
    return watched
search_movies_watched_by_id(1) # Toy Story

218

In [44]:
# Search movie by id/title, show the average rating, the number of users that have watched the movie
def search_movies_avg_rating_by_id(movie_id: int):
    rated_movies = df_ratings.where(f"movieId = {movie_id}").select("rating")
    return rated_movies.collect()
search_movies_avg_rating_by_id(1) # Toy Story

[Row(rating=4.0),
 Row(rating=4.0),
 Row(rating=4.5),
 Row(rating=2.5),
 Row(rating=4.5),
 Row(rating=3.5),
 Row(rating=4.0),
 Row(rating=3.5),
 Row(rating=3.0),
 Row(rating=5.0),
 Row(rating=3.0),
 Row(rating=3.0),
 Row(rating=5.0),
 Row(rating=5.0),
 Row(rating=3.0),
 Row(rating=4.0),
 Row(rating=5.0),
 Row(rating=3.0),
 Row(rating=3.0),
 Row(rating=5.0),
 Row(rating=5.0),
 Row(rating=4.0),
 Row(rating=4.0),
 Row(rating=2.5),
 Row(rating=5.0),
 Row(rating=4.5),
 Row(rating=0.5),
 Row(rating=4.0),
 Row(rating=2.5),
 Row(rating=4.0),
 Row(rating=3.0),
 Row(rating=3.0),
 Row(rating=4.0),
 Row(rating=3.0),
 Row(rating=5.0),
 Row(rating=4.5),
 Row(rating=4.0),
 Row(rating=4.0),
 Row(rating=3.0),
 Row(rating=3.5),
 Row(rating=4.0),
 Row(rating=4.0),
 Row(rating=3.0),
 Row(rating=2.0),
 Row(rating=3.0),
 Row(rating=4.0),
 Row(rating=4.0),
 Row(rating=3.0),
 Row(rating=4.0),
 Row(rating=3.5),
 Row(rating=5.0),
 Row(rating=5.0),
 Row(rating=2.0),
 Row(rating=3.0),
 Row(rating=4.0),
 Row(ratin