# Predicting the Primary Genre Of Movies/TV-Shows using Plot Text

## Introduction

Everyone loves binge-watching their favorite Movies and TV Shows. Nowadays, movies can pull elements from multiple genres (e.g., action, adventure, comedy, etc.) through complex themes intertwined within a single plot. For example, a movie can be primarily an action movie while also containing undercurrents of romance and comedy (for example, Thor: Ragnarok, Thor: Love and Thunder). The majority of online platforms (e.g., IMDB, rotten tomatoes, etc.) that maintain movie/tv-shows details include all genres BUT do not specifically mention a "primary genre." We seek to tackle this problem today by using machine learning to classify any movie/tv-show with a single, primary genre that best represents the title's plot.

## Commercial Applications

Highlighting a movie or tv show's primary genre can have many commercial applications, including improved content recommendation and increased precision in understanding various actors' performance and affinities per particular genres (and consequently, across niche fan bases).

## Methodology

As stated, we will use **machine learning** to **predict the primary genre** of movies/tv-shows. The majority of movies have a plot mentioned in a few lines of text that can be utilized for predicting target genres. This will be a **multi-label classification task** as we'll be predicting multiple genres per movie based on plot text. A Keras **Neural Network** model was trained for this task. The dataset used for training the model is publicly available **CMU Movie Summary Corpus** dataset. The text of the plot is encoded using the **TF-IDF** text encoding method. The dataset was divided into train (90%) and test (10%) subsets. The model outputs probabilities for each genre, and we choose the genre with the highest probability as the "primary genre" of the particular movie/tv-show. The process's total code is present as a single class which can be run by simply initiating and calling the run() method.

### Next Steps

In [None]:
import tensorflow as tf

print("Tensorflow Version : {}".format(tf.__version__))

In [None]:
import scipy

print(scipy.__version__)

In [None]:
from PIL import Image
import requests

In [None]:
import pandas as pd
import numpy as np
import json
import nltk
import re
import csv
import matplotlib.pyplot as plt 
import seaborn as sns
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report, accuracy_score

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.text import Tokenizer, text_to_word_sequence

import scikitplot as skplt
import matplotlib.pyplot as plt
%matplotlib inline

import scipy
#from scipy.sparse.csr_matrix import sort_indices

class PredictPrimaryGenre:
    
    def __init__(self, threshold = 0, file_name="PeerLogix Titles (IMDb Metadata).csv"):
        
        self.file_to_predict = file_name
        
        self.meta_file = "MovieSummaries/movie.metadata.tsv"
        self.plot_file = "MovieSummaries/plot_summaries.txt"
        self.peerlogix_genres_file = "peerlogix_genres.csv"
        
        self.only_peerlogix_valid_genres = True
        
        self.valid_genres = [
            'Action', 'Action & Adventure', 'Adventure', 'Animation', 'Comedy', 'Crime', 'Documentary',
            'Drama', 'Family', 'Fantasy', 'Horror', 'Musical', 'Mystery', 'Romance', 'Science Fiction',
            'Thriller', 'War', ]

        #self.valid_genres = pd.read_csv(self.peerlogix_genres_file).genre.unique().tolist() ## This one has around 8-10 more unique labels
        
        self.genre_corrections = {
                            'Action/Adventure' : 'Action & Adventure',
                            'Crime Fiction' : 'Crime',
                            'Family Film' : 'Family',
                            'Romance Film' : 'Romance',
                            'War Film' : 'War',
                            'Comedy Film' : 'Comedy'
                            }
        
        self.threshold = threshold 
        self.max_df = 1.0 #0.8
        self.stop_words = "english"
        self.max_features = 75000
        
        print('Predictor Initialized.')
        
    
    def loadMetaData(self):
        self.meta = pd.read_csv(self.meta_file, sep = '\t', header = None)
        self.meta.columns = ["movie_id",1,"movie_name",3,4,5,6,7,"genre"]
        self.meta['movie_id'] = self.meta['movie_id'].astype(str)
        
        print("Loading of Meta Data Complete.")
        
    def loadMoviePlotsData(self):
        plots = []

        with open(self.plot_file, 'r', encoding = 'UTF-8') as f:
            reader = csv.reader(f, dialect='excel-tab') 
            for row in tqdm(reader):
                plots.append(row)
                
        movie_id = []
        plot = []

        # extract movie Ids and plot summaries
        for i in tqdm(plots):
            movie_id.append(i[0])
            plot.append(i[1])

        # create dataframe
        self.movies = pd.DataFrame({'movie_id': movie_id, 'plot': plot})
        
        print("Loading of Movies Plot Data Complete.")
    
    def loadExistingPeerLogixGenres(self):
        existing_peerlogix_genres = pd.read_csv(self.peerlogix_genres_file)
        existing_peerlogix_genres = existing_peerlogix_genres.groupby("imdb_id").aggregate(lambda x: list(x)).reset_index()
        self.existing_peerlogix_genres = dict(zip(existing_peerlogix_genres["imdb_id"].values.tolist(), existing_peerlogix_genres["genre"].values.tolist()))
        
        print("Loading of Existing PeerLogix Genres Complete.")
        
    def mergeMovieAndPlotsData(self):
        self.movies = pd.merge(self.movies, self.meta[['movie_id', 'movie_name', 'genre']], on = 'movie_id')
        
        print("Merging of Genre and Plots Data Complete.")
        
    def cleanGenreData(self):
        genres = [] 

        for i in self.movies['genre']: 
            movie_genres = list(json.loads(i).values())

            movie_genres = [self.genre_corrections.get(genre, genre) for genre in movie_genres] ## Genre Correction in Data

            if self.only_peerlogix_valid_genres: ## Keep only peerlgix valid genres
                movie_genres = [genre for genre in movie_genres if genre in self.valid_genres]

            genres.append(movie_genres) 

        self.movies['genre_new'] = genres
    
        print("Cleaning of Genre Data Complete.")
    
    def cleanPlotData(self):
        self.movies['clean_plot'] = self.movies['plot'].apply(lambda x: " ".join(re.findall("[a-zA-Z]+", x.lower())))
        
        print("Cleaning of Plots Data Complete.")
        
    def LoadPeerLogixFileToPredictPrimaryGenre(self):
        self.peerlogix = pd.read_csv(self.file_to_predict)
        
        print("Loading of Prediction File Complete.")
    
    def LoadAndCleanData(self):
        self.loadMetaData()
        self.loadMoviePlotsData()
        self.loadExistingPeerLogixGenres()
        self.mergeMovieAndPlotsData()
        self.cleanGenreData()
        self.cleanPlotData()
        self.LoadPeerLogixFileToPredictPrimaryGenre()
        
    def PrepareMultiLabelTargetValues(self):
        self.multilabel_binarizer = MultiLabelBinarizer()
        self.multilabel_binarizer.fit(self.movies['genre_new'])

        # transform target variable
        self.Y = self.multilabel_binarizer.transform(self.movies['genre_new'])
        #print(self.Y.shape)
        print("Preparation of Multi-Label Target Complete.")
    
    def PrepareTrainTestSplit(self):
        self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(self.movies['clean_plot'], self.Y, test_size=0.1, random_state=123)
        
        print("Train Test Split Complete.")
        
    def TransformPlotsFromTextToFloats(self):
        self.tfidf_vectorizer = TfidfVectorizer(max_df=self.max_df, stop_words=self.stop_words, max_features=self.max_features)
        
        self.X_train_tfidf = self.tfidf_vectorizer.fit_transform(self.X_train)
        self.X_test_tfidf = self.tfidf_vectorizer.transform(self.X_test)
        self.X_train_tfidf.sort_indices()
        self.X_test_tfidf.sort_indices()
        
        print("Train/Test Shape : {}/{}".format(self.X_train_tfidf.shape, self.X_test_tfidf.shape))
        
        print("Text Data Vectorization Complete.")
        
    def PrepreDataForMLModel(self):
        self.PrepareMultiLabelTargetValues()
        self.PrepareTrainTestSplit()
        self.TransformPlotsFromTextToFloats()
        
        
    def TrainClassifier(self):
        #lr = LogisticRegression(max_iter=500)
        #gb = GradientBoostingClassifier()
        #self.clf = OneVsRestClassifier(gb)
        
        self.clf = Sequential([
                layers.Dense(256, activation="relu", input_shape=(self.X_train_tfidf.shape[1],)),
                layers.Dense(128, activation="relu"),
                layers.Dense(64, activation="relu"),
                layers.Dense(self.Y.shape[1], activation="sigmoid"),
            ])
        
        #print(self.clf.summary())
        self.clf.compile("adam", "binary_crossentropy", metrics=["categorical_crossentropy"])

        self.clf.fit(self.X_train_tfidf, self.Y_train, batch_size=256, epochs=5)#, validation_data=(self.X_test_tfidf, self.Y_test))
        
    def EvaluateClassifier(self):
        Y_test_pred = self.clf.predict(self.X_test_tfidf)
        Y_test_pred = np.where(Y_test_pred > 0.5, 1, 0)
        print("\nF1-Score : {:.3f}".format(f1_score(self.Y_test, Y_test_pred, average="micro")))
        
    def PredictGenre(self, text):
        if isinstance(text, str):
            cleaned_text = " ".join(re.findall("[a-zA-Z]+", text.lower()))
            X = self.tfidf_vectorizer.transform([cleaned_text])
            X.sort_indices()
            #probs = self.clf.predict_proba(X) ## Probabilities/ liklihood
            probs = self.clf.predict(X) ## Probabilities/ liklihood

            # Isolate highest probable genre
            primary_genre_idx = probs.argsort()[0][-1] ## Taking highest probability/liklihood
            idx2, idx3, idx4, idx5 = probs.argsort()[0][-2], probs.argsort()[0][-3], probs.argsort()[0][-4], probs.argsort()[0][-5]

            # Discard results if none were above threshold 
            if probs[0][primary_genre_idx] < self.threshold:
                return ["NA", ] * 5

            # Else, return top genre
            primary_genre = self.multilabel_binarizer.classes_[primary_genre_idx] 
            genre2, genre3, genre4, genre5 = self.multilabel_binarizer.classes_[idx2], self.multilabel_binarizer.classes_[idx3], self.multilabel_binarizer.classes_[idx4], self.multilabel_binarizer.classes_[idx5]

            primary_genre = self.genre_corrections.get(primary_genre, primary_genre)
            genre2, genre3, genre4, genre5 = self.genre_corrections.get(genre2, genre2), self.genre_corrections.get(genre3, genre3), self.genre_corrections.get(genre4, genre4), self.genre_corrections.get(genre5, genre5)
            return primary_genre, genre2, genre3, genre4, genre5
        else:
            #print(text)
            return ["NA", ] * 5    
    
    def PredictGenreForFile(self):
        primary_genre = []

        for i, (imdb_id, plot) in enumerate(self.peerlogix[["imdb_id","description"]].values):
            existing_genres = self.existing_peerlogix_genres.get(imdb_id, []) ## Retrieve Existing Genres for id
            predicted_genres = self.PredictGenre(plot) ## Make Prediction on Plot

            if existing_genres: ## If Genres present for IMDB ID then choose from it else append predicted one.
                if len(existing_genres) == 1: ## If single Genre then it'll be primary Genre
                    primary_genre.append(existing_genres[0])
                else:
                    selected_genre = None
                    for genre in predicted_genres: ### Check for predicted Genre in existing Genres
                        if genre in existing_genres:
                            selected_genre = genre
                            break    
                    primary_genre.append(selected_genre if selected_genre else predicted_genres[0]) 

            else:
                ## Category NA gets appeneded here only for movies that does not have plot and existing genres
                primary_genre.append(predicted_genres[0]) ## Append first one which is primary

            if (i+1)%5000 == 0:
                print("{} iteration completed".format(i+1))
        
        self.peerlogix["Predicted_Primary_Genre1"] = primary_genre
        self.peerlogix["Actual_Primary_Genre1"] = [self.existing_peerlogix_genres.get(imdb_id, ["NA"])[0] for imdb_id in self.peerlogix.imdb_id]
    
    def RunPredictor(self):
        
        # Download the training data
        self.LoadAndCleanData()
        self.PrepreDataForMLModel()
        
        print("\n=========== Data Loading and Cleaning Complete =============\n")
        
        # train the classifier
        self.TrainClassifier()
        self.EvaluateClassifier()
        
        print("\n=========== Model Training and Evaluation Complete =============\n")
        
        # Loop through each and choose primary genre 
        self.PredictGenreForFile()
        
        print("\n=========== Prediction of Genre Complete =============\n")
        
        # Save CSV locally 
        self.peerlogix.to_csv("file_with_genre.csv")
        
        return self.peerlogix

In [None]:
%%time

genre_predictor = PredictPrimaryGenre()

final_prediction_df = genre_predictor.RunPredictor()

final_prediction_df.head()

In [None]:
from collections import Counter

actual = Counter(final_prediction_df["Actual_Primary_Genre1"].values)
predicted = Counter(final_prediction_df["Predicted_Primary_Genre1"].values)

for genre in actual.keys():
    print("{:20s} - Actual: {:5d}, Predicted: {:5d}, Actual: {:5.2f} %, Predicted: {:5.2f} %".format(genre,
                                                            actual[genre],
                                                            predicted[genre],
                                                            100 * actual[genre] / final_prediction_df.shape[0],
                                                            100 * predicted[genre] / final_prediction_df.shape[0],
                                                           ))

In [None]:
## Rerun this cell if it fails with URL related error as some ids don't have URLs

imdb_ids = genre_predictor.peerlogix.sample(6).imdb_id.values

genres_l, urls, plots, imgs, p_genres = [], [], [], [], []
for imdb_id in imdb_ids:
    genres = genre_predictor.existing_peerlogix_genres[imdb_id]
    poster_url = genre_predictor.peerlogix[genre_predictor.peerlogix.imdb_id== imdb_id].poster_url.values[0]
    plot = genre_predictor.peerlogix[genre_predictor.peerlogix.imdb_id== imdb_id].description.values[0]
    img = Image.open(requests.get(poster_url, stream=True).raw)
    predicted_genres = list(genre_predictor.PredictGenre(plot))
    
    genres_l.append(genres)
    urls.append(poster_url)
    plots.append(plot)
    imgs.append(img)
    p_genres.append(predicted_genres)

In [None]:
from PIL import Image, ImageFont, ImageDraw
from IPython. display import display
import requests
import numpy as np
import matplotlib.pyplot as plt
import textwrap

%matplotlib inline

In [None]:
fig = plt.figure(figsize=(22,16))

for i in range(6):
    ax = fig.add_subplot(2,3,i+1)

    ax.imshow(imgs[i]);
    ax.text(0,0, "Actual: {}\nPredicted: {}\n".format(genres_l[i], p_genres[i]), fontdict={"fontsize": 15}, ha="left", wrap=True);
    ax.set_xticks([],[]);ax.set_yticks([],[]);
    
plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(22,16))

for i in range(6):
    ax = fig.add_subplot(2,3,i+1)

    ax.imshow(imgs[i]);
    ax.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[i], width=35)), p_genres[i][0], p_genres[i][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
    ax.set_xticks([],[]);ax.set_yticks([],[]);
    
plt.tight_layout()

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[:1])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[0])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[0], width=50)), p_genres[0][0], p_genres[0][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="tomato");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[1:2])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[1])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[1], width=50)), p_genres[1][0], p_genres[1][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="dodgerblue");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[2:3])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[2])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[2], width=50)), p_genres[2][0], p_genres[2][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="lime");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[3:4])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[3])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[3], width=50)), p_genres[3][0], p_genres[3][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="dodgerblue");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[4:5])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[4])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[4], width=50)), p_genres[4][0], p_genres[4][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="dodgerblue");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
vectorized_arr = genre_predictor.tfidf_vectorizer.transform(plots[5:])
probs = genre_predictor.clf.predict_proba(vectorized_arr)
probs_idxs = probs.argsort()[0][::-1]

cats = genre_predictor.multilabel_binarizer.classes_[probs_idxs]
ordered_probs = probs[0][probs_idxs]

In [None]:
fig = plt.figure(figsize=(18, 6))

ax1 = fig.add_subplot(121)
ax1.imshow(imgs[5])
ax1.text(0,0, "{}\n\nPrimary Genre: {}\nOther Genres : {}\n".format("\n".join(textwrap.wrap(plots[5], width=50)), p_genres[5][0], p_genres[5][1:]), 
            fontdict={"fontsize": 15},
            ha="left", 
            wrap=True);
    
ax1.set_xticks([],[]);ax1.set_yticks([],[]);

ax2 = fig.add_subplot(122)
ax2.bar(x=cats[:5], height=ordered_probs[:5], color="dodgerblue");
ax2.set_ylabel("Probability");
ax2.set_xlabel("Predicted Genres");
ax2.set_title("Probability of Categories");

In [None]:
from lime import lime_text

classes = genre_predictor.multilabel_binarizer.classes_.tolist()
explainer = lime_text.LimeTextExplainer(class_names=classes, verbose=True)

def make_predictions(X_batch_text):
    X_batch = genre_predictor.tfidf_vectorizer.transform(X_batch_text)
    preds = genre_predictor.clf.predict_proba(X_batch)
    return preds    

In [None]:
idx = 0

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
idx = 1

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
idx = 2

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
idx = 3

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
idx = 4

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
idx = 5

explanation = explainer.explain_instance(plots[idx],
                                         classifier_fn=make_predictions,
                                         labels=[classes.index(p_genres[idx][0])], num_features=15)
explanation.show_in_notebook()

In [None]:
final_prediction_df[final_prediction_df['imdb_id'] == 'tt0082517']

In [None]:
final_prediction_df.Predicted_Primary_Genre1.value_counts() / final_prediction_df.Predicted_Primary_Genre1.value_counts().sum()