In [43]:
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, explode
from pyspark.sql.types import StringType, IntegerType, BooleanType, FloatType, TimestampType

DATA_DIR = Path("data/ml-latest-small")

In [2]:
spark = SparkSession.builder\
    .master("local")\
    .appName("Word Count")\
    .getOrCreate()

In [3]:
# TODO: Spark schemas?
def read_df(file: str, types: dict):
    df = spark.read\
        .format("csv")\
        .option("header", "true")\
        .load(str(DATA_DIR / file))
    return set_dtypes(df, types)

def set_dtypes(df, types: dict):
    for column, dtype in types.items():
        dtype = {
            str: StringType(),
            int: IntegerType(),
            float: FloatType(),
            bool: BooleanType(),
            "time": TimestampType() 
        }[dtype]
        df = df.withColumn(column, col(column).cast(dtype))
    return df

df_links = read_df("links.csv", {
    "movieId": int,
    "imdbId": int,
    "tmdbId": int
})
df_movies = read_df("movies.csv", {
    "movieId": int,
    "title": str,
    "genres": str
})
df_ratings = read_df("ratings.csv", {
    "userId": int,
    "movieId": int,
    "rating": float,
    "timestamp": "time"
})
df_tags = read_df("tags.csv", {
    "userId": int,
    "movieId": int,
    "tag": str,
    "timestamp": "time"
})

In [48]:
df_genres = df_movies.select(df_movies.movieId, split(df_movies.genres, "\|").alias("genres"))
df_genres = df_genres.select(df_genres.movieId, explode(df_genres.genres).alias("genre"))

In [61]:
# Search user by id, show the number of movies/genre that he/she has watched
def search_user(user_id: int):
    rated_movies = df_ratings.where(f"userId = {user_id}").select("movieId").distinct()
    tagged_movies = df_tags.where(f"userId = {user_id}").select("movieId").distinct()
    movies = rated_movies.union(tagged_movies).distinct().join(df_movies, on=["movieId"], how="inner")
    movies = movies.select(movies.movieId, explode(split(movies.genres, "\|")).alias("genre"))
    movies = movies.groupBy("genre").count()
    return movies.collect()
search_user(1)

[Row(genre='Crime', count=45),
 Row(genre='Romance', count=26),
 Row(genre='Thriller', count=55),
 Row(genre='Adventure', count=85),
 Row(genre='Drama', count=68),
 Row(genre='War', count=22),
 Row(genre='Fantasy', count=47),
 Row(genre='Mystery', count=18),
 Row(genre='Musical', count=22),
 Row(genre='Animation', count=29),
 Row(genre='Film-Noir', count=1),
 Row(genre='Horror', count=17),
 Row(genre='Western', count=7),
 Row(genre='Comedy', count=83),
 Row(genre='Children', count=42),
 Row(genre='Action', count=90),
 Row(genre='Sci-Fi', count=40)]