In [1]:
from sklearn import linear_model
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import json
from sklearn.preprocessing import LabelEncoder
import sklearn.metrics as metrics

## Read, Clean Data

In [2]:
df = pd.read_csv('../data/tmdb_5000_movies.csv.gz') #https://www.kaggle.com/tmdb/tmdb-movie-metadata
df.dropna(inplace=True)
df.budget = df.budget/1000000
df.revenue = df.revenue/1000000
df.head(2)

Unnamed: 0,budget,genres,homepage,id,keywords,original_language,original_title,overview,popularity,production_companies,production_countries,release_date,revenue,runtime,spoken_languages,status,tagline,title,vote_average,vote_count
0,237.0,"[{""id"": 28, ""name"": ""Action""}, {""id"": 12, ""nam...",http://www.avatarmovie.com/,19995,"[{""id"": 1463, ""name"": ""culture clash""}, {""id"":...",en,Avatar,"In the 22nd century, a paraplegic Marine is di...",150.437577,"[{""name"": ""Ingenious Film Partners"", ""id"": 289...","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2009-12-10,2787.965087,162.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}, {""iso...",Released,Enter the World of Pandora.,Avatar,7.2,11800
1,300.0,"[{""id"": 12, ""name"": ""Adventure""}, {""id"": 14, ""...",http://disney.go.com/disneypictures/pirates/,285,"[{""id"": 270, ""name"": ""ocean""}, {""id"": 726, ""na...",en,Pirates of the Caribbean: At World's End,"Captain Barbossa, long believed to be dead, ha...",139.082615,"[{""name"": ""Walt Disney Pictures"", ""id"": 2}, {""...","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2007-05-19,961.0,169.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}]",Released,"At the end of the world, the adventure begins.",Pirates of the Caribbean: At World's End,6.9,4500


## Identify Predictors and scale data

In [6]:
predictors = ['budget','popularity','runtime','vote_average','vote_count']
X_orig = df[predictors].values
print('X_orig.shape:',X_orig.shape)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_orig)
print('X_scaled.shape:',X_scaled.shape)

X_orig.shape: (1493, 5)
X_scaled.shape: (1493, 5)


## Get the list of all generes, identify unique genres 

In [7]:
#List of (list of genres) for each movie
movie_genre_lists = df.genres.apply(lambda genres:[genre['name'] for genre in json.loads(genres)]).values
genres_set = set()
for movie_genre_list in movie_genre_lists:
     genres_set.update(movie_genre_list)
uniq_genres_list = sorted(list(genres_set))
print('Number of unique genres:',len(uniq_genres_list))
uniq_genres_list

Number of unique genres: 20


['Action',
 'Adventure',
 'Animation',
 'Comedy',
 'Crime',
 'Documentary',
 'Drama',
 'Family',
 'Fantasy',
 'Foreign',
 'History',
 'Horror',
 'Music',
 'Mystery',
 'Romance',
 'Science Fiction',
 'TV Movie',
 'Thriller',
 'War',
 'Western']

In [13]:
y_all = []
for genre in uniq_genres_list:
    y = [1 if genre in movie_genre_list else 0 for movie_genre_list in movie_genre_lists]
    y_all.append(y)

y_all = np.array(y_all).T
X_train, X_test, y_all_train, y_all_test = train_test_split(X_scaled, y_all, test_size=0.2,random_state=42)

y_all.shape

(1493, 20)

In [14]:
rocs = []
for genre,y_train,y_test in zip(uniq_genres_list,y_all_train.T,y_all_test.T):
    logreg = linear_model.LogisticRegression()
    logreg.fit(X_train, y_train)
    logregs.append(logreg)
    preds = logreg.predict_proba(X_test)
    if len(set(y_test)) <= 1:
        continue
    roc = metrics.roc_auc_score(y_test,preds[:,1])
    rocs.append(roc)
    print('ROC,',genre, ' = ' ,np.round(roc,2))
    
print('\n\nROCs mean', np.round(np.mean(rocs),2))

ROC, Action  =  0.7
ROC, Adventure  =  0.81
ROC, Animation  =  0.96
ROC, Comedy  =  0.69
ROC, Crime  =  0.66
ROC, Documentary  =  0.93
ROC, Drama  =  0.78
ROC, Family  =  0.86
ROC, Fantasy  =  0.69
ROC, Foreign  =  1.0
ROC, History  =  0.68
ROC, Horror  =  0.8
ROC, Music  =  0.61
ROC, Mystery  =  0.54
ROC, Romance  =  0.57
ROC, Science Fiction  =  0.76
ROC, Thriller  =  0.65
ROC, War  =  0.84
ROC, Western  =  0.51


ROCs mean 0.74
