In [149]:
import numpy as np
import random
import string
import math
from collections import defaultdict
import json

In [150]:
def gen_word(word_length):
    word_len = np.random.randint(*word_length)
    return ''.join(random.sample(string.ascii_lowercase, word_len))

In [151]:
topic_size = 6
num_topics = 20

topics = [
    [gen_word((2,6)) for _ in range(topic_size)] for _ in range(num_topics)
]

In [152]:
topics

[['ikm', 'bxorj', 'kz', 'pw', 'vryts', 'qza'],
 ['pckid', 'xym', 'lhipy', 'ot', 'hz', 'rdv'],
 ['ip', 'zegrw', 'abs', 'oe', 'cyvb', 'ick'],
 ['yshl', 'jc', 'pdbl', 'tuh', 'ibfp', 'bz'],
 ['ysior', 'de', 'szlo', 'yn', 'gxc', 'wihz'],
 ['kspq', 'hj', 'dthmj', 'lmjvw', 'hay', 'yel'],
 ['fmbsc', 'azvh', 'pb', 'vrns', 'ibs', 'eab'],
 ['xze', 'xbfqy', 'ptems', 'dwz', 'hnfp', 'usa'],
 ['bdx', 'zb', 'ga', 'she', 'se', 'vw'],
 ['cqg', 'ysjeb', 'jnofp', 'hmtu', 'dop', 'eq'],
 ['vons', 'ilckw', 'kzedo', 'hlzb', 'rjb', 'tue'],
 ['fbx', 'dweup', 'qs', 'va', 'wyu', 'jvnrd'],
 ['bn', 'dfvs', 'vr', 'ax', 'vkb', 'kaqly'],
 ['ics', 'qlcw', 'lfvdi', 'ox', 'nw', 'qml'],
 ['anbit', 'xqiy', 'fspyj', 'wmqnh', 'awlt', 'fvb'],
 ['sy', 'lmwz', 'zgu', 'ksieq', 'tvwqx', 'dpfi'],
 ['czqbf', 'wvo', 'hq', 'lmowd', 'ufb', 'izc'],
 ['mtuqg', 'gbhc', 'uh', 'ye', 'uhpg', 'eycs'],
 ['mskrg', 'rmvzo', 'rljbq', 'tdjvo', 'xfs', 'wdgkm'],
 ['utlio', 'bp', 'roct', 'ivuf', 'iral', 'bt']]

In [153]:
def generate_sentence(m_topics):
    sent = []
    prefix_topics = 1
    suffix_topics = 1
    
    for i in range(prefix_topics):
        sent += random.sample(topics[random.choice(range(num_topics))], random.randint(0,2))
        
    sent += random.sample(topics[random.choice(m_topics)], random.randint(0,2)) + ['M'] + random.sample(topics[random.choice(m_topics)], random.randint(0,2))
    
    for i in range(suffix_topics):
        sent += random.sample(topics[random.choice(range(num_topics))], random.randint(0,2))
      
    return sent
    

In [154]:
m_topics = random.choices(range(num_topics), k=2)
m_topics, generate_sentence(m_topics)

([5, 7], ['mskrg', 'wdgkm', 'ptems', 'M', 'kspq', 'yel', 'qza', 'bxorj'])

In [157]:
def generate_to_file(file_path, count=1000):
    with open(file_path, 'w') as out:
        for i in range(count):
            m_topics = random.choices(range(num_topics), k=2)
            topics_str = "_".join(map(str, m_topics))
            for j in range(random.randint(2, 4)):
                sent = generate_sentence(m_topics)
                out.write(' '.join(sent) + "\t" + topics_str + '\n')

In [158]:
generate_to_file('../data/fake_data_train.tsv')
generate_to_file('../data/fake_data_valid.tsv', count=500)          
generate_to_file('../data/fake_data_test.tsv', count=500) 

In [159]:
with open('../data/fake_ft_data.tsv', 'w') as out:
    for i in range(1000):
        m_topics = random.choices(range(num_topics), k=2)
        topics_str = "_".join(map(str, m_topics))
        for j in range(random.randint(2, 4)):
            sent = generate_sentence(m_topics)
            out.write(' '.join(sent) + '\n')

In [160]:
#!pip install git+https://github.com/facebookresearch/fastText.git

In [161]:
import fastText

In [181]:
model = fastText.train_unsupervised(
    input='../data/fake_ft_data.tsv', minCount=0, bucket=1000, dim=30
)

In [182]:
len(model.get_words())

122

In [183]:
model.get_subwords('bpxre')

(['<bp',
  '<bpx',
  '<bpxr',
  '<bpxre',
  'bpx',
  'bpxr',
  'bpxre',
  'bpxre>',
  'pxr',
  'pxre',
  'pxre>',
  'xre',
  'xre>',
  're>'],
 array([ 963, 1117,  165,  196,  367,  815,  962,  676,  599,  842,  988,
         728,  658,  620]))

In [184]:
model.save_model("../data/fake_ft_model.bin")

In [185]:
from gensim.models import FastText
from gensim.models.utils_any2vec import ft_ngram_hashes
from gensim.models.utils_any2vec import compute_ngrams

In [186]:
ft = FastText.load_fasttext_format("../data/fake_ft_model.bin")

In [187]:
ft.save('../data/gensim_fake_ft.model')

In [188]:
ft2 = FastText.load('../data/gensim_fake_ft.model')

In [170]:
params = {
    "minn": ft2.wv.min_n,
    "maxn": ft2.wv.max_n,
    "num_buckets": ft2.wv.bucket,
    "fb_compatible": ft2.wv.compatible_hash
}

vocab = dict((word, keydvector.index) for word, keydvector in ft2.wv.vocab.items())
hash2index = ft2.wv.hash2index

In [171]:
def get_ids(word):
    
    if word in vocab:
        return np.array([vocab[word]])
    res = []
    for ngram_id in ft_ngram_hashes(word, **params):
        res.append(hash2index.get(ngram_id, ngram_id) + len(ft2.wv.vocab))
        
    return np.array(res)

In [172]:
get_ids('brace')

array([ 725,  832,  201,  128,  414, 1015,  602,  716,  503,  842,  180,
        594, 1012,  667])

In [173]:
np.concatenate([ft2.wv.vectors_vocab, ft2.wv.vectors_ngrams], axis=0).shape

(1122, 100)

In [174]:
ft2.wv.vectors_vocab.shape

(122, 100)

In [175]:
ft2.wv.vectors_vocab.shape[0] + ft2.wv.vectors_ngrams.shape[0]

1122

In [180]:
torch.rand(10) > 0.5

tensor([1, 0, 1, 0, 1, 0, 0, 0, 0, 0], dtype=torch.uint8)