In [1]:
from pathlib import Path

import pandas as pd
import numpy as np
from numpy import nan, ndarray
from pandas import DataFrame
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
    FunctionTransformer,
    MultiLabelBinarizer,
    OneHotEncoder,
)


In [2]:
#DATA_DIR = Path.cwd().resolve().parents[1]

In [3]:
df = pd.read_parquet("data/01_raw/movies_dataset_2025-05-07.parquet")

In [4]:
use_features: list[str] = [
    "original_language",
    "popularity",
    "vote_average",
    "vote_count",
    "is_popular",
    "runtime",
    "budget",
    "revenue",
    "genres",
    "spoken_languages",
]

df = df[use_features]

In [5]:
df.drop_duplicates(
    subset=[
        "original_language",
        "popularity",
        "vote_average",
        "vote_count",
        "is_popular",
        "runtime",
        "budget",
        "revenue",
    ],
    inplace=True,
)

In [6]:
cat_cols = ["original_language"]
df[cat_cols] = df[cat_cols].astype("category")

In [7]:
df["is_popular"] = df["is_popular"].astype("int8")

## Reducing high cardinality
Let's reduce the high cardinality of languages related features.

In [8]:
# ruff: noqa: RUF001
spoken_languages_mappings = {
    "Français": "European (Romance)",
    "Español": "European (Romance)",
    "English": "European (Germanic)",
    "Deutsch": "European (Germanic)",
    "हिन्दी": "South Asian",
    "广州话 / 廣州話": "East Asian",
    "日本語": "East Asian",
    "Italiano": "European (Romance)",
    "Pусский": "European (Slavic)",
    "Nederlands": "European (Germanic)",
    "isiZulu": "African",
    "ภาษาไทย": "Southeast Asian",
    "普通话": "East Asian",
    "Bahasa indonesia": "Southeast Asian",
    "": "Unknown/Other",
    "தமிழ்": "South Asian",
    "suomi": "European (Other)",
    "한국어/조선말": "East Asian",
    "български език": "European (Slavic)",
    "Català": "European (Romance)",
    "Türkçe": "Middle Eastern/Central Asian",
    "Português": "European (Romance)",
    "Norsk": "European (Germanic)",
    "Dansk": "European (Germanic)",
    "svenska": "European (Germanic)",
    "Lietuvių": "European (Other)",
    "Polski": "European (Slavic)",
    "తెలుగు": "South Asian",
    "עִבְרִית": "Middle Eastern/Central Asian",
    "Український": "European (Slavic)",
    "Latin": "European (Other)",  # Could also be considered Historical
    "?????": "Unknown/Other",
    "No Language": "Unknown/Other",
    "اردو": "South Asian",
    "العربية": "Middle Eastern/Central Asian",
    "Română": "European (Romance)",
    "Íslenska": "European (Germanic)",
    "Magyar": "European (Other)",
    "فارسی": "Middle Eastern/Central Asian",
    "Bahasa melayu": "Southeast Asian",
    "Galego": "European (Romance)",
    "ქართული": "European (Other)",  # Kartvelian is a unique family, grouped here for simplicity
    "euskera": "European (Other)",  # Language Isolate, grouped here
    "Èdè Yorùbá": "African",
    "Wolof": "African",
    "Gaeilge": "European (Other)",  # Celtic, grouped here
    "Hrvatski": "European (Slavic)",
    "ελληνικά": "European (Other)",  # Hellenic, grouped here
    "Slovenčina": "European (Slavic)",
    "πੰਜਾਬੀ": "South Asian",
    "Český": "European (Slavic)",
    "Tiếng Việt": "Southeast Asian",
    "Fulfulde": "African",
    "қазақ": "Middle Eastern/Central Asian",
    "Esperanto": "Unknown/Other",  # Constructed language
    "Èʋegbe": "African",
    "বাংলা": "South Asian",
    "پښتو": "Middle Eastern/Central Asian",
    "shqip": "European (Other)",  # Albanian, grouped here
    "Srpski": "European (Slavic)",
    "Afrikaans": "European (Germanic)",
    "Kiswahili": "African",
    "Eesti": "European (Other)",  # Uralic, grouped here
    "Slovenščina": "European (Slavic)",
    "Bamanankan": "African",
    "Azərbaycan": "Middle Eastern/Central Asian",
    "Bosanski": "European (Slavic)",
    "සිංහල": "South Asian",
    "Latviešu": "European (Other)",  # Baltic, grouped here
    "Malti": "Middle Eastern/Central Asian",
    nan: "Unknown/Other",
}

In [9]:
original_language_mappings = {
    "fr": "European (Romance)",
    "es": "European (Romance)",
    "en": "European (Germanic)",
    "te": "South Asian",
    "de": "European (Germanic)",
    "hi": "South Asian",
    "ja": "East Asian",
    "nl": "European (Germanic)",
    "th": "Southeast Asian",
    "id": "Southeast Asian",
    "ht": "European (Romance)",
    "it": "European (Romance)",
    "ta": "South Asian",
    "ml": "South Asian",
    "fi": "European (Other)",
    "ko": "East Asian",
    "bg": "European (Slavic)",
    "ca": "European (Romance)",
    "pt": "European (Romance)",
    "tr": "Middle Eastern/Central Asian",
    "no": "European (Germanic)",
    "tl": "Southeast Asian",
    "da": "European (Germanic)",
    "zu": "African",
    "sv": "European (Germanic)",
    "pl": "European (Slavic)",
    "uk": "European (Slavic)",
    "zh": "East Asian",
    "ru": "European (Slavic)",
    "kn": "South Asian",
    "xx": "Unknown/Other",
    "cn": "East Asian",
    "ar": "Middle Eastern/Central Asian",
    "hu": "European (Other)",
    "fa": "Middle Eastern/Central Asian",
    "mn": "East Asian",
    "yo": "African",
    "ro": "European (Romance)",
    "sk": "European (Slavic)",
    "jv": "Southeast Asian",
    "cs": "European (Slavic)",
    "ur": "South Asian",
    "pa": "South Asian",
    "is": "European (Germanic)",
    "hr": "European (Slavic)",
    "vi": "Southeast Asian",
    "lv": "European (Other)",
    "km": "Southeast Asian",
    "ms": "Southeast Asian",
    "kk": "Middle Eastern/Central Asian",
    "ka": "European (Other)",
    "ga": "European (Other)",
    "xh": "African",
    "ig": "African",
    "el": "European (Other)",
    "bn": "South Asian",
    "tt": "Middle Eastern/Central Asian",
    "gl": "European (Romance)",
    "mk": "European (Slavic)",
    "bo": "East Asian",
    "dz": "South Asian",
    "he": "Middle Eastern/Central Asian",
    "sr": "European (Slavic)",
    "ff": "African",
    "gu": "South Asian",
    "ab": "European (Other)",
    "et": "European (Other)",
    "kl": "European (Other)",
    "lt": "European (Other)",
    "se": "European (Other)",
    "eu": "European (Other)",
    "bs": "European (Slavic)",
    "lb": "European (Germanic)",
    "mi": "Southeast Asian",
    "hy": "Middle Eastern/Central Asian",
    "su": "Southeast Asian",
    "mt": "Middle Eastern/Central Asian",
    "sl": "European (Slavic)",
}

In [10]:
df["genres"].iloc[0]

array(['Acción', 'Crimen', 'Suspense'], dtype=object)

In [11]:
def get_features_names(_, feature_names) -> ndarray:
    return feature_names


def map_lang(X: DataFrame, col: str, mappings: dict[str, str]) -> DataFrame:
    """
    Map language codes to broader categories.
    """
    X = X.copy()
    if X[col].dtype.name == "object":  # if it's a list
        X[col] = X[col].apply(
            lambda x: [mappings.get(item, "Unknown/Other") for item in x]
        )
        return X
    X[col] = X[col].map(mappings)
    return X

In [12]:
class MultiLabelBinarizerTransformer(BaseEstimator, TransformerMixin):
    """A custom transformer to apply MultiLabelBinarizer within a scikit-learn pipeline.

    This transformer is designed to be used with `ColumnTransformer` on a single
    column of a pandas DataFrame that contains lists of labels (multi-label data).
    It wraps the functionality of `sklearn.preprocessing.MultiLabelBinarizer` and
    provides a `get_feature_names_out` method compatible with scikit-learn pipelines.
    """

    def __init__(self):
        """Initializes the MultiLabelBinarizerTransformer."""
        # Initialize MultiLabelBinarizer here, but don't fit it yet
        self.mlb = MultiLabelBinarizer()

    def fit(self, X, y=None):
        """Fits the MultiLabelBinarizer on the input data.

        Args:
            X: A pandas DataFrame slice with one column containing lists of labels.
            y: Ignored.

        Returns:
            self: Returns the instance itself.
        """
        # X will be a DataFrame slice with one column (e.g., df[['genres']])
        # Fit MultiLabelBinarizer on the values of that column
        self.mlb.fit(X.iloc[:, 0])
        return self

    def transform(self, X):
        """Transforms the input data using the fitted MultiLabelBinarizer.

        Args:
            X: A pandas DataFrame slice with one column containing lists of labels.

        Returns:
            numpy.ndarray: A sparse matrix representing the binarized labels.
        """
        # Transform the values of the column
        return self.mlb.transform(X.iloc[:, 0])

    def get_feature_names_out(self, input_features=None):
        """Gets the output feature names after binarization.

        Args:
            input_features: Ignored.

        Returns:
            list: A list of strings representing the output feature names (the labels).
        """
        # Return the classes learned by the fitted MultiLabelBinarizer
        return self.mlb.classes_.tolist()

In [13]:
num_cols: list[str] = [
    "popularity",
    "vote_average",
    "vote_count",
    "runtime",
    "budget",
    "revenue",
    "is_popular",
]
cat_cols: list[str] = ["original_language"]
multi_label_cat_cols: list[str] = ["genres", "spoken_languages"]

In [14]:
num_pipe = Pipeline(steps=[("imputer", SimpleImputer(strategy="median"))])
cat_pipe = Pipeline(
    steps=[
        (
            "language mapper",
            FunctionTransformer(
                map_lang,
                kw_args={
                    "col": "original_language",
                    "mappings": original_language_mappings,
                },
                feature_names_out=get_features_names,
            ),
        ),
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("one-hot", OneHotEncoder(drop="first")),
    ]
)

multi_label_genres_pipe = Pipeline(
    steps = [("binarizer", MultiLabelBinarizerTransformer())]
)

multi_label_spoken_languages_pipe = Pipeline(
    steps=[
        (
            "language mapper",
            FunctionTransformer(
                map_lang,
                kw_args={
                    "col": "spoken_languages",
                    "mappings": spoken_languages_mappings,
                },
                feature_names_out=get_features_names,
            ),
        ),
        ("binarizer", MultiLabelBinarizerTransformer()),
    ]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_pipe, num_cols),
        ("cat", cat_pipe, cat_cols),
        ("genres", multi_label_genres_pipe, ["genres"]),
        ("spoken_languages", multi_label_spoken_languages_pipe, ["spoken_languages"]),
    ],
)

In [15]:
preprocessor

In [20]:
preprocessed = preprocessor.fit_transform(df)
feature_names = preprocessor.get_feature_names_out()
preprocessed = pd.DataFrame(preprocessed, columns=feature_names)

In [21]:
preprocessed

Unnamed: 0,num__popularity,num__vote_average,num__vote_count,num__runtime,num__budget,num__revenue,num__is_popular,cat__original_language_East Asian,cat__original_language_European (Germanic),cat__original_language_European (Other),...,spoken_languages__African,spoken_languages__East Asian,spoken_languages__European (Germanic),spoken_languages__European (Other),spoken_languages__European (Romance),spoken_languages__European (Slavic),spoken_languages__Middle Eastern/Central Asian,spoken_languages__South Asian,spoken_languages__Southeast Asian,spoken_languages__Unknown/Other
0,219.2462,8.100,10.0,112.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
1,98.7524,5.100,12.0,91.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
2,18.2723,7.100,10.0,101.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,9.9606,8.800,10.0,157.0,7000000.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,599.2458,6.735,213.0,109.0,0.0,0.0,1.0,0.0,1.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7995,0.2817,5.000,19.0,89.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7996,0.4137,6.700,23.0,84.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7997,0.2691,5.691,47.0,90.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
7998,0.6625,6.571,85.0,112.0,5800000.0,2440653.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0


In [22]:
df.head()

Unnamed: 0,original_language,popularity,vote_average,vote_count,is_popular,runtime,budget,revenue,genres,spoken_languages
0,fr,219.2462,8.1,10,1,112,0,0,"[Acción, Crimen, Suspense]",[Français]
1,es,98.7524,5.1,12,1,91,0,0,[Comedia],[Español]
2,en,18.2723,7.1,10,0,101,0,0,"[Drama, Crimen]",[English]
3,te,9.9606,8.8,10,0,157,7000000,0,"[Crimen, Suspense, Acción]",[]
4,de,599.2458,6.735,213,1,109,0,0,"[Suspense, Acción]","[Deutsch, English]"
