In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

import os

In [None]:
spark = (
    SparkSession
    .builder
    .appName('Prepare data')
    .config("spark.driver.host","127.0.0.1")
    .enableHiveSupport()
    .getOrCreate()
)

In [None]:
print(spark.version)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-2]) 

users_input_path = os.path.join(project_path, 'data/users')
questions_input_path = os.path.join(project_path, 'data/questions')
answers_input_path = os.path.join(project_path, 'data/answers')

users_table_path = os.path.join(project_path, 'data/tables/users')
questions_table_path = os.path.join(project_path, 'data/tables/questions')
answers_table_path = os.path.join(project_path, 'data/tables/answers')


users_tableB_path = os.path.join(project_path, 'data/tables/usersB')
usersP_output_path =  os.path.join(project_path, 'data/usersP')

questions_tableA_path = os.path.join(project_path, 'data/tables/questionsA')


users_base_path = os.path.join(project_path, 'data/users_base')
users_increment_path = os.path.join(project_path, 'data/users_increment')

sample_data_with_metadata = os.path.join(project_path, 'data/sample_data_with_metadata')

In [None]:
(
    spark.read.parquet(users_input_path)
    .withColumn('user_id', when(col('user_id') == 22, lit(2)).otherwise(col('user_id')))
    .withColumn('last2_id', col('user_id').substr(-2, 100))
    .repartition('last2_id')
    .write
    .option('path', usersP_output_path)
    .mode('overwrite')
    .partitionBy('last2_id')
    .save()
)

In [None]:
(
    spark.read.parquet(users_input_path)
    .write
    .mode('overwrite')
    .option('path', users_table_path)
    .saveAsTable('users')
)

In [None]:
(
    spark.read.parquet(questions_input_path)
    .write
    .mode('overwrite')
    .option('path', questions_table_path)
    .saveAsTable('questions')
)

In [None]:
(
    spark.read.parquet(answers_input_path)
    .write
    .mode('overwrite')
    .option('path', answers_table_path)
    .saveAsTable('answers')
)

In [None]:
(
    spark.table('users')
    .repartition(20, 'user_id')
    .write
    .mode('overwrite')
    .bucketBy(20, 'user_id')
    .option('path', users_tableB_path)
    .saveAsTable('usersB')
)

In [None]:
(
    spark.table('questions')
    .withColumn('user_id', col('user_id').cast('int'))
    .repartition(8)
    .write
    .option('path', questions_tableA_path)
    .mode('overwrite')
    .saveAsTable('questionsA')
)

In [None]:
(
    spark.read.parquet(users_input_path)
    .filter(col('user_id') == 22)
).show()

In [None]:
(
    spark.read.parquet(users_input_path)
    .withColumn('location', lit(None).cast('string'))
    .write
    .option('path', users_base_path)
    .mode('overwrite')
    .save()
)

In [None]:
(
    spark.read.parquet(users_input_path)
    .filter(col('location').isNotNull())
    .write
    .option('path', users_increment_path)
    .mode('overwrite')
    .save()
)

In [None]:
spark.stop()