In [1]:
import os
from argparse import Namespace
from gensim.models.fasttext import FastText
from gensim.models.callbacks import CallbackAny2Vec
import matplotlib.pyplot as plt

In [2]:
args = Namespace(
    
    ## Input params
    traning_corpus_filepath = './data/clean_data/word2vec_training_corpus.txt',
    
    ## Model params
    epochs = 5,
    embed_dim = 100,
    min_count = 40,
    window = 40,
    
    ## Output params
    model_dir = './models',    
    model_path = './models/fastText_model.model'
)

os.makedirs(args.model_dir, exist_ok = True)

In [3]:
class ReadSentence:
    
    def __init__(self, filepath):
        self.filepath = filepath
                
    def __iter__(self):
        with open(args.traning_corpus_filepath) as f:
            for line in f.readlines():
                yield eval(line)

In [4]:
# class EpochSaver(CallbackAny2Vec):
    
#     def __init__(self):
#         self.batch_loss = []
#         self.epoch_loss = []
#         self.prev_loss = 0
        
#     def on_batch_end(self,model):
#         loss = model.running_training_loss
#         self.batch_loss.append(loss)
        
#     def on_epoch_end(self, model):
#         loss = model.running_training_loss
#         self.epoch_loss.append(loss)

In [5]:
## sentence iterator
sentence_reader = ReadSentence(args.traning_corpus_filepath)
# epochSaver = EpochSaver()

In [6]:
model = FastText(size= args.embed_dim, window= args.window, min_count= args.min_count)
model.build_vocab(sentences=sentence_reader)
total_examples = model.corpus_count

In [7]:
model.train(sentences=sentence_reader, 
            total_examples=total_examples, 
            epochs= args.epochs)

In [8]:
## save model
model.save(args.model_path)

## Evaluation

In [9]:
## Load pretrained model 
model = FastText.load(args.model_path)

In [10]:
model.wv.most_similar('awful')

[('terrible', 0.8721861839294434),
 ('horrible', 0.8077470064163208),
 ('awfulness', 0.7925478219985962),
 ('terribly', 0.7722146511077881),
 ('awfully', 0.7679882049560547),
 ('horribly', 0.757067859172821),
 ('dreadful', 0.7355570793151855),
 ('atrocious', 0.6969504356384277),
 ('lousy', 0.6741313934326172),
 ('amateurish', 0.6654798984527588)]

In [11]:
model.wv.most_similar('bad')

[('badge', 0.8372039794921875),
 ('terrible', 0.66477370262146),
 ('sinbad', 0.6606861352920532),
 ('horrible', 0.6519021391868591),
 ('baddie', 0.6486789584159851),
 ('baddies', 0.6227812767028809),
 ('awful', 0.6187651753425598),
 ('badly', 0.617999792098999),
 ('crappy', 0.5904452800750732),
 ('bag', 0.5749733448028564)]

In [12]:
model.wv.most_similar('good')

[('goodbye', 0.8806805610656738),
 ('goods', 0.7243682146072388),
 ('gooding', 0.6606674194335938),
 ('goodman', 0.6574000716209412),
 ('goofs', 0.6515835523605347),
 ('google', 0.6307097673416138),
 ('decent', 0.6285086870193481),
 ('goof', 0.6167135834693909),
 ('goodfellas', 0.6052292585372925),
 ('ok', 0.5743190050125122)]