In [1]:
import os, sys, time
import multiprocessing
import pickle
import re, string
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

## Preprocess

In [2]:
data = pd.read_csv("mbti_1.csv")
n_users = len(data)
posts = data["posts"]
labels = data["type"].unique()
n_class = len(labels)
type2num = {label: i for i,label in enumerate(labels)}
Y = np.array(list(map(lambda s: type2num[s], data["type"].to_numpy())))

In [3]:
def plot_distribution():
    fig, ax = plt.subplots(figsize=(10,4))
    type_val = data["type"].value_counts()
    labels = type_val.keys()
    x = np.arange(len(labels))
    ax.bar(x, type_val.values)
    ax.set_ylabel("# of people")
    ax.set_xticks(x)
    ax.set_xticklabels(labels,rotation='45')
    ax.set_axisbelow(True)
    ax.yaxis.grid(color='gray', linestyle='dashed')
    fig.tight_layout()
    plt.show()

In [4]:
def generate_posts(path=""):
    filename = os.path.join(path,"posts.pkl")
    user_posts = []
    if not os.path.isfile(filename):
        stopwords = pd.read_csv("stopwords.csv").to_numpy().reshape(-1)
        stopwords = np.array(list(map(lambda s: s.replace("'",""),stopwords)))
        for uid in range(n_users):
            # add empty space first (better used for regex parsing)
            new_post = posts[uid].replace("|||"," ||| ")
            new_post = new_post.replace(",",", ")
            # remove url links
            new_post = re.sub("(http|https):\/\/.*?( |'|\")","",new_post)
            # avoid words in two sentences merged together after removing spaces
            new_post = new_post.replace(".",". ")
            # remove useless numbers and punctuations
            new_post = re.sub(r"[0-9]+", "", new_post)
            new_post = new_post.translate(str.maketrans('', '', string.punctuation))
            # remove redundant empty spaces
            new_post = re.sub(" +"," ",new_post).strip()
            # make all characters lower
            new_post = new_post.lower()
            new_post
            temp = []
            # remove stopping words
            for word in new_post.split():
                if len(word) != 1 and word not in stopwords:
                    temp.append(word)
            user_posts.append(temp)
            if uid * 100 % n_users == 0:
                print("Done {}/{} = {}%".format(uid,n_users,uid*100/n_users))
        print("Finished generating word list")
        pickle.dump(user_posts,open(filename,"wb"))
    else:
        user_posts = pickle.load(open(filename,"rb"))
        print("Loaded user posts")
    return user_posts

## Generate BoW model

In [5]:
def generate_dict(user_posts,path=""):
    filename = os.path.join(path,"word_dict.npz")
    if not os.path.isfile(filename):
        word_lst = []
        for post in user_posts:
            word_lst += post

        # make dictionary (used for bag of words, BOW)
        word_counts = Counter(word_lst)
        word_counts["<UNK>"] = max(word_counts.values()) + 1
        # remove words that don't occur too frequently
        print("# of words before:",len(word_counts))
        for word in list(word_counts): # avoid changing size
            if word_counts[word] < 6:
                del word_counts[word]
        print("# of words after:",len(word_counts))
        # sort based on counts, but only remain the word strings
        sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)

        # make embedding based on the occurance frequency of the words
        int_to_word = {k: w for k, w in enumerate(sorted_vocab)}
        word_to_int = {w: k for k, w in int_to_word.items()}
        np.savez(filename,int2word=int_to_word,word2int=word_to_int)
    else:
        infile = np.load(filename,allow_pickle=True)
        int_to_word = infile["int2word"].item()
        word_to_int = infile["word2int"].item()
        print("Loaded {}".format(filename))
    n_words = len(int_to_word)
    print('Vocabulary size:', n_words)
    return word_to_int, int_to_word

In [6]:
def generate_bow(user_posts,word_to_int):
    filename = "bow.npy"
    if not os.path.isfile(filename):
        n_users = len(user_posts)
        n_words = len(word_to_int)
        feature = np.zeros((n_users,n_words))
        print(feature.shape)
        for uid, post in enumerate(user_posts):
            count = Counter(post)
            for key in count:
                feature[uid][word_to_int.get(key,0)] = count[key]
            if uid * 100 % n_users == 0:
                print("Done {}/{} = {}%".format(uid,n_users,uid*100/n_users))
        print("Finished generating BoW model")
        np.save(filename,feature)
        print("Saved {}".format(filename))
    else:
        feature = np.load(filename)
        print("Loaded BoW model")
    return feature

In [7]:
user_posts = generate_posts()
word2int, int2word = generate_dict(user_posts)
X = generate_bow(user_posts,word2int)

Loaded user posts
Loaded word_dict.npz
Vocabulary size: 27129
(8675, 27129)
Done 0/8675 = 0.0%
Done 347/8675 = 4.0%
Done 694/8675 = 8.0%
Done 1041/8675 = 12.0%
Done 1388/8675 = 16.0%
Done 1735/8675 = 20.0%
Done 2082/8675 = 24.0%
Done 2429/8675 = 28.0%
Done 2776/8675 = 32.0%
Done 3123/8675 = 36.0%
Done 3470/8675 = 40.0%
Done 3817/8675 = 44.0%
Done 4164/8675 = 48.0%
Done 4511/8675 = 52.0%
Done 4858/8675 = 56.0%
Done 5205/8675 = 60.0%
Done 5552/8675 = 64.0%
Done 5899/8675 = 68.0%
Done 6246/8675 = 72.0%
Done 6593/8675 = 76.0%
Done 6940/8675 = 80.0%
Done 7287/8675 = 84.0%
Done 7634/8675 = 88.0%
Done 7981/8675 = 92.0%
Done 8328/8675 = 96.0%
Finished generating BoW model
Saved bow.npy


## Random Forest model

In [8]:
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import classification_report
from sklearn import svm

In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
unique_X_val = []
for attr in range(X.shape[1]): # disadvantage
    unique_X_val.append(np.unique(X[:,attr]))

In [10]:
clf = svm.SVC(verbose=True)
clf.fit(X_train, y_train)
predict = clf.score(X_test, y_test)
print("Support Vector Machine (SVM) acc: {:.2f}%".format(predict * 100))

[LibSVM]Support Vector Machine (SVM) acc: 65.39%


In [11]:
pickle.dump(clf,open("svm.pkl","wb"))

In [12]:
Y_pred = clf.predict(X_test)
print(classification_report(Y_pred,Y_test,target_names=labels))

NameError: name 'Y_test' is not defined

In [13]:
print(classification_report(Y_pred,y_test,target_names=labels))

              precision    recall  f1-score   support

        INFJ       0.69      0.65      0.67       438
        ENTP       0.57      0.68      0.62       179
        INTP       0.82      0.59      0.69       561
        INTJ       0.69      0.67      0.68       326
        ENTJ       0.28      0.80      0.42        25
        ENFJ       0.22      0.61      0.32        18
        INFP       0.84      0.64      0.73       720
        ENFP       0.60      0.79      0.68       173
        ISFP       0.22      0.68      0.34        28
        ISTP       0.54      0.75      0.63        76
        ISFJ       0.42      0.85      0.56        26
        ISTJ       0.28      0.52      0.36        27
        ESTP       0.21      0.83      0.33         6
        ESFP       0.00      0.00      0.00         0
        ESTJ       0.00      0.00      0.00         0
        ESFJ       0.00      0.00      0.00         0

    accuracy                           0.65      2603
   macro avg       0.40   

  _warn_prf(average, modifier, msg_start, len(result))
