# Myers-Briggs Personality Type Prediction


# 1 - Packages #

Let's first import all the packages that you will need.

In [1]:
# Package imports
import numpy as np
import tensorflow as tf
import re
from nltk.stem.snowball import SnowballStemmer
import nltk
# nltk.download()

# 2 - Dataset #

Let's get the dataset we will work on.

## 2.1 - Preprocessing helper functions ##

In [8]:
import string
from nltk.corpus import words

# English Lexicon
lexicon = {}
for word in words.words():
    lexicon[word] = True
    
# Data store
data = []

# Lexicon
words_dict = {}
word_val = 0

# Personity types 
personality_type_dict = {}
personality_type_val = 0

# Snowball stemmer
stemmer = SnowballStemmer("english")


# Removes url, punctuation, and digits
def post_clean_up(post):
    remove_url = re.sub(r'\w+:\/{2}[\d\w-]+(\.[\d\w-]+)*(?:(?:\/[^\s/]*))*', '', post)
    remove_punc = remove_url.translate(str.maketrans('', '', string.punctuation))
    remove_digit = re.sub(r'\d+', '', remove_punc)
    remove_digit = remove_digit.strip()
    return remove_digit

# Applies snow ball stemmer and inserts root word to words_dict
def apply_snow_ball_stemmer(post):
    global word_val
    processed_post = ""
    for word in post.split():
        if word not in lexicon:
            continue
        root_word = stemmer.stem(word)

        if root_word not in words_dict:
            words_dict[root_word] = word_val
            word_val += 1

        processed_post += " "+ root_word

    processed_post = processed_post.strip()
    return processed_post

## 2.1 - Preprocessing mbti dataset ##

In [9]:


# Open mbti file
resource_location = "../data/mbti_1.csv"
file = open(resource_location, 'r')
lines = file.readlines()[1:]

mean_length = 0

for line in lines:
    
    personality_type, _, posts = line.partition(",")
    
    if personality_type not in personality_type_dict:
        personality_type_dict[personality_type] = personality_type_val
        personality_type_val += 1
    
    for post in posts.split("|||"):
        # Removing URLs, punctuation, and digits
        post = post_clean_up(post)
        
        # Filter out posts with length less than 10
        if len(post) > 10:
            
            # Apply Snowball stemmer
            post = apply_snow_ball_stemmer(post)
            
            
            # Filter out posts with length less than 10
            if len(post) > 10:
                mean_length += len(post)
                example = [post, personality_type]
                data.append(example)

        
            
data = np.asarray(data)
print(data.shape, mean_length/data.shape[0])

(381453, 2) 106.30440447446999


In [13]:
import json
processed_mbti_dict_resource_location = "../data/processed_mbti.csv"
words_dict_resource_location = '../data/words_dict.json'
personality_type_dict_resource_location = '../data/personality_dict.json'


def save_csv(data, resource_location):
    np.savetxt(resource_location, data, fmt='%s, %s')


def save_json(data, resource_location):
    with open(resource_location, 'w') as fp:
        json.dump(data, fp)
        fp.close()

# Save processed mbti data
save_csv(data, processed_mbti_dict_resource_location)
# Save words_dict
save_json(words_dict, words_dict_resource_location)
# Save words_dict
save_json(personality_type_dict, personality_type_dict_resource_location)

