In [1]:
## Here is a useful tutorial for how to use Doc2Vec:
## https://medium.com/@mishra.thedeepak/doc2vec-simple-implementation-example-df2afbbfbad5

## Import all the dependencies
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from nltk.tokenize import word_tokenize
from skmultilearn.problem_transform import BinaryRelevance
from skmultilearn.problem_transform import LabelPowerset
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC

In [2]:
## load the data
train = pd.read_csv('input/train.csv')
test = pd.read_csv('input/test.csv')
subm = pd.read_csv('input/sample_submission.csv')

label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

In [3]:
## load model
max_epochs = 100
vec_size = 50
model = Doc2Vec.load("models/d2v.model_maxepochs_" + str(max_epochs) + "_vecsize_" + str(vec_size))

In [4]:
ntrain = len(train['id']) 
ntest = len(test['id'])

X_train = np.zeros((ntrain, vec_size))
for i in range(ntrain):
    X_train[i,:] = model.docvecs[i]
    
X_test = np.zeros((ntest, vec_size))
for i in range(ntest):
    X_test[i,:] = model.docvecs[ntrain+i]
    
y_train = train[label_cols].values

In [5]:
## Binary Relevance Method
#classifier = BinaryRelevance(classifier = SVC(probability=True), require_dense = [False, True])
classifier = BinaryRelevance(classifier = MLPClassifier(hidden_layer_sizes=(5,), max_iter=1000), require_dense = [False, True])
#classifier = BinaryRelevance(classifier = GaussianNB())

## Label Powerset Method
#classifier = LabelPowerset(GaussianNB())

## train
classifier.fit(X_train, y_train)

## predict
predictions = classifier.predict_proba(X_test).toarray()

In [6]:
submid = pd.DataFrame({'id': subm["id"]})
submission = pd.concat([submid, pd.DataFrame(predictions, columns = label_cols)], axis=1)
filename = "submissions/submission_maxepochs_" + str(max_epochs) + "_vecsize_" + str(vec_size) + '.csv'
submission.to_csv(filename, index=False)

In [7]:
submission

Unnamed: 0,id,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,00001cee341fdb12,0.964078,7.490767e-02,8.544754e-01,1.182386e-01,6.647984e-01,3.373291e-01
1,0000247867823ef7,0.003191,6.680416e-06,7.742312e-04,4.662123e-06,5.882194e-04,1.312609e-04
2,00013b17ad220c46,0.001215,2.527521e-06,1.635789e-04,1.035251e-05,1.043496e-04,4.596804e-05
3,00017563c3f7919a,0.000443,4.707854e-07,3.495501e-05,2.622167e-05,8.364274e-05,1.096681e-07
4,00017695ad8997eb,0.039892,5.332047e-04,7.962707e-03,1.681830e-04,8.392733e-03,5.372617e-04
5,0001ea8717f6de06,0.001589,7.798498e-07,1.729872e-04,8.644654e-06,2.062112e-04,1.875656e-05
6,00024115d4cbde0f,0.004437,7.067669e-06,1.468856e-04,1.802703e-05,7.324488e-04,4.338912e-06
7,000247e83dcc1211,0.496792,2.785718e-02,3.396626e-01,9.771156e-04,1.963225e-01,1.111161e-02
8,00025358d4737918,0.196460,8.624355e-06,1.947773e-02,4.814422e-03,5.259188e-02,3.026364e-05
9,00026d1092fe71cc,0.022408,3.737656e-05,2.104809e-03,1.850600e-04,2.247622e-02,2.532283e-04
