In [1]:
import pyspark as ps    # for the pyspark suite
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import IntegerType, StringType, FloatType, DateType, TimestampType
import pyspark.sql.functions as F
import re
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

spark = ps.sql.SparkSession.builder \
            .master("local[4]") \
            .appName("df lecture") \
            .getOrCreate()

sc = spark.sparkContext  # for the pre-2.0 sparkContext

## 1. Reading ratings from `ratings.dat` (obsolete)

In [2]:
ratings_raw = sc.textFile("data/ratings.dat")

def ratings_cleaning_function(inputstring):
    """ to be used inside flatMap, returns a list """
    user_id, movie_id, rating, timestamp = inputstring.split("::")
    return((int(user_id), int(movie_id), int(rating), int(timestamp)))

ratings_clean = ratings_raw.map(ratings_cleaning_function)

ratings_schema = StructType( [
    StructField('user_id',IntegerType(),True),
    StructField('movie_id',IntegerType(),True),
    StructField('rating',IntegerType(),True),
    StructField('timestamp',IntegerType(),True) ] )

ratings = spark.createDataFrame(ratings_clean, ratings_schema)

ratings.repartition(20,F.col('user_id'))

ratings.show(5)
ratings.printSchema()

print(ratings.count())

+-------+--------+------+---------+
|user_id|movie_id|rating|timestamp|
+-------+--------+------+---------+
|      1|    1193|     5|978300760|
|      1|     661|     3|978302109|
|      1|     914|     3|978301968|
|      1|    3408|     4|978300275|
|      1|    2355|     5|978824291|
+-------+--------+------+---------+
only showing top 5 rows

root
 |-- user_id: integer (nullable = true)
 |-- movie_id: integer (nullable = true)
 |-- rating: integer (nullable = true)
 |-- timestamp: integer (nullable = true)

1000209


**Just a quick check, YES each rating is unique.**

## 2. Creating Training/Testing sets and request file

In [3]:
training = ratings.sort('timestamp', ascending=True).limit(800000)

testing = ratings.sort('timestamp', ascending=False).limit(200209)

In [4]:
#testing.sort(F.col('user_id'),F.col('movie_id')).show(20)

+-------+--------+------+---------+
|user_id|movie_id|rating|timestamp|
+-------+--------+------+---------+
|      1|       1|     5|978824268|
|      1|      48|     5|978824351|
|      1|     150|     5|978301777|
|      1|     260|     4|978300760|
|      1|     527|     5|978824195|
|      1|     531|     4|978302149|
|      1|     588|     4|978824268|
|      1|     594|     4|978302268|
|      1|     595|     5|978824268|
|      1|     608|     4|978301398|
|      1|     661|     3|978302109|
|      1|     720|     3|978300760|
|      1|     745|     3|978824268|
|      1|     783|     4|978824291|
|      1|     914|     3|978301968|
|      1|     919|     4|978301368|
|      1|     938|     4|978301752|
|      1|    1022|     5|978300055|
|      1|    1028|     5|978301777|
|      1|    1029|     5|978302205|
+-------+--------+------+---------+
only showing top 20 rows



Restrict testing to users we know about

## 3. Writing `training.csv` and `testing.csv`

In [7]:
with open('data/training.csv', 'w') as trainingfile:
    trainingfile.write("user,movie,rating,timestamp\n")
    for row in training.collect():
        trainingfile.write("{},{},{},{}\n".format(row['user_id'],
                                                   row['movie_id'],
                                                   row['rating'],
                                                   row['timestamp']))

with open('data/testing.csv', 'w') as testingfile:
    with open('data/requests.csv', 'w') as requestfile:
        testingfile.write("user,movie,actualrating\n")
        requestfile.write("user,movie\n")
        for row in testing.collect():
            testingfile.write("{},{},{}\n".format(row['user_id'],
                                                       row['movie_id'],
                                                       row['rating']))
            requestfile.write("{},{}\n".format(row['user_id'],
                                               row['movie_id']))