# Setup notebook, imports and predefined functions

## Notebook magics

In [1]:
%load_ext autoreload
%autoreload 2

## Some imports

In [2]:
from nhpylm.lexicon import build_fst_for_lexicon
from nhpylm import fst
import os
from nhpylm.c_core import nhpylm
from tqdm import tqdm
from nhpylm.kaldi_data_preparation import convert_transcription, word_to_grapheme
import json
from nhpylm import json_utils as ju
from copy import deepcopy
from nhpylm.kaldi import get_kaldi_env
from nhpylm.process_caller import run_processes

## Output directory

In [3]:
output_directory = 'lattice_playground/wcl-l-e-p-m/'
os.makedirs(output_directory, exist_ok=True)

## Some predefined functions

### Combined write and print/display function

In [4]:
def write_and_print_fst(fst_graph, fst_filename, sym_filename, **kwargs):
    print_fst_kwargs = {'determinize': False}
    print_fst_kwargs.update(kwargs)
    fst_graph.write_fst(fst_filename, **print_fst_kwargs)
    return fst.print(fst_filename, sym_filename, sym_filename)

### Convert a list of sentences to a list of list of list of units:

In [5]:
def text_to_splitted_words(text):
    return [convert_transcription(line, word_to_units=word_to_grapheme(join=False))[1] for line in text]

### Return all unique units from the converted sentences

In [6]:
def find_symbols(text):
    return {symbol for line in text for word in line for symbol in word}

### Write a symbol list to a symbol file mapping all symbols to integers from 1 to N_symbols

In [7]:
def write_symbols(symbols, sym_file):
    with open(sym_file, 'w') as fid:
        for i, s in enumerate(symbols):
            fid.write('{} {}\n'.format(s, i))

## Wrapper to run fst command easily

In [8]:
def run_cmd(cmds, inputs=None, env=get_kaldi_env()):
    run_processes(cmds, inputs, environment=env)

## For Simulation purpose: Generate a monophone structure given a hmm prototype
- Creates and adds the hmm input symbols to the end of the word list

In [9]:
def get_monophones(hmm_prob_proto, int_labels, emitions_per_state, word_list):
    hmm_states = hmm_prob_proto.keys()
    emit_symb = len(word_list)
    monophones = dict()
    for label in int_labels:
        monophones[label] = dict()
        hmm = deepcopy(hmm_prob_proto)
        for hmm_state in hmm_states:
            if hmm_state >= 0:
                emitions_state = list()
                for emition in range(emitions_per_state):
                    emitions_state.append((emit_symb, 1/num_emitions_per_state))
                    emit_symb += 1
                    word_list.append('{}-{}-{}'.format(word_list[label], hmm_state, emition))
                hmm[hmm_state]['observations'] = tuple(emitions_state)
        monophones[label] = hmm
    return monophones

# Get train data, split into characters and get symbols

In [10]:
train_data = ['Martian Marsman', 'man on Mars']
train_data_splitted = text_to_splitted_words(train_data)
symbols = find_symbols(train_data_splitted)

# Instantiate LM, add training sentences and resample hyper parameters

In [11]:
lm = nhpylm.NHPYLM_wrapper(list(symbols), 2, 8)

train_data_ids = lm.word_lists_to_id_lists(train_data_splitted)
for line in tqdm(train_data_ids):
    lm.add_id_sentence_to_lm(line)
    
lm.resample_hyperparameters()



# Get int versions of lexicon and labels

In [12]:
int_lexicon = lm.get_word_id_to_char_id()
int_eos_word = lm.sentence_boundary_id
int_labels = lm.get_char_ids()
int_eps = lm.sym2id('EPS')
int_eow = lm.sym2id('EOW')
int_eoc = lm.sym2id('EOC')
int_phi = lm.sym2id('PHI')
int_eos_label = lm.sym2id('EOS')

# Get word list

In [13]:
word_list = lm.string_ids

# Build discrete HMM for characters

## Some parameters and model prototype

In [14]:
num_emitions_per_state = 2
hmm_prob_proto = { -1: {'transitions': ((0, 1),)}, 0: {'transitions': ((0, 2/3), (1, 1/3))},
                    1: {'transitions': ((1, 2/3), (2, 1/3))}, 2: {'transitions': ((2, 2/3), (-1, 1/3))}}

## Build hmm and extend word list by hmm input symbols

In [15]:
monophones = get_monophones(hmm_prob_proto, int_labels, num_emitions_per_state, word_list)

# Write symbol file

In [16]:
sym_filename = output_directory + 'symbols.txt'

write_symbols(word_list, sym_filename)

# Write and print monophone transducer

In [17]:
H_fst_filename = output_directory + 'H.fst'
fst_monophones = fst.build_monophone_fst(monophones, int_eps)
fst_monophones.add_self_loops(int_eps, int_eos_label, int_eos_label)
fst_monophones.add_self_loops(int_eps, int_eow, int_eow)
fst_monophones.add_self_loops(int_eps, int_eoc, int_eoc)

write_and_print_fst(fst_monophones, H_fst_filename, sym_filename, minimize=False, determinize=False, sort_type='olabel')

# Build and print FST for lexicon

In [18]:
L_fst_filename = output_directory + 'L.fst'
mode = 'trie'
build_character_model = True
fst_lexicon = build_fst_for_lexicon(int_lexicon, int_eps, int_eow, build_character_model,
                                    mode, int_labels, eoc=int_eoc)
fst_lexicon.add_eos(int_eos_label, int_eos_word)
write_and_print_fst(fst_lexicon, L_fst_filename, sym_filename, minimize=False, determinize=False, sort_type='ilabel')



# Get FST for language model

In [19]:
G_fst_filename = output_directory + 'G.fst'
_, arc_list = lm.to_fst_text_format(eow=int_eoc)
G_fst = fst.build_fst_from_arc_list(arc_list)
G_fst.write_fst(G_fst_filename, minimize=False, determinize=False, rmepsilon=False)
fst.print(G_fst_filename, sym_filename, sym_filename)



# Character sequence

## Some character sequence

In [20]:
character_sequence = 'MartianMarsman'

## Build and print FST for character sequence

In [21]:
int_sequence = [lm.sym2id(character) for character in character_sequence]

I_fst_filename = output_directory + 'I.fst'
character_sequence_fst = fst.build_fst_for_sequence(int_sequence + [int_eos_label])
write_and_print_fst(character_sequence_fst, I_fst_filename, sym_filename, minimize=False, determinize=False)

## Compose monophone transducer with character sequence

In [22]:
H_I_fst_filename = output_directory + 'H_I.fst'
fst.compose(H_fst_filename, I_fst_filename, H_I_fst_filename,
            determinize=False, minimize=False, rmepsilon=False, sort_type='olabel')
fst.print(H_I_fst_filename, sym_filename, sym_filename)

## Randomly generate observed sequence

In [23]:
O_fst_filename = output_directory + 'O_fst'
fst.randgen(H_I_fst_filename, O_fst_filename, project=True, project_output=False, rmepsilon=True)
fst.print(O_fst_filename, sym_filename, sym_filename)

## Add loops for word end/disambigutity symbols (eow and eoc)

In [24]:
O_loop_fst_filename = output_directory + 'O_loop.fst'
run_cmd(fst.fstaddselfloops_cmd(O_fst_filename, O_loop_fst_filename, [int_eow, int_eoc], [int_eow, int_eoc]))
fst.print(O_loop_fst_filename, sym_filename, sym_filename)

# Do the final compositions

## Compose monophone transducer with lexicon

In [25]:
H_L_fst_filename = output_directory + 'H_L.fst'
fst.compose(H_fst_filename, L_fst_filename, H_L_fst_filename,
            determinize=False, minimize=False, rmepsilon=False, sort_type='olabel')
fst.print(H_L_fst_filename, sym_filename, sym_filename)

## Compose monophone transducer and lexicon with input sequence FST

In [26]:
O_loop_H_L_fst_filename = output_directory + 'O_loop_H_L.fst'
fst.compose(O_loop_fst_filename, H_L_fst_filename, O_loop_H_L_fst_filename,
            determinize=False, minimize=False, rmepsilon=False, sort_type="olabel")
fst.print(O_loop_H_L_fst_filename, sym_filename, sym_filename)

## Compose monophone transducer and lexicon and input sequence FST with language model FST

In [27]:
O_loop_H_L_G_fst_filename = output_directory + 'O_loop_H_L_G.fst'
fst.compose(O_loop_H_L_fst_filename, G_fst_filename, O_loop_H_L_G_fst_filename,
            determinize=False, minimize=False, rmepsilon=False, phi=int_phi)
fst.print(O_loop_H_L_G_fst_filename, sym_filename, sym_filename)

## Get shortest path(s)

In [28]:
O_loop_H_L_G_shortestpath_fst_filename = output_directory + 'O_loop_H_L_G_shortestpath.fst'
fst.shortestpath(O_loop_H_L_G_fst_filename, O_loop_H_L_G_shortestpath_fst_filename, nshortest=1,
            determinize=False, minimize=False, rmepsilon=True, project=True, project_output=True)
fst.print(O_loop_H_L_G_shortestpath_fst_filename, sym_filename, sym_filename)