In [107]:
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
import pandas as pd
import ast

In [108]:
train_df = pd.read_csv('Data/train.csv')


In [109]:
train_df['genres'].count()

2993

In [110]:
drop_cols = ['id', 'belongs_to_collection', 'homepage', 'imdb_id', 'original_language', 'original_title',
             'overview', 'popularity', 'poster_path', 'production_companies', 'production_countries', 'release_date', 
             'spoken_languages','runtime', 'status', 'tagline', 'Keywords', 'cast', 'crew']

In [111]:
df_2 = train_df.drop(labels=drop_cols, axis=1)
df_2.head(3)

Unnamed: 0,budget,genres,title,revenue
0,14000000,"[{'id': 35, 'name': 'Comedy'}]",Hot Tub Time Machine 2,12314651
1,40000000,"[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...",The Princess Diaries 2: Royal Engagement,95149435
2,3300000,"[{'id': 18, 'name': 'Drama'}]",Whiplash,13092000


In [112]:
#Drop bad budget data
# Get names of indexes for which column budget less than budg_cutoff
budg_cutoff = 1000
budg_drop_indices = df_2[ df_2['budget'] < budg_cutoff ].index
 
# Delete these row indices from DataFrame because they're unlikely/impossible
df_2.drop(budg_drop_indices , inplace=True)
df_2.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 2170 entries, 0 to 2999
Data columns (total 4 columns):
budget     2170 non-null int64
genres     2165 non-null object
title      2170 non-null object
revenue    2170 non-null int64
dtypes: int64(2), object(2)
memory usage: 84.8+ KB


In [113]:
#Drop genre nulls
df_2.dropna(subset=['genres'], inplace=True)
df_2.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 2165 entries, 0 to 2999
Data columns (total 4 columns):
budget     2165 non-null int64
genres     2165 non-null object
title      2165 non-null object
revenue    2165 non-null int64
dtypes: int64(2), object(2)
memory usage: 84.6+ KB


In [114]:
df_2['genres'][19]

"[{'id': 28, 'name': 'Action'}, {'id': 53, 'name': 'Thriller'}, {'id': 80, 'name': 'Crime'}]"

In [115]:
#Create new column containing a list of genres for each movie
def to_list(cell):
    #input is a single cell containing a list of dicts in string form
    #returns a list of genres as unique strings
    l = []
    cell = ast.literal_eval(cell)  #cells are str, need list 
    
    for d in cell:                #cell is list of dicts
        l.append(d['name'])
    return(l)

df_2['genre_names'] = df_2['genres'].map(to_list)
#Drop duplicate info, ie. old 'genres'
df_2.drop(['genres'], axis=1, inplace=True)
df_2.head()

Unnamed: 0,budget,title,revenue,genre_names
0,14000000,Hot Tub Time Machine 2,12314651,[Comedy]
1,40000000,The Princess Diaries 2: Royal Engagement,95149435,"[Comedy, Drama, Family, Romance]"
2,3300000,Whiplash,13092000,[Drama]
3,1200000,Kahaani,16000000,"[Thriller, Drama]"
5,8000000,Pinocchio and the Emperor of the Night,3261638,"[Animation, Adventure, Family]"


In [116]:
df_2['genre_names'][10:15]

14    [Action, Thriller, Science Fiction, Mystery]
15                          [Action, Crime, Drama]
16                              [Horror, Thriller]
18                               [Comedy, Romance]
19                       [Action, Thriller, Crime]
Name: genre_names, dtype: object

In [117]:
# Create MultiLabelBinarizer object to transfor genre_names
mlb = MultiLabelBinarizer()
df_2_transf = df_2.join(pd.DataFrame(mlb.fit_transform(df_2.pop('genre_names')),
                          columns=mlb.classes_,
                          index=df_2.index))
df_2_transf.describe()

Unnamed: 0,budget,revenue,Action,Adventure,Animation,Comedy,Crime,Documentary,Drama,Family,...,History,Horror,Music,Mystery,Romance,Science Fiction,TV Movie,Thriller,War,Western
count,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,...,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0,2165.0
mean,31218570.0,87808170.0,0.280831,0.171363,0.048961,0.333025,0.173672,0.009238,0.498845,0.091455,...,0.048037,0.111778,0.031871,0.084065,0.184758,0.11455,0.000462,0.292841,0.037413,0.014319
std,40356470.0,156079800.0,0.449509,0.376913,0.215836,0.471404,0.378915,0.095691,0.500114,0.288322,...,0.213894,0.315166,0.175696,0.277549,0.38819,0.318551,0.021492,0.455171,0.189816,0.118828
min,2500.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,5952000.0,7096000.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,16800000.0,29400000.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,40000000.0,100491700.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
max,380000000.0,1519558000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
