In [47]:
import csv
from collections import defaultdict
from pathlib import Path

import nltk
import numpy as np
import pandas as pd
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from sklearn.svm import LinearSVC

from scripts.preprocess_data import preprocess_data
from utils.helpers import list_from_yaml
from utils.metrics import confusion_matrix, classifcation_metrics
from utils.serialization import load_json, save_pickle, save_json
from utils.text_cleaning import remove_stopwords

# Constants

In [48]:
DATA_PATH = Path('./data')
SAVE_PATH = Path('./trained_models')

# Pre-process data

In [29]:
# preprocess_data(DATA_PATH)


# Load Pre-processed data

In [49]:
movie_data = pd.read_csv(DATA_PATH / 'prepared_movie_data.csv')
movie_data.head()

Unnamed: 0,movie_name,freebase_movie_id,movie_release_date,movie_box_office_revenue,movie_runtime,movie_languages,movie_countries,movie_genres,plot_summary
0,Ghosts of Mars,/m/03vyhn,2001-08-24,14010832.0,98.0,"{""/m/02h40lc"": ""English Language""}","{""/m/09c7w0"": ""United States of America""}","['science fiction', 'supernatural', 'action', ...",Set in the second half of the 22nd century the...
1,White Of The Eye,/m/0285_cd,1987,,110.0,"{""/m/02h40lc"": ""English Language""}","{""/m/07ssc"": ""United Kingdom""}","['psychological thriller', 'thriller']",A series of murders of rich young women throug...
2,A Woman in Flames,/m/01mrr1,1983,,106.0,"{""/m/04306rv"": ""German Language""}","{""/m/0345h"": ""Germany""}",['drama'],Eva an upper class housewife becomes frustrate...
3,The Sorcerer's Apprentice,/m/04jcqvw,2002,,86.0,"{""/m/02h40lc"": ""English Language""}","{""/m/0hzlz"": ""South Africa""}","['family', 'adventure', 'world cinema', 'fanta...",Every hundred years the evil Morgana returns t...
4,Little city,/m/0gffwj,1997-04-04,,93.0,"{""/m/02h40lc"": ""English Language""}","{""/m/09c7w0"": ""United States of America""}","['comedy', 'romance', 'drama', 'comedy-drama',...",Adam a San Francisco-based artist who works as...


# Continue Pre-processing

## Lower & Tokenize

In [50]:
# to lower case
movie_data.loc[:, 'plot_summary'] = movie_data['plot_summary'].apply(lambda x: x.lower())
# tokenizing
movie_data.loc[:, 'tokenized_summary'] = movie_data['plot_summary'].apply(word_tokenize)

## Add 100 most frequent words to list of stopwords

In [51]:
# add stop words
frequent_words = defaultdict(lambda: 0)
for summary in movie_data['tokenized_summary']:
    for word in summary:
        frequent_words[word] += 1
frequent_words = sorted(frequent_words.items(), key=lambda item: item[1], reverse=True)

In [52]:
most_frequent_words = pd.DataFrame(frequent_words).head(100)
most_frequent_words = set(most_frequent_words[0]).union(set(stopwords.words('English')))

## Remove stopwords from data

In [34]:
stemmer = SnowballStemmer('english')

In [35]:
movie_data.loc[:, 'cleaned_plot_summary'] = movie_data['tokenized_summary'].apply(remove_stopwords, stemmer=stemmer, stopwords=most_frequent_words)

## Convert genres to multilabel one hot encoding

In [11]:
genre_mapping = load_json(DATA_PATH, 'genre_mapping')

Loaded file: genre_mapping.json successfully


In [12]:
def get_genre_index(x, mapping=genre_mapping):
    x = list_from_yaml(x)
    return [genre_mapping[genre] for genre in x]
        

In [13]:
movie_data.loc[:, 'genre_indices'] = movie_data['movie_genres'].apply(get_genre_index)

# Training

## Parameters

In [14]:
TRAIN_SPLIT = 0.8
VALIDATION_SPLIT = 0.5
RANDOM_STATE = 42 # seed

## Train/Test split

In [15]:
train_data, test_data = train_test_split(movie_data, train_size=TRAIN_SPLIT, shuffle=True,
                                         random_state=RANDOM_STATE)
val_data, test_data = train_test_split(test_data, train_size=TRAIN_SPLIT, shuffle=True,
                                       random_state=RANDOM_STATE)

### Generate Data

In [16]:
mlb = MultiLabelBinarizer()
tfidf = TfidfVectorizer()

In [17]:
train_X, train_y = tfidf.fit_transform(train_data['cleaned_plot_summary'].to_numpy()), mlb.fit_transform(train_data['genre_indices'].to_numpy())
test_X, test_y = tfidf.transform(test_data['cleaned_plot_summary'].to_numpy()), mlb.transform(test_data['genre_indices'].to_numpy())
print(f'train data has {train_X.shape[0]} movies,'
      f' test data has {test_X.shape[0]} movies')

train data has 33143 movies, test data has 1658 movies


## Logistic Regression

In [24]:
params = {'estimator__C': np.arange(1.0, 5, 1)}
lr_ovr = OneVsRestClassifier(LogisticRegression(multi_class='ovr', solver='sag', max_iter=100))
gs_classifier = GridSearchCV(lr_ovr, param_grid=params, scoring='f1_micro')
gs_classifier.fit(train_X, train_y)
print(f'Best training F1 score: {gs_classifier.best_score_:.5f}')
print(f'Best parameters: {gs_classifier.best_params_}')

Best training F1 score: 0.39500
Best parameters: {'estimator__C': 4.0}


In [27]:
lr_ovr = OneVsRestClassifier(LogisticRegression(multi_class='ovr', solver='sag',
                                                max_iter=1000, C= 4.0))
lr_ovr.fit(train_X, train_y)

OneVsRestClassifier(estimator=LogisticRegression(C=4.0, max_iter=1000,
                                                 multi_class='ovr',
                                                 solver='sag'))

In [28]:
pred_y = lr_ovr.predict(test_X)
print('Test Metrics')
TP, TN, FP, FN = confusion_matrix(pred_y, test_y)
test_metrics = classifcation_metrics(TP, TN, FP, FN)
test_metrics

Test Metrics


{'precision': 0.6730118443316413,
 'recall': 0.30696507814007334,
 'accuracy': 0.9503266040011835,
 'f1': 0.4216244865509474}

## SVM

In [38]:
svm_ovr = OneVsRestClassifier(LinearSVC(C=3))
svm_ovr.fit(train_X, train_y)

OneVsRestClassifier(estimator=LinearSVC(C=3))

In [39]:
pred_y = svm_ovr.predict(test_X)
TP, TN, FP, FN = confusion_matrix(pred_y, test_y)
test_metrics = classifcation_metrics(TP, TN, FP, FN)
test_metrics

{'precision': 0.5634132086499123,
 'recall': 0.37198533667759986,
 'accuracy': 0.9459567107449303,
 'f1': 0.4481115630447414}

# Saving models

In [None]:
metadata = {'tokenizer': word_tokenize,
            'genre_mapping': genre_mapping,
            'stopwords': most_frequent_words,
            'tfidf': tfidf,
            'mlb': mlb,
            'stemmer': stemmer,
            'model_type': 'sklearn'}

In [None]:
save_pickle(metadata, SAVE_PATH, 'lr_metadata')
save_pickle(lr_ovr, SAVE_PATH, 'lr')
save_pickle(metadata, SAVE_PATH, 'svm_metadata')
save_pickle(svm_ovr, SAVE_PATH, 'svm')