# Ingest Data Warehouse

### Import


In [28]:
import glob
import polars as pl
import pyspark
import opendatasets as od
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml import Pipeline

from xgboost.spark import SparkXGBClassifier

# from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.feature import (
    StringIndexer,
    OneHotEncoder,
    VectorAssembler,
    Imputer,
)

from pyspark.sql.types import (
    BooleanType,
)


# from PyMovieDb import IMDB
from pyspark.sql import functions as F
import duckdb
import json

In [29]:
print(pyspark.__version__)
print(duckdb.__version__)

3.5.1
0.9.2


### Config


In [30]:
# duckdb_database = "../orchestration/db/bigdata.duckdb"

#### Setting up Cluster Connection


In [31]:
# Connect to Existing Spark Cluster
# spark = (
#     SparkSession.builder.master("spark://spark:7077")
#     .appName("Spark-ETL")
#     .config("spark.sql.debug.maxToStringFields", 1000)
#     .getOrCreate()
# )

# Connect to local Spark Sessions
spark = (
    SparkSession.builder.appName("YourAppName")
    # .config("spark.sql.legacy.timeParserPolicy", "LEGACY")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)

ConnectionRefusedError: [Errno 111] Connection refused

# Add to Data Warehouse

## Initial Data

#### Train


In [None]:
# Get a list of all CSV files that match the pattern
csv_files = glob.glob("../../data/*.csv")

total_rows = 0
imdb_data = None

for file_path in csv_files:
    print("Reading file:", file_path)
    df = (
        spark.read.option("header", "true")
        .option("inferSchema", "true")
        .csv(file_path, nullValue="\\N")
    )
    # df.show(5)
    total_rows += df.count()

    if "label" not in df.columns:
        df = df.withColumn("label", F.lit(None))

    # Rename the column 'primaryTitle' to 'title'
    if "primaryTitle" in df.columns:
        df = df.withColumnRenamed("primaryTitle", "title")

    if imdb_data is None:
        imdb_data = df
    else:
        imdb_data = imdb_data.union(df)

if imdb_data is not None:
    imdb_data = imdb_data.drop("_c0")
    imdb_data.show(5)
    print(f"Counted Rows {total_rows} vs Final DF {imdb_data.count()}")

# Normalize the strings in the 'title' columns
imdb_data = imdb_data.withColumn(
    "title",
    F.lower(F.trim(F.regexp_replace(F.col("title"), "[^a-zA-Z0-9\\s]", ""))),
)
# Normalize the strings in the 'title' columns
imdb_data = imdb_data.withColumn(
    "originalTitle",
    F.lower(F.trim(F.regexp_replace(F.col("originalTitle"), "[^a-zA-Z0-9\\s]", ""))),
)

#### Validation & Test


In [None]:
# Validation Data
validation_hidden = (
    (
        spark.read.option("header", "true")
        .option("inferSchema", "true")
        .csv("../../data/validation_hidden.csv", nullValue="\\N")
    )
    .drop("_c0")
    .select(
        "tconst",
        "primaryTitle",
    )
)
# # Select only tconst column
# validation_hidden = validation_hidden
# Print the dataframe
# validation_hidden.show(5)


test_hidden = (
    (
        spark.read.option("header", "true")
        .option("inferSchema", "true")
        .csv("../../data/test_hidden.csv", nullValue="\\N")
    )
    .drop("_c0")
    .select(
        "tconst",
        "primaryTitle",
    )
)

# Select only tconst column
# test_hidden = test_hidden.select("tconst")
# Print the dataframe
# test_hidden.show(5)

#### Directing


In [None]:
# Using Polars to retrieve the directing data
# Load and parse the JSON file
with open("../../data/directing.json") as f:
    data = json.load(f)

movies_polars_df = pl.from_dict(data["movie"]).transpose().rename({"column_0": "movie"})
directors_polars_df = (
    pl.from_dict(data["director"]).transpose().rename({"column_0": "director"})
)
directing_polars_df = pl.concat(
    [
        movies_polars_df,
        directors_polars_df,
    ],
    how="horizontal",
)
# directing_polars_df.head(5)

#### Writing


In [None]:
with open("../../data/writing.json") as f:
    data = json.load(f)
writing_json = spark.sparkContext.parallelize(data)
writing_spark_df = spark.read.json(writing_json)
# writing_spark_df.show(5)

## Extra Data


In [None]:
imdb_data.show(5)
imdb_data.count()

### Kaggle Data


In [None]:
data_dir = "../../data"


# # 10000 Data About Movies (1915-2023)
# od.download(
#     "https://www.kaggle.com/datasets/willianoliveiragibin/10000-data-about-movies-1915-2023?select=data.csv",
#     data_dir=data_dir,
# )

# # Movie Industry
# od.download("https://www.kaggle.com/datasets/danielgrijalvas/movies", data_dir=data_dir)

# # TMDB 10000 Movie Dataset
# od.download(
#     "https://www.kaggle.com/datasets/muqarrishzaib/tmdb-10000-movies-dataset",
#     data_dir=data_dir,
# )


# Full TMDB Movies Dataset 2024
# od.download(
#     "https://www.kaggle.com/datasets/asaniczka/tmdb-movies-dataset-2023-930k-movies",
#     data_dir=data_dir,
# )

od.download(
    "https://www.kaggle.com/datasets/gayu14/tv-and-movie-metadata-with-genres-and-ratings-imbd",
    data_dir=data_dir,
)

# Golden Globes Data
od.download(
    "https://www.kaggle.com/datasets/unanimad/golden-globe-awards/data",
    data_dir=data_dir,
)

# # Oscar Award Data
od.download(
    "https://www.kaggle.com/datasets/unanimad/the-oscar-award",
    data_dir=data_dir,
)

# FilmTV Movies Dataset
od.download(
    "https://www.kaggle.com/datasets/stefanoleone992/filmtv-movies-dataset",
    data_dir=data_dir,
)

# Rotten Tomatoes Top Movies Ratings and Technical Data
od.download(
    "https://www.kaggle.com/datasets/andrezaza/clapper-massive-rotten-tomatoes-movies-and-reviews",
    data_dir=data_dir,
)

In [None]:
# Get the maximum value
max_value = imdb_data.agg(F.max("startYear")).collect()[0][0]
print("Max startYear:", max_value)

# Get the minimum value
min_value = imdb_data.agg(F.min("startYear")).collect()[0][0]
print("Min startYear:", min_value)

imdb_data = imdb_data.select(
    "tconst",
    "title",
    "originalTitle",
    "startYear",
    "runtimeMinutes",
    "numVotes",
    "label",
)

rename_dict = {}
for col_name in imdb_data.columns:
    rename_dict[col_name] = "imdb_" + col_name

imdb_data_rename = imdb_data.withColumnsRenamed(rename_dict)
imdb_data_rename.show(5)
print(imdb_data_rename.columns)

imdb_data_rename.filter(F.col("title") == "son rise").show()

### TMDB Movies


In [None]:
# tv-and-movie-metadata-with-genres-and-ratings-imbd
# tmdb_movies = (
#     spark.read.option("header", "true")
#     .option("inferSchema", "true")
#     .csv("../../data/tmdb-movies-dataset-2023-930k-movies/TMDB_movie_dataset_v11.csv")
#     .drop("_c0")
# )
# tmdb_movies.show(5)
# print(tmdb_movies.columns)

# result = imdb_data_rename.join(
#     tmdb_movies,
#     imdb_data_rename.imdb_title == tmdb_movies.title,
#     how="left",
# )
# result.show(5)
# result.count()


# result = imdb_data_rename.join(
#     tmdb_movies,
#     (
#         (imdb_data_rename.imdb_title == tmdb_movies.title)
#         | (imdb_data_rename.imdb_originalTitle == tmdb_movies.title)
#     )
#     & (
#         (
#             F.isnull(imdb_data_rename.imdb_startYear)
#             | ~imdb_data_rename(imdb_data_rename.imdb_startYear)
#         )
#         | (
#             (imdb_data_rename.imdb_startYear == tmdb_movies.releaseYearTheaters)
#             | (imdb_data_rename.imdb_startYear == tmdb_movies.releaseYearStreaming)
#         )
#     ),
#     "left",  # or "left", "right", "outer", depending on what kind of join you want
# )


# # Count the occurrences of each row
# row_counts = result.groupby(result.columns).count()

# # Filter the rows that have a count greater than 1
# duplicates = row_counts.filter(F.col("count") > 1)

# duplicates.show(5)

# # Drop duplicates based on the imdb_data_new columns
# result = result.dropDuplicates(subset=imdb_data_rename.columns)

# result.show(5)
# result.count()

### Metadata Ratings IMDB


In [None]:
# # tv-and-movie-metadata-with-genres-and-ratings-imbd
# movie_metadata_rating = (
#     spark.read.option("header", "true")
#     .option("inferSchema", "true")
#     .csv("../../data/tv-and-movie-metadata-with-genres-and-ratings-imbd/IMBD.csv")
#     .drop("_c0")
# )
# movie_metadata_rating.show(5)

# # Extract the numeric part of the runtime string and convert it to integer
# movie_metadata_rating = movie_metadata_rating.withColumn(
#     "runtime", F.regexp_extract(F.col("runtime"), r"(\d+)", 1).cast("integer")
# )
# print(f"Movie Metadata Rating Columsn: {movie_metadata_rating.columns}")

# # Join the dataframes using title and runtime, and also check if a movie name can be joined on imdb_originalTitle if imdb_title is no match
# result = imdb_data_rename.join(
#     movie_metadata_rating,
#     (
#         (imdb_data_rename.imdb_title == movie_metadata_rating.movie)
#         | (imdb_data_rename.imdb_originalTitle == movie_metadata_rating.movie)
#     )
#     & (imdb_data_rename.imdb_runtimeMinutes == movie_metadata_rating.runtime),
#     how="left",
# )
# result.show(5)
# print(f"TOTAL ROWS:{result.count()}")

# # # Drop duplicates based on the imdb_data_new columns
# result = result.dropDuplicates(subset=imdb_data_rename.columns)
# # print(f"Resulting columns: {result.columns}")
# # result.show(5)
# # print(f"TOTAL ROWS FILTERED DUPLICATES:{result.count()}")

# result = result.select(
#     *imdb_data_rename.columns, "genre", "rating", "votes", "stars", "director"
# )
# print(f"Resulting columns: {result.columns}")
# result.show(5)
# print(f"TOTAL ROWS FILTERED DUPLICATES:{result.count()}")

# # Perform a left anti join to get the rows that are not joined
# not_joined = imdb_data_rename.join(
#     movie_metadata_rating,
#     (
#         (imdb_data_rename.imdb_title == movie_metadata_rating.movie)
#         | (imdb_data_rename.imdb_originalTitle == movie_metadata_rating.movie)
#     )
#     & (imdb_data_rename.imdb_runtimeMinutes == movie_metadata_rating.runtime),
#     how="left_anti",
# )
# not_joined.show(5)
# print(not_joined.count())

# # Count the occurrences of each row
# row_counts = result.groupby(result.columns).count()

# # Filter the rows that have a count greater than 1
# duplicates = row_counts.filter(F.col("count") > 1)

# duplicates.show(5)

### FilmtTV Movies Dataset


In [None]:
filmt_tv = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv("../../data/filmtv-movies-dataset/filmtv_movies.csv")
    .drop("_c0")
)
# Normalize the strings in the 'title' columns
filmt_tv = filmt_tv.withColumn(
    "title", F.lower(F.trim(F.regexp_replace(F.col("title"), "[^a-zA-Z0-9\\s]", "")))
)

print(filmt_tv.count())
filmt_tv.show(5)


# Join the dataframes using title and also check if a movie name can be joined on imdb_originalTitle if imdb_title is no match
result = imdb_data_rename.join(
    filmt_tv,
    (
        (imdb_data_rename.imdb_title == filmt_tv.title)
        | (imdb_data_rename.imdb_originalTitle == filmt_tv.title)
    )
    & (imdb_data_rename.imdb_startYear == filmt_tv.year),
    how="left",
)
result.show(5)
print(result.count())

# Drop duplicates based on the imdb_data_new columns
result = result.dropDuplicates(subset=imdb_data_rename.columns).drop(
    "notes", "description", "filmtv_id", "year", "duration"
)
print(f"Resulting columns: {result.columns}")
result.show(5)
print(f"TOTAL ROWS FILTERED DUPLICATES:{result.count()}")

### Rotten Tomatoes


In [None]:
rotten_tomato_reviews_data = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv(
        "../../data/clapper-massive-rotten-tomatoes-movies-and-reviews/rotten_tomatoes_movies.csv"
    )
    .drop("_c0")
)

# Normalize the strings in the 'title' columns
rotten_tomato_reviews_data = rotten_tomato_reviews_data.withColumn(
    "title", F.lower(F.trim(F.regexp_replace(F.col("title"), "[^a-zA-Z0-9\\s]", "")))
)

rotten_tomato_reviews_data.show(5)

In [None]:
# Define a UDF to check if a value is a valid year
def is_valid_year(value):
    if value is None:
        return False
    try:
        year = int(value)
        return 1800 <= year <= 2100
    except ValueError:
        return False


is_valid_year_udf = F.udf(is_valid_year, BooleanType())

In [None]:
selection_rotten_tomato_reviews_data = rotten_tomato_reviews_data.select(
    "title",
    "writer",
    "director",
    "genre",
    "releaseDateTheaters",
    "releaseDateStreaming",
    "audienceScore",
    "tomatoMeter",
    "rating",
    "ratingContents",
)
# Extract the year from the release dates
selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "releaseYearTheaters", F.year(F.col("releaseDateTheaters"))
)
selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "releaseYearStreaming", F.year(F.col("releaseDateStreaming"))
)
# Add a prefix to the column names
for col_name in selection_rotten_tomato_reviews_data.columns:
    selection_rotten_tomato_reviews_data = (
        selection_rotten_tomato_reviews_data.withColumnRenamed(
            col_name, "rtm_" + col_name
        )
    )

# Handle null values in genre and directors fields
result = result.withColumn(
    "genre", F.when(F.col("genre").isNull(), "").otherwise(F.col("genre"))
)
result = result.withColumn(
    "directors", F.when(F.col("directors").isNull(), "").otherwise(F.col("directors"))
)

selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "rtm_genre", F.when(F.col("rtm_genre").isNull(), "").otherwise(F.col("rtm_genre"))
)
selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "rtm_director",
    F.when(F.col("rtm_director").isNull(), "").otherwise(F.col("rtm_director")),
)

# Split the genre and director strings into arrays
result = result.withColumn("genre", F.split(F.col("genre"), ","))
result = result.withColumn("directors", F.split(F.col("directors"), ","))

selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "rtm_genre", F.split(F.col("rtm_genre"), ",")
)
selection_rotten_tomato_reviews_data = selection_rotten_tomato_reviews_data.withColumn(
    "rtm_director", F.split(F.col("rtm_director"), ",")
)

# Join the dataframes
result = result.join(
    selection_rotten_tomato_reviews_data,
    (
        (result.imdb_title == selection_rotten_tomato_reviews_data.rtm_title)
        | (result.imdb_originalTitle == selection_rotten_tomato_reviews_data.rtm_title)
    )
    & (
        (F.isnull(result.imdb_startYear) | ~is_valid_year_udf(result.imdb_startYear))
        | (
            (
                result.imdb_startYear
                == selection_rotten_tomato_reviews_data.rtm_releaseYearTheaters
            )
            | (
                result.imdb_startYear
                == selection_rotten_tomato_reviews_data.rtm_releaseYearStreaming
            )
        )
    ),
    "left",  # or "left", "right", "outer", depending on what kind of join you want
)

# Merge genre and director arrays without creating duplicates
result = result.withColumn(
    "genre",
    F.array_join(F.array_distinct(F.concat(F.col("genre"), F.col("rtm_genre"))), ","),
)

result = result.withColumn(
    "directors",
    F.array_join(
        F.array_distinct(F.concat(F.col("directors"), F.col("rtm_director"))), ","
    ),
)

# Remove leading commas from genre and directors
result = result.withColumn("genre", F.expr("substring(genre, 2)"))
result = result.withColumn("directors", F.expr("substring(directors, 2)"))

# Drop duplicates based on the imdb_data_new columns
result = result.dropDuplicates(subset=imdb_data_rename.columns).drop(
    "rtm_title",
    "rtm_director",
    "rtm_genre",
    "rtm_releaseDateTheaters",
    "rtm_releaseDateStreaming",
    "rtm_releaseYearTheaters",
    "rtm_releaseYearStreaming",
)
print(f"Resulting columns: {result.columns}")
result.show(5)
print(f"TOTAL ROWS FILTERED DUPLICATES:{result.count()}")

In [None]:
result.select("imdb_tconst", "imdb_title", "genre", "directors").where(
    (F.col("genre").isNotNull()) | (F.col("directors").isNotNull())
).show(5)

### Oscars and Golden Globes


In [None]:
# Oscars
oscars = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv("../../data/the-oscar-award/the_oscar_award.csv")
    .drop("_c0")
)
# oscars.show(5)
# Golden Globes
golden_globes = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv("../../data/golden-globe-awards/golden_globe_awards.csv")
    .drop("_c0")
)
# golden_globes.show(5)

# Normalize the strings in the 'title' columns
oscars = oscars.withColumn(
    "title", F.lower(F.trim(F.regexp_replace(F.col("film"), "[^a-zA-Z0-9\\s]", "")))
)

# Normalize the strings in the 'title' columns
golden_globes = golden_globes.withColumn(
    "title", F.lower(F.trim(F.regexp_replace(F.col("film"), "[^a-zA-Z0-9\\s]", "")))
)


# Count the number of Oscars won
oscars_won = (
    oscars.filter(F.col("winner") == True)
    .groupBy("title")
    .count()
    .withColumnRenamed("count", "oscars_won")
)

# Count the number of Golden Globes won
golden_globes_won = (
    golden_globes.filter(F.col("win") == True)
    .groupBy("title")
    .count()
    .withColumnRenamed("count", "golden_globes_won")
)

# Join the result DataFrame with the oscars_won and golden_globes_won DataFrames
result = result.join(oscars_won, result.imdb_title == oscars_won.title, "left").drop(
    "title"
)
result = result.join(
    golden_globes_won, result.imdb_title == golden_globes_won.title, "left"
).drop("title")

# Replace null values with 0
result = result.fillna(0, subset=["oscars_won", "golden_globes_won"])
print("RESULTING TABLE")
result.show(5)

# Filter for movies that won 1 or more Oscars or Golden Globes
awards_selection = result.filter(
    (F.col("oscars_won") >= 1) | (F.col("golden_globes_won") >= 1)
)
print("FILTER ON WON AWARDS")
awards_selection.show(5)

# Selecting


In [None]:
# List of columns to cast
columns_to_cast = [
    "avg_vote",
    "critics_vote",
    "public_vote",
    "total_votes",
    "humor",
    "rhythm",
    "effort",
    "tension",
    "erotism",
    "rtm_audienceScore",
    "rtm_tomatoMeter",
    "rtm_rating",
    "imdb_runtimeMinutes",
    "imdb_numVotes",
]

# Cast each column to float
for column in columns_to_cast:
    result = result.withColumn(column, F.col(column).cast("float"))

# Replace empty strings with NULL in all columns
for column in result.columns:
    result = result.withColumn(
        column, F.when(F.col(column) == "", None).otherwise(F.col(column))
    )

result.show(5)
print(result.columns)
print(result.dtypes)

# Splitting


In [None]:
validation_data = result.join(
    validation_hidden, result.imdb_tconst == validation_hidden.tconst, "inner"
).drop(*validation_hidden.columns)
print(validation_data.count())
validation_data.show(5)

test_data = result.join(
    test_hidden, result.imdb_tconst == test_hidden.tconst, "inner"
).drop(*test_hidden.columns)
test_data.show(5)
print(test_data.count())


# Create the training data by excluding the validation and test data
train_data = result.join(validation_data, ["imdb_tconst"], "left_anti")
train_data = train_data.join(test_data, ["imdb_tconst"], "left_anti")

print(train_data.count())
train_data.show(5)

print("Total Rows:", train_data.count() + validation_data.count() + test_data.count())

In [None]:
def preprocess_datasets(
    datasets: list[DataFrame], train_data: DataFrame
) -> list[DataFrame]:
    # Define the stages of the Pipeline
    stages = [
        StringIndexer(inputCol="actors", outputCol="actorsIndex", handleInvalid="keep"),
        OneHotEncoder(inputCol="actorsIndex", outputCol="actorsVec"),
        StringIndexer(inputCol="country", outputCol="countryIndex", handleInvalid="keep"),
        OneHotEncoder(inputCol="countryIndex", outputCol="countryVec"),
        StringIndexer(inputCol="genre", outputCol="genreIndex", handleInvalid="keep"),
        OneHotEncoder(inputCol="genreIndex", outputCol="genreVec"),
        StringIndexer(
            inputCol="directors", outputCol="directorsIndex", handleInvalid="keep"
        ),
        OneHotEncoder(inputCol="directorsIndex", outputCol="directorsVec"),
        StringIndexer(
            inputCol="rtm_writer", outputCol="writerIndex", handleInvalid="keep"
        ),
        OneHotEncoder(inputCol="writerIndex", outputCol="writerVec"),
        Imputer(
            inputCols=[
                "avg_vote",
                "critics_vote",
                "public_vote",
                "total_votes",
                "humor",
                "rhythm",
                "effort",
                "tension",
                "erotism",
                "rtm_audienceScore",
                "rtm_tomatoMeter",
                "imdb_runtimeMinutes",
                "imdb_numVotes",
            ],
            outputCols=[
                "avg_vote",
                "critics_vote",
                "public_vote",
                "total_votes",
                "humor",
                "rhythm",
                "effort",
                "tension",
                "erotism",
                "rtm_audienceScore",
                "rtm_tomatoMeter",
                "imdb_runtimeMinutes",
                "imdb_numVotes",
            ],
            strategy="median",
        ),
        VectorAssembler(
            inputCols=[
                "genreVec",
                "directorsVec",
                "writerVec",
                "countryVec",
                "actorsVec",
                "avg_vote",
                "critics_vote",
                "public_vote",
                "total_votes",
                "humor",
                "rhythm",
                "effort",
                "tension",
                "erotism",
                "oscars_won",
                "golden_globes_won",
                "rtm_audienceScore",
                "rtm_tomatoMeter",
                "imdb_runtimeMinutes",
                "imdb_numVotes",
            ],
            outputCol="features",
            handleInvalid="keep",
        ),
    ]

    # Define the Pipeline
    pipeline = Pipeline(stages=stages)

    # Fit the Pipeline to the training data
    model = pipeline.fit(train_data)

    # Transform all datasets using the fitted Pipeline
    for i in range(len(datasets)):
        datasets[i] = model.transform(datasets[i])

    # Transform the training data
    train_data = model.transform(train_data)

    return [train_data] + datasets


# Apply the preprocessing function to each dataset
train_data, validation_data, test_data = preprocess_datasets(
    [validation_data, test_data], train_data
)

In [None]:
# XGBOOST CLASSFIER
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Convert boolean to integer
train_data = train_data.withColumn("imdb_label", F.col("imdb_label").cast("integer"))


# Define the model
model = SparkXGBClassifier(
    features_col="features",
    label_col="imdb_label",
    num_workers=5,
    verbosity=1,
)

# Define the parameter grid
param_grid = (
    ParamGridBuilder()
    .addGrid(model.max_depth, [3, 4])
    .addGrid(model.learning_rate, [0.01, 0.1, 0.2])
    .addGrid(model.n_estimators, [100])
    .addGrid(model.colsample_bytree, [0.3, 0.7])
    .build()
)

# Define the evaluator
evaluator = BinaryClassificationEvaluator(labelCol="imdb_label")

# Define the cross-validation
cv = CrossValidator(
    estimator=model,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=2,
)

# Fit the model
cv_model = cv.fit(train_data)

In [None]:
predict_validation = model.transform(validation_data).select(
    "imdb_tconst",
    "imdb_title",
    "rawPrediction",
    "probability",
    F.when(F.col("prediction") == 1, "True").otherwise("False").alias("prediction"),
)

predict_validation.show(5)
predict_validation.select("prediction").coalesce(1).write.mode("overwrite").csv(
    "../../output/validate_predictions.csv",
    header=False,
)

In [32]:
predict_test = model.transform(test_data).select(
    "imdb_tconst",
    "imdb_title",
    "rawPrediction",
    "probability",
    F.when(F.col("prediction") == 1, "True").otherwise("False").alias("prediction"),
)
predict_test.show(5)

predict_test.select("prediction").coalesce(1).write.mode("overwrite").csv(
    "../../output/test_predictions.csv", header=False
)

AttributeError: 'SparkXGBClassifier' object has no attribute 'transform'