# Morfessor Implementation (Inverted Words)



## Outline
  - Objective : find constructions
      - Compounds = words
      - Constructions = morphs
      - Atoms = letters/characters
  - Components 
      - Cost Function
      - Training
      - Decoding
  - Model
      - Lexicon, Grammar
      - Independence assumption vis-a-vis constructions
      - MAP estimate - Likelihood + MDL-Prior
  - Training Algo
      - Greedy & Local Search
      - Starts with inital lexicon and tries to find optimal segmentation     


## Estimate
$argmax_M\,P(M|corpus)\,=\,argmax_M\,P(corpus|M)\,P(M)$

$P(M)\,=\,P(lexicon,grammar)=\,P(lexicon)$ (for baseline)

### Prior Probability - MDL Formulation
Supposing that there are $L$ different morphs,

$P(lexicon)\,=\,L!\,P(properties(\mu_1)...properties(\mu_L))$

$L!$ ways to order the list, properties - frequency, string of letters


$P(properties(\mu_1)...properties(\mu_L))\,=\,P(f_{\mu_1},..f_{\mu_L}).P(s_{\mu_1},..s_{\mu_L})$

$P(f_{\mu_1},..f_{\mu_L})\,=\,\frac{(L-1)!(N-L)!}{(N-1)!}$ - Implicit frequency modeling - Appendix A

where $N=\Sigma_{j=1}^{L} f_{\mu_j}$ Total number of morph tokens

$P(s_{\mu_1},..s_{\mu_L})\,=\,\Pi_{i=1}^{L} P(s_{\mu_i})\,=\,\Pi_{i=1}^{L} \Pi_{j=1}^{l_{\mu_i}} P(c_{ij})$ - Probability of each character multiplied - Implicit modeling of length with '#' marker added to each morph at the end

### Likelihood - MLE Formulation
$P(corpus|M)\,=\,\Pi_{j=1}^{W}\Pi_{k=1}^{n_j}P(\mu_{jk})$ - There are $W$ words, each word split into $n_j$ morphs

$P(\mu_i)\,=\,\frac{f_{\mu_i}}{N}\,=\,\frac{f_{\mu_i}}{\Sigma_{j=1}^{L} f_{\mu_j}}$

### Putting it all together
$argmax_M\,P(M|corpus)\,=\,argmax_M\,P(corpus|M)\,P(M)$


$argmax_M\,P(M|corpus)\,=\,argmax_M\,\Pi_{j=1}^{W}\Pi_{k=1}^{n_j}P(\mu_{jk})\,.\,L!\frac{(L-1)!(N-L)!}{(N-1)!}.\Pi_{i=1}^{L} \Pi_{j=1}^{l_{\mu_i}} P(c_{ij})$

In [131]:
class Node:
    def __init__(self, form, count=0, leaf=False):
        self.form = form
        self.count = count
        self.leaf = leaf
        self.splitloc = 0
        
    def __str__(self):
        bare_form = f"{self.form}/{self.count}"
        if self.splitloc == 0:
            return bare_form
        splits = self.form[:self.splitloc], self.form[self.splitloc:]
        bare_form = bare_form + f" ({self.splitloc}) {splits}"
        return f"({bare_form})"
    
    def __repr__(self):
        return self.__str__()
    

In [151]:
#data_file = 'data/wikitext-2/train.txt'
#data_file = 'test_corpus.txt'
data_file = 'data/tel/train.txt'
#! wc -l data/tel/train.txt

In [152]:
nodes = {}
charmap = {}
num_tokens = 0
import re 
by_space = re.compile('\s+')
with open(data_file, errors='ignore') as f:
    for line in f:
        line = line.strip()
        if len(line) > 0:
            line = line.replace("#", '')
            line = by_space.split(line)
                       
            for tok in line:
                if tok not in nodes:
                    nodes[tok] = Node(tok, leaf=True)
                for c in tok:
                    if c not in charmap:
                        charmap[c] = 0
                    charmap[c] += 1
                nodes[tok].count += 1
                num_tokens += 1
                

In [153]:
import math
def print_nodes(nodes):
    for tok, node in nodes.items():
        print(node)
        
def counts_to_logprobs(counts):
    total_count = sum(counts.values())
    logprobs = {}
    for key, value in counts.items():
        logprobs[key] = -math.log(value/total_count)
    return logprobs


def stirlings_approximation(n):
    return n * math.log(n) - n + 0.5*(math.log(n) + math.log(2*math.pi))

def log_fact(n):
    if n < 2:
        return 0
    if n < 20:
        return math.log(math.factorial(n))
    return stirlings_approximation(n)

def implicit_frequency(num_types, num_tokens):
    """
        P(F) = ((L-1)! x (N-L) !) / (N-1) !
        C(F) = -(log((L-1)!) + log((N-L)!) - log((N-1)!))
    """ 
    return log_fact(num_tokens - 1) - log_fact(num_tokens - num_types) - log_fact(num_types -1)


# Currently not adding '#' character at the end. Need to add that and see how it works out.
def implicit_length_cost(types, charmap):
    """
        C(S) = Sigma_i_|words| ( Sigma_j_|wi| (-log(P(c_ij))) )
    """
    total_cost = 0
    for t in types:
        for char in t:
            total_cost += charmap[char]
    return total_cost

def lexicon_cost(num_tokens, nodes, charmap):
    
    return  implicit_frequency(len(nodes), num_tokens) + implicit_length_cost(nodes, charmap) -  log_fact(len(nodes))


def corpus_cost(node_probs):
    total_cost = 0
    with open(data_file) as f:
        for line in f:
            line = line.strip()
            if len(line) > 0:
                line = by_space.split(line)
                for tok in line:
                    total_cost += node_probs[tok]
    return total_cost

def corpus_cost_eff(nodes, num_tokens):
    """
        compute Sigma_i_L(f(i) * log(f(i))) - NlogN
    """
    
    node_counts = {k : v.count*math.log(v.count) for k, v in nodes.items()}
    return -(sum(node_counts.values()) - num_tokens * math.log(num_tokens))
    

In [154]:
num_types = len(nodes.keys())
charmap_logs = counts_to_logprobs(charmap)

In [155]:
print(implicit_frequency(num_types, num_tokens))
corpus_cost_eff(nodes, num_tokens)

582064.0383753171


11878260.327121038

## Update Rules

### Removing a node t

1. Frequency Cost:
 * Number of types decreases by 1.    
    $L' = L - 1$      
 * Number of tokens decreases by f(t).   
    $N' = N - f(t)$        
 * $C(F)' = -(log((L'-1)!) + log((N'-L')!) - log((N'-1)!))$
  
  
2. Length Cost:
  * $C(S)' = C(S) - \Sigma_{i}^{|t|} -log(p(ci)$
  
  
3. Corpus cost:
  * Number of tokens decreases by f(t).   
     $N' = N - f(t)$        
  * A decrease in the total count by f(t)
  * $C(corpus)  = NlogN - \Sigma_i^L (f(i) * log(f(i)))$
  * $C(corpus)' = C(corpus) + f(t)*log(f(t)) - NlogN + N'logN'$

### Adding a node t with parent u

1. Frequency Cost:
 * If t is new, Number of types increases by 1.    
    $L' = L + 1$      
 * Number of tokens increases by f(u).   
    $N' = N + f(u)$        
 * $C(F)' = -(log((L'-1)!) + log((N'-L')!) - log((N'-1)!))$
  
  
2. Length Cost (if t is new):
  * $C(S)' = C(S) + \Sigma_{i}^{|t|} -log(p(ci)$
  
  
3. Corpus cost:
  * Number of tokens increases by f(u).   
     $N' = N - f(u)$        
  * An increase in the total count by f(u)
  * $C(corpus)  = NlogN - \Sigma_i^L (f(i) * log(f(i)))$
  * $C(corpus)' = C(corpus) - f(u)*log(f(u)) - NlogN + N'logN'$
  
   
   
   


In [156]:
class MorfessorModel:
    def __init__(self, charmap_logs):
        self.frequency_cost = 0
        self.length_cost = 0
        self.corpus_cost = 0
        self.total_cost = 0
        
        self.num_types = 0
        self.num_tokens = 0
        
        self.charmap_logs = charmap_logs
        
        self.nodes = {}
        
        
   
        
    def add_node(self, token, count, parent=None, debug=False):
        num_types_new = self.num_types
        old_count = 0
        node_count = count
        if debug:
            print("A Cost Before", self.corpus_cost, self.frequency_cost, self.length_cost)

        
        if token not in self.nodes:
            self.num_types = self.num_types + 1
            self.length_cost = self.length_cost + sum([self.charmap_logs[char] for char in token if char != '_'])
            self.nodes[token] = Node(token, count)
        
        else:
            if self.nodes[token].count == 0:
                self.num_types += 1

            else:
                count += self.nodes[token].count # Update the count            
                old_count = self.nodes[token].count * math.log(self.nodes[token].count)


            self.nodes[token].count = count

               
        num_tokens_new = self.num_tokens + node_count
        
        
        self.frequency_cost = implicit_frequency(self.num_types, num_tokens_new)    
        
        if debug:
            print("A components ", old_count, - count * math.log(count), "delta : ", old_count - count * math.log(count), - ((self.num_tokens * math.log(self.num_tokens)) if self.num_tokens > 0 else 0), (num_tokens_new * math.log(num_tokens_new))  )

        self.corpus_cost = self.corpus_cost + old_count \
                                            - count * math.log(count) \
                                            - ((self.num_tokens * math.log(self.num_tokens)) if self.num_tokens > 0 else 0) \
                                            + (num_tokens_new * math.log(num_tokens_new))
        
        if debug:
            print("A Cost After", self.corpus_cost, self.frequency_cost, self.length_cost)

        self.num_tokens = num_tokens_new
        self.total_cost = self.corpus_cost + self.frequency_cost + self.length_cost
        
        
    def remove_node(self, token, decrease_by, debug = False):
        node = self.nodes[token]
        count = node.count - decrease_by
        
        if debug:
            print("R Cost Before", self.corpus_cost, self.frequency_cost, self.length_cost)

        
        if count == 0:
            self.num_types = self.num_types - 1
            self.length_cost = self.length_cost - sum([self.charmap_logs[char] for char in node.form if char != '_'])
            
        num_tokens_new = self.num_tokens - decrease_by
    
        self.frequency_cost = implicit_frequency(self.num_types, num_tokens_new)               
        
        self.corpus_cost = self.corpus_cost + node.count * math.log(node.count) \
                                            - (count * math.log(count) if count > 0 else 0) \
                                            - (self.num_tokens * math.log(self.num_tokens)) \
                                            + (num_tokens_new * math.log(num_tokens_new))
        
        if debug:
            print("R Cost After", self.corpus_cost, self.frequency_cost, self.length_cost)

        
        self.total_cost = self.corpus_cost + self.frequency_cost + self.length_cost
        self.num_tokens = num_tokens_new
        
        self.nodes[token].count -= decrease_by
        
        if self.nodes[token].count <= 0:
            del self.nodes[token]
        
        
    def resplit(self, token, debug=False):
        if debug:
            print("Starting point", self.corpus_cost, self.frequency_cost, self.length_cost, self.total_cost)
        node = self.nodes[token]        
        node_count = node.count
        best_split = [self.total_cost, 0]
        
        threshold = 0
        
        self.remove_node(token, decrease_by = node_count)
        
        for i in range(1,len(token)):
            pre, suf = token[:i], token[i:]
            
            if debug:
                print(pre, suf)
                print("Cost before", self.total_cost)
                
            for subnode in (pre, suf):
                self.add_node(subnode, node_count, parent = node)
            
            if debug:
                print("After Adding", self.corpus_cost, self.frequency_cost, self.length_cost, self.total_cost)
                print(self.total_cost, best_split[0])
            
            if self.total_cost < best_split[0] + threshold:
                if debug:
                    print(f"{pre}+{suf} gives the best_split with {self.total_cost} < {best_split}")
                best_split[0] = self.total_cost
                best_split[1] = i
                
                        
            # Restoring the DataStructure
            for subnode in (pre, suf):
                if subnode in self.originals:
                    oc = self.nodes[subnode].count
                    
                self.remove_node(subnode, decrease_by = node_count)
                
                if self.nodes[subnode].count <= 0 and self.nodes[subnode].splitloc == 0:
                    
                    del self.nodes[subnode]        

       
        if best_split[1] == 0:
            # No split 
            self.add_node(node.form, node_count)
            if debug:
                print("Bestsplit", node.form, best_split, node.form)
        else:
            
            prev = None
            for subnode in (node.form[:best_split[1]], node.form[best_split[1]:]):
                self.add_node(subnode, node_count)
                if subnode != prev:
                    if self.nodes[subnode].splitloc == 0:
                        self.resplit(subnode)
                prev = subnode
            node.splitloc = best_split[1]
            
        

In [157]:
m = MorfessorModel(charmap_logs)
import random
tokens = [(k,v.count) for k, v in nodes.items()]
#random.shuffle(tokens)

In [158]:
import tqdm
m.originals = {}
for token, count in tqdm.tqdm(tokens):
    m.add_node(token, count)
    m.originals[token] = count

100%|██████████| 218755/218755 [00:01<00:00, 124539.43it/s]


In [159]:
m.num_tokens, m.total_cost, m.num_types


(1262581, 19748608.96662646, 218755)

In [160]:
def run_once(data_file, debug):
    new_candidate = False
    with open(data_file, errors='ignore') as f:
        for i, line in tqdm.tqdm(enumerate(f)):
            line = line.strip()
            if len(line) > 0:
                #line = line.replace("#", '')
                line = by_space.split(line)
                #line = "#".join(line)
                prev = line[0]
                for tok in line[1:]:
                    candidate = prev+'_'+tok
                    new_candidate = False

                    if prev not in m.nodes:
                        prev = candidate
                    elif tok in m.nodes:

                        # Current cost
                        old_cost = m.total_cost
                        cand_count = 0
                        if candidate in m.nodes:
                            cand_count = 1
                        if cand_count == 0 and (candidate in m.originals):
                            cand_count = m.originals[candidate] + 1
                        elif cand_count == 0:
                            cand_count = 1

                        m.remove_node(prev, cand_count, debug=debug)
                        if tok not in m.nodes:
                            # Prev=Tok removing it is not possible so, revert back to normal for now
                            m.add_node(prev, cand_count)
                            prev = tok
                            continue
                        
                        m.remove_node(tok, cand_count, debug=debug)
                    
                        if candidate not in m.nodes:
                            new_candidate = True
                            # Now let us add candidate
                        m.add_node(candidate, cand_count, debug = debug)
                        # Retain the node if the cost is lower
                        if m.total_cost < old_cost:
                            prev = candidate
                        else:
                            # Revert back !
                            m.remove_node(candidate, cand_count, debug=debug)
                            m.add_node(prev, cand_count, debug=debug)
                            m.add_node(tok, cand_count, debug=debug)

                            prev = tok

                    if candidate not in m.originals:
                        m.originals[candidate] = 0
                    m.originals[candidate] += 1



In [161]:
run_once(data_file, debug=False)
# Reset Originals
m.originals = {}
for k, v in m.nodes.items():
    m.originals[k] = v.count


# prev_cost = m.total_cost
# it = 0
# while True:
#     run_once(data_file, debug=False)
#     # Reset Originals
#     m.originals = {}
#     for k, v in m.nodes.items():
#         m.originals[k] = v.count
        
#     if m.total_cost >= prev_cost:
#         break
        
#     print(f" Iteration {it}: {m.total_cost}, {m.num_types}")

100000it [00:38, 2606.33it/s]


In [162]:
m.num_tokens, m.total_cost, m.num_types

(1102523, 18998124.517780695, 195490)

In [168]:
sorted({k : v for k, v in m.nodes.items() if k.count('_') > 1 and v.count > 5}.items(), key = lambda x : x[1].count, reverse = True)

[('avakASaM_uMxi_.', avakASaM_uMxi_./124),
 ('waxiwarulu_pAlgoVnnAru_.', waxiwarulu_pAlgoVnnAru_./124),
 ('[_mArcu_]', [_mArcu_]/112),
 ('spaRtaM_ceSAru_.', spaRtaM_ceSAru_./104),
 ('saMgawi_weVlisiMxe_.', saMgawi_weVlisiMxe_./92),
 ('viRayaM_weVlisiMxe_.', viRayaM_weVlisiMxe_./89),
 ("'_'_ani", '_'_ani/81),
 ('dimAMd_ceSAru_.', dimAMd_ceSAru_./81),
 ('avasaraM_lexu_.', avasaraM_lexu_./74),
 ('erpAtu_ceSAru_.', erpAtu_ceSAru_./68),
 ('Ayana_annAru_.', Ayana_annAru_./67),
 ('vyakwaM_ceSAru_.', vyakwaM_ceSAru_./64),
 ('hAmI_iccAru_.', hAmI_iccAru_./51),
 (',_nyUstude_:', ,_nyUstude_:/37),
 ('kaligi_uMtuMxi_.', kaligi_uMtuMxi_./37),
 ('AgrahaM_vyakwaM_ceSAru', AgrahaM_vyakwaM_ceSAru/36),
 (',_namaswe_weVlaMgANa_:', ,_namaswe_weVlaMgANa_:/34),
 ('gurwu_ceSAru_.', gurwu_ceSAru_./33),
 ('Avexana_vyakwaM_ceSAru', Avexana_vyakwaM_ceSAru/32),
 ('(_AMXrajyowi_)_:', (_AMXrajyowi_)_:/31),
 ('AXArapadi_uMtuMxi_.', AXArapadi_uMtuMxi_./31),
 ('spaRtaM_cesiMxi_.', spaRtaM_cesiMxi_./29),
 ('vyakwaM_ces

In [136]:
import numpy

def write_to_file(data_file):
    to_file = f"{data_file}.wtok"
    with open(data_file, errors='ignore') as f, open(to_file, 'w') as tf:
        for i, line in tqdm.tqdm(enumerate(f)):
            line = line.strip()
            if len(line) > 0:
                #line = line.replace("#", '')
                line = by_space.split(line)
                #line = "#".join(line)
                prev = line[0]
                new_line = [prev]
                
                for tok in line[1:]:

                    candidate = prev+'_'+tok
                    prev_cost, tok_cost = numpy.inf, numpy.inf
                    if prev in m.nodes:
                        prev_cost = -math.log(m.nodes[prev].count/m.num_tokens)
                    if tok in m.nodes:
                        tok_cost = -math.log(m.nodes[tok].count/m.num_tokens)

                    individual_cost = prev_cost + tok_cost
                    if candidate in m.nodes:
                        candidate_cost = -math.log(m.nodes[candidate].count/m.num_tokens)
                        if candidate_cost < individual_cost:
                            prev = candidate
                            new_line[-1] = prev
                        else:
                            prev = tok
                            new_line.append(tok)
                    else:
                            prev = tok
                            new_line.append(tok)
                    
                tf.write(' '.join(new_line) + '\n')
#                     print('/'.join(new_line), end=' ')
#                 print("\n"+"-"*20)

#             if i > 10:
#                 break
                   

In [137]:
write_to_file(data_file)

36718it [00:03, 10249.59it/s]


In [46]:
len([k for k in m.nodes if m.nodes[k].count > 0])

9171

In [47]:
def get_repr(nodes, node):
    
    base_form = f" {node.form}/{node.count} "
    if node.splitloc == 0:
        return base_form
        
    splits = node.form[:node.splitloc], node.form[node.splitloc:]
    
    s_reprs = ""
    if node.form in m.originals:
        s_reprs = f" <{m.originals[node.form]}>"
        
    for subnode in splits:
        subnode = nodes[subnode]
        s_reprs += f" ({get_repr(nodes, subnode)}) "
    
    return f"{base_form} {s_reprs}"
        
    
def recursive_display(nodes):
    for token, node in nodes.items():
        print(get_repr(nodes, node))

In [48]:
recursive_display(m.nodes)

 =/3081 
 Valkyria/54 
 Chronicles/39 
 III/45 
 Senjō/5 
 no/128 
 3/251 
 :/285 
 <unk>/5522 
 (/1064 
 Japanese/8 
 戦場のヴァルキュリア3/3 
 ,/10665 
 lit/14 
 ./7413 
 of/5705 
 the/11370 
 Battlefield/0   <4> ( Battle/35 )  ( field/47 ) 
 )/1066 
 commonly/0   <7> ( common/47 )  ( ly/513 ) 
 referred/24 
 to/4114 
 as/1542 
 outside/32 
 Japan/14 
 is/1189 
 a/3967 
 tactical/0   <3> ( tactic/7 )  ( al/145 ) 
 role/67 
 @-@/1897 
 playing/36 
 video/54 
 game/186 
 developed/48 
 by/1215 
 Sega/5 
 and/5107 
 Media.Vision/3 
 for/1553 
 PlayStation/9 
 Portable/0   <3> ( Port/18 )  ( able/78 ) 
 Released/0   <2> ( Release/8 )  ( d/608 ) 
 in/5392 
 January/68 
 2011/64 
 it/717 
 third/71 
 series/100 
 same/98 
 fusion/4 
 real/39 
 time/301 
 gameplay/10 
 its/327 
 predecessors/0   <5> ( predecessor/8 )  ( s/4728 ) 
 story/34 
 runs/0   <10> ( run/54 )  ( s/4728 ) 
 parallel/7 
 first/376 
 follows/0   <9> ( follow/20 )  ( s/4728 ) 
 "/2582 
 Nameless/12 
 penal/0   <2> ( pen/8 )  ( al/