In [1]:
from tqdm import tqdm_notebook

from data_reader import DataReader
from collections import defaultdict
import numpy as np
from aer import read_naacl_alignments, AERSufficientStatistics

In [4]:
NULL_TOKEN = 'NULL'

In [16]:
class IBM2:
    def __init__(self,longest_s, source_path: str, target_path,source_path_valid,target_path_valid,gold_path_valid) -> None:
        self.data_reader = DataReader(source_path, target_path)
        self.validation_data = DataReader(source_path_valid, target_path_valid)
        self.gold_links = read_naacl_alignments(gold_path_valid)
        init_ef = 1 / (self.data_reader.n_source_tokens * self.data_reader.n_target_tokens)
        
        init_align = 1 / (longest_s)
        
        self.probs_ef: Dict[Tuple[str, str], float] = defaultdict(lambda: init_ef)
        
        self.align_probs : Dict[int, float] = defaultdict(lambda: init_align)

    def jump_func(self,e_pos,f_pos,len_e,len_f):
                    
            
        return int(e_pos - np.floor(f_pos * (len_e/len_f)))
    
        
    def train(self, n_iter: int):
        for s in range(n_iter):
            
            counts_ef = defaultdict(float)
            counts_e = defaultdict(float)
            counts_align = defaultdict(int)
            #len(self.data_reader)
            for k in tqdm_notebook(range(8000)):
                e, f = self.data_reader[k]
                f = [NULL_TOKEN] + f

                e_normalizer = defaultdict(float)
                
                for e_pos,we in enumerate(e):
                    for f_pos,wf in enumerate(f):
                        x = self.jump_func(e_pos,f_pos,len(e),len(f))
                        align_prob = self.align_probs[x]
                        
                        e_normalizer[we] += self.probs_ef[we, wf] * align_prob

                for e_pos,we in enumerate(e):
                    for f_pos,wf in enumerate(f):
                        x = self.jump_func(e_pos,f_pos,len(e),len(f))
                        align_prob = self.align_probs[x]
                        
                        delta = (self.probs_ef[we, wf] * align_prob) / e_normalizer[we]

                        counts_ef[we, wf] += delta
                        
                        counts_e[wf] += delta
                        
                        counts_align[x] += delta
                        
                        

            for (we, wf), c in counts_ef.items():
                self.probs_ef[we, wf] = c / counts_e[wf]
            
            norm_align_probs = np.sum(list(counts_align.values()))
            
            for x,c in counts_align.items():
            
                self.align_probs[x] = c/norm_align_probs
                

            # DECODING
            print('Validation...')
            metric = AERSufficientStatistics()
            predictions = []

            for (source, target) in tqdm_notebook(self.validation_data.get_parallel_data(), total=len(self.validation_data)):

                source = [NULL_TOKEN] + source

                l = len(source)
                m = len(target)
                
                links = set()
                for i, t in enumerate(target, start=1):
                    link = (
                        1 + np.argmax(
                            [self.probs_ef[t, s] * self.align_probs[self.jump_func(i,j,l,m)]
                             for (j, s)
                             in enumerate(source, start=1)]),
                        i
                    )
                    links.add(link)
                predictions.append(links)

            for gold, pred in zip(self.gold_links, predictions):
                metric.update(sure=gold[0], probable=gold[1], predicted=pred)
            
            aer_result = metric.aer()
            print('AER: {}'.format(aer_result))


In [17]:
ibm2 = IBM2(80,'training/hansards.36.2.e', 'training/hansards.36.2.f','validation/dev.e','validation/dev.f','validation/dev.wa.nonullalign')
ibm2.train(2)

HBox(children=(IntProgress(value=0, max=8000), HTML(value='')))


Validation...


HBox(children=(IntProgress(value=0, max=38), HTML(value='')))


AER: 0.8451369216241738


HBox(children=(IntProgress(value=0, max=8000), HTML(value='')))


Validation...


HBox(children=(IntProgress(value=0, max=38), HTML(value='')))


AER: 0.846081208687441


In [19]:
1 - 0.8451369216241738

0.15486307837582625

In [18]:
1 - 0.846081208687441

0.15391879131255903