In [None]:
import pandas as pd
import nltk
nltk.download('cmudict')
from nltk.corpus import cmudict
from profanity_check import predict, predict_prob

In [None]:
def num_syllables_word(word):
    try:
        # from: https://datascience.stackexchange.com/questions/23376/how-to-get-the-number-of-syllables-in-a-word
        return [len(list(y for y in x if y[-1].isdigit())) for x in cmu_dict[word.lower()]][0]
    except:
        # print('error, word {} not in cmudict'.format(word))
        words_not_in_cmu.append(word)
        return -1

def has_correct_syllables(haiku):
    for i, sentence in enumerate(haiku):
        num_syllabes = 0
        try:
            for word in sentence.split():
                num_s =  num_syllables_word(word)
                if (num_s != -1):
                    num_syllabes += num_syllables_word(word)
                else:
                    return False
            if i == 0 or i == 2:
                if num_syllabes != 5:
                    return False
            elif i == 1:
                if num_syllabes != 7:
                    return False
                    break
        except:
            return False
    return True

In [None]:
filepath = '../data/'
dataset = 'kaggle-jhalini-all_haiku.csv'

cmu_dict = cmudict.dict()
words_not_in_cmu = []

df = pd.read_csv(filepath+dataset)
df.drop(columns=df.columns[0], inplace=True) # Drop the first unnamed column

# Only use n rows for speeding up testing (all = df.shape[0])
nrows = df.shape[0]
df.drop(index=df.index[nrows:], inplace=True) 

# Remove misc characters (needs work, see words_not_in_cmu)
df.replace('-{2,}|-$|- |~|"|\.|;|^ +| +$|\'$', '', regex=True, inplace=True) 

# Make haikus lowercase
for columns in df[['0','1','2']]:
    df[columns] = df[columns].str.lower()

# Remove haikus that do not follow 5-7-5 syllable structure
all_haikus = df[['0','1','2']].to_numpy()
df = df[[has_correct_syllables(haiku) for haiku in all_haikus]]

# Remove haikus containing profane language
all_haikus = df[['0','1','2']].to_numpy()
profane_ids = [any(predict_prob(h) > 0.75) for h in all_haikus]
profane_haikus = df[profane_ids]
df = df[[not profane for profane in profane_ids]]

print('Total haikus: {}'.format(nrows))
print('Filtered Haikus: {} ({}%)\n'.format(df.shape[0], round(df.shape[0] / nrows * 100, 2)))
df.to_csv(filepath+'filtered_haikus.csv')

In [None]:
profane_haikus.to_csv(filepath+'profane_haikus.csv')

In [None]:
words_not_in_cmu