In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import nltk
import sklearn

In [2]:
dialog_texts = pd.read_pickle('./data/dialog_texts')

In [33]:
meta_cols = ['movie_id', 'title', 'year', 'rating', 'no. votes', 'genres']
meta = pd.read_table(
    './datasets/movie-dialog-corpus/movie_titles_metadata.tsv', sep='\t', header=None, names=meta_cols, index_col='movie_id')
def str_to_list(str: str):
    s = str.strip('\'[]')
    return list(set(s.split("\' \'") if s else []))

meta['genre_list'] = meta.genres.apply(str_to_list)

In [25]:
def merge_texts(g):
    return ' '.join(g.text)

movie_texts = dialog_texts[['movie_id', 'text']].groupby('movie_id').apply(merge_texts)

In [29]:
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=2000)
X = vectorizer.fit_transform(movie_texts)

In [44]:
X.shape

(617, 1000)

In [None]:
vectorizer.get_feature_names_out()

In [59]:
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

genre_counts = meta.explode('genre_list').genre_list.value_counts()
top_genres = genre_counts.nlargest(5).index
print('top genres:', top_genres.to_list())


def get_genre_sets(X, labels, genres):
    filtered_genres = [[g for g in x if g in genres] for x in labels]
    mlb = MultiLabelBinarizer()
    y = mlb.fit_transform(filtered_genres)
    return train_test_split(X, y, random_state=1) + [mlb]


X_train, X_test, y_train, y_test, mlb = get_genre_sets(
    X, meta.loc[movie_texts.index].genre_list, top_genres)


top genres: ['drama', 'thriller', 'comedy', 'action', 'crime']


In [61]:
from sklearn.ensemble import RandomForestClassifier 
rf_clf = RandomForestClassifier(n_estimators=100)
rf_clf.fit(X_train,y_train)

In [62]:
rf_clf.score(X_test, y_test)

0.18064516129032257

In [63]:
from sklearn.metrics import classification_report

def report_clf(clf, X, y, classes):
    y_predict = clf.predict(X)
    return classification_report(
        y,
        y_predict,
        target_names=classes,
        zero_division=1
    )

print(report_clf(rf_clf, X_test, y_test, mlb.classes_))

              precision    recall  f1-score   support

      action       1.00      0.04      0.09        45
      comedy       1.00      0.00      0.00        45
       crime       1.00      0.03      0.05        36
       drama       0.62      0.71      0.66        82
    thriller       0.80      0.42      0.55        67

   micro avg       0.68      0.32      0.44       275
   macro avg       0.88      0.24      0.27       275
weighted avg       0.84      0.32      0.35       275
 samples avg       0.76      0.39      0.40       275

