In [3]:
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import MultiLabelBinarizer
from xgboost import XGBClassifier
from data_preprocessing.data_preprocessor import DataPreprocessor
from modeling.model_handler import ModelHandler


In [4]:
data_preprocessor = DataPreprocessor('data/processed_data.csv')

In [5]:
data_preprocessor.filter_genres(k=5)

Unnamed: 0,plot_summary,genres
0,Murugappa is a small time farm labourer who l...,[Drama]
1,A hyper-vigilant agent of the Department of Pu...,"[Thriller, Drama]"
2,"Four friends- Gangu , Abdul , Nihal and Gary ...","[Action, Drama]"
3,A married man is having an affair with another...,[Drama]
4,"The movie concerns the life of Tomasina ""Tommy...","[Romance Film, Comedy]"
...,...,...
29990,Jimmy Boland has been sentenced to a Californ...,[Action]
29991,Englishman Ronald Quayle was accused of murder...,[Drama]
29995,Managing Editor Sam Gatlin arrives in the afte...,[Drama]
29997,The film is about a woman named Jennefer who ...,"[Thriller, Drama]"


In [6]:
mlb = MultiLabelBinarizer()
mlb_labels = mlb.fit_transform(data_preprocessor.data['genres'])

x_train, x_test, y_train, y_test = train_test_split(data_preprocessor.data['plot_summary'], mlb_labels, test_size=0.2,
                                                    random_state=42)

In [7]:
base_classifier = SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3, random_state=42)
classifier = OneVsRestClassifier(base_classifier)

In [8]:
trained_clf_pipeline = ModelHandler.train_model_skl(classifier, x_train, y_train)

In [9]:
reports = ModelHandler.evaluate_model_skl(trained_clf_pipeline, mlb, x_test, y_test)

In [10]:
ModelHandler.save_model_skl(trained_clf_pipeline, mlb, reports, 'SGD_One_vs_Rest')

In [11]:
pre_trained_clf_pipeline, pre_trained_mlb = ModelHandler.load_model_skl('SGD_One_vs_Rest')

In [12]:
results = ModelHandler.inference_model_skl(pre_trained_clf_pipeline, pre_trained_mlb, x_test.iloc[0], threshold=0.2)
results

[('Drama', 0.8108802845929739), ('Romance Film', 0.3871883515917228)]

In [13]:
base_xgb = XGBClassifier(eval_metric='logloss')
classifier = ClassifierChain(base_xgb, order='random', random_state=42)

In [14]:
trained_clf_pipeline = ModelHandler.train_model_skl(classifier, x_train, y_train)

In [15]:
reports = ModelHandler.evaluate_model_skl(trained_clf_pipeline, mlb, x_test, y_test)

In [16]:
ModelHandler.save_model_skl(trained_clf_pipeline, mlb, reports, 'XGB_Chain')

In [17]:
pre_trained_clf_pipeline, pre_trained_mlb = ModelHandler.load_model_skl('XGB_Chain')

In [18]:
results = ModelHandler.inference_model_skl(pre_trained_clf_pipeline, pre_trained_mlb, x_test.iloc[0], threshold=0.2)
results

[('Action', 0.3328200578689575),
 ('Drama', 0.8717947006225586),
 ('Romance Film', 0.5780348181724548)]