In [19]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [20]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from xgboost.sklearn import XGBRegressor
from sklearn.metrics import mean_squared_error

In [32]:
data = pd.read_csv("data/cleaned_data.csv")
data
data['gender'].value_counts()

2.0    135203
0.0     67609
1.0     62878
3.0        33
Name: gender, dtype: int64

In [22]:
def parse_features(cleaned_data_top_3_subset_df):
    cleaned_data_top_3_subset_df_copy = cleaned_data_top_3_subset_df.copy()
    genere_by_position = {0: 'Adventure',
     1: 'Animation',
     2: 'Children',
     3: 'Comedy',
     4: 'Fantasy',
     5: 'Romance',
     6: 'Drama',
     7: 'Action',
     8: 'Crime',
     9: 'Thriller',
     10: 'Horror',
     11: 'Mystery',
     12: 'Sci-Fi',
     13: 'War',
     14: 'Musical',
     15: 'Documentary',
     16: 'IMAX',
     17: 'Western',
     18: 'Film-Noir'
    }
    position_by_genere = {'Adventure': 0,
 'Animation': 1,
 'Children': 2,
 'Comedy': 3,
 'Fantasy': 4,
 'Romance': 5,
 'Drama': 6,
 'Action': 7,
 'Crime': 8,
 'Thriller': 9,
 'Horror': 10,
 'Mystery': 11,
 'Sci-Fi': 12,
 'War': 13,
 'Musical': 14,
 'Documentary': 15,
 'IMAX': 16,
 'Western': 17,
 'Film-Noir': 18}
    def arr_to_cols(row):
        genres = set(row['genres'].split('|'))
        for pos in range(19):
            col_name = genere_by_position[pos]
            row[col_name] = 1 if col_name in genres else 0
        return row
    return cleaned_data_top_3_subset_df_copy.apply(arr_to_cols, axis=1)

In [23]:
data1 = data.groupby('movieId').head(1)
data1 = data1[data1['year'] >= 1931]
data1 = data1.head(1000)
data1 = parse_features(data1)
data1

Unnamed: 0.1,Unnamed: 0,gender,name,movieId,title,genres,genres_vectorized,genres_list,year,budget,...,Thriller,Horror,Mystery,Sci-Fi,War,Musical,Documentary,IMAX,Western,Film-Noir
0,0,2.0,Tom Hanks,1.0,Toy Story,Adventure|Animation|Children|Comedy|Fantasy,[1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Adventure', 'Animation', 'Children', 'Comedy...",1995.0,30000000,...,0,0,0,0,0,0,0,0,0,0
41,41,2.0,Robin Williams,2.0,Jumanji,Adventure|Children|Fantasy,[1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Adventure', 'Children', 'Fantasy']",1995.0,65000000,...,0,0,0,0,0,0,0,0,0,0
73,73,2.0,Walter Matthau,3.0,Grumpier Old Men,Comedy|Romance,[0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Comedy', 'Romance']",1995.0,0,...,0,0,0,0,0,0,0,0,0,0
81,81,1.0,Whitney Houston,4.0,Waiting to Exhale,Comedy|Drama|Romance,[0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Comedy', 'Drama', 'Romance']",1995.0,16000000,...,0,0,0,0,0,0,0,0,0,0
110,110,2.0,Steve Martin,5.0,Father of the Bride Part II,Comedy,[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...,['Comedy'],1995.0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28866,28866,2.0,Kevin Costner,1302.0,Field of Dreams,Children|Drama|Fantasy,[0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Children', 'Drama', 'Fantasy']",1989.0,0,...,0,0,0,0,0,0,0,0,0,0
28897,28897,2.0,Sean Connery,1303.0,"Man Who Would Be King, The",Adventure|Drama,[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Adventure', 'Drama']",1975.0,8000000,...,0,0,0,0,0,0,0,0,0,0
28909,28909,2.0,Paul Newman,1304.0,Butch Cassidy and the Sundance Kid,Action|Western,[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. ...,"['Action', 'Western']",1969.0,6000000,...,0,0,0,0,0,0,0,0,1,0
28931,28931,2.0,Harry Dean Stanton,1305.0,"Paris, Texas",Drama|Romance,[0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. ...,"['Drama', 'Romance']",1984.0,1746964,...,0,0,0,0,0,0,0,0,0,0


In [31]:
data1['gender'].value_counts()

2.0    727
1.0    257
0.0     16
Name: gender, dtype: int64

In [24]:
data1.columns

Index(['Unnamed: 0', 'gender', 'name', 'movieId', 'title', 'genres',
       'genres_vectorized', 'genres_list', 'year', 'budget', 'runtime',
       'rating', 'Adventure', 'Animation', 'Children', 'Comedy', 'Fantasy',
       'Romance', 'Drama', 'Action', 'Crime', 'Thriller', 'Horror', 'Mystery',
       'Sci-Fi', 'War', 'Musical', 'Documentary', 'IMAX', 'Western',
       'Film-Noir'],
      dtype='object')

### Split Data

In [25]:
X = data1[['gender', 'year', 'budget', 'runtime', 'Adventure', 'Animation', 'Children', 'Comedy', 
                 'Fantasy', 'Romance', 'Drama', 'Action', 'Crime', 'Thriller', 'Horror', 'Mystery', 'Sci-Fi',
                 'War', 'Musical', 'Documentary', 'IMAX', 'Western', 'Film-Noir']]
y = data1['rating']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25)

### Train

In [26]:
xgb_model = XGBRegressor(n_estimators=1000, max_depth=10, learning_rate=0.001, random_state=0)
xgb_model.fit(X_train, y_train)

XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
             colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
             importance_type='gain', interaction_constraints='',
             learning_rate=0.001, max_delta_step=0, max_depth=10,
             min_child_weight=1, missing=nan, monotone_constraints='()',
             n_estimators=1000, n_jobs=0, num_parallel_tree=1, random_state=0,
             reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
             tree_method='exact', validate_parameters=1, verbosity=None)

In [29]:
y_predict = xgb_model.predict(X_test)
y_predict

array([2.2582538, 2.0267825, 2.605949 , 2.429799 , 2.334363 , 2.605949 ,
       2.031973 , 2.5982492, 2.156478 , 2.2895453, 2.605949 , 2.0544233,
       2.279317 , 2.605949 , 2.5966938, 2.2614899, 2.1890538, 2.1118665,
       2.1497912, 2.129492 , 2.2243786, 2.2145662, 2.0930667, 2.3224812,
       2.605949 , 2.5966938, 2.6025608, 2.429799 , 2.244153 , 2.2575922,
       2.605949 , 2.1332006, 2.2614899, 2.0196917, 2.4845762, 2.605949 ,
       2.3199105, 2.605949 , 2.3741674, 2.2600327, 2.2168808, 2.5966938,
       2.1812503, 2.2644978, 2.2550735, 2.485744 , 2.605949 , 2.5966938,
       2.605949 , 2.6025608, 2.126449 , 2.4047432, 2.3871057, 2.0090065,
       2.5966938, 2.3850322, 2.381507 , 2.125475 , 2.0014653, 2.0677323,
       2.605949 , 2.605949 , 2.270087 , 2.3738942, 2.3250492, 2.3004646,
       2.2137847, 2.5966938, 2.3199105, 2.2868023, 2.429799 , 2.2941132,
       2.2068372, 2.429799 , 2.3224812, 2.136506 , 2.196981 , 2.4221864,
       2.5535824, 1.8738528, 2.2182841, 2.605949 , 