# Assignment 3 - Transliteration

### Setup [do not change]

In [1]:
# !git clone https://github.com/butoialexandra/eth-nlp-f22-hw3.git

import os
os.chdir('eth-nlp-f22-hw3')

# !pip install -e .

In [2]:
# !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1wlHbFChZFcBgEU8knAeaUIqE0W4EaJEC' -O 'test_cases.pkl'

In [3]:
from rayuela.base.semiring import Semiring, Real, Tropical, Boolean
from rayuela.base.symbol import Sym, ε, ε_1, ε_2
from rayuela.base.misc import epsilon_filter

from rayuela.fsa.fst import FST
from rayuela.fsa.state import State, PairState

import dill
import numpy as np
from math import exp, log

from itertools import product

from frozendict import frozendict

### Subquestion a) - Log Semiring

In [4]:
class Log(Semiring):
    def __init__(self, score):
        self.score = score

    def star(self):
        if self.score > 0:
            print(self.score)
        return Log(-1 * log(1 - exp(self.score)))

    def __float__(self):
        return float(self.score)

    def __add__(self, other):
        # TODO: implement log addition
        # For numerical stability you should use the using the log-sum-exp trick
        # https://en.wikipedia.org/wiki/LogSumExp
        return Log(np.logaddexp(self.score, other.score))

    def __mul__(self, other):
        # TODO: implement log multiplication
        return Log(self.score + other.score)

    def __repr__(self):
        return f"{round(self.score, 15)}"

    def __eq__(self, other):
        return np.allclose(float(self.score), float(other.score), atol=1e-3)

    def __hash__(self):
        return hash(self.score)

# TODO: implement the log semiring one and zero
Log.zero = Log(-np.Inf)
Log.one = Log(0.0)


Log.idempotent = False
Log.cancellative = True

### Before moving to the next subquestion, we give some examples of WFSTs to illustrate how to use the library `rayuela`.

In [5]:
# A simple language model encoded by a WFST

sem = Real # the Real semiring
fst = FST(sem) # a WFST in the Real semiring

# add transitions
fst.add_arc(State(0), Sym('formal'), Sym('formal'), State(1), Real(0.2))
fst.add_arc(State(0), Sym('natural'), Sym('natural'), State(1), Real(0.3))
fst.add_arc(State(0), Sym('learning'), Sym('learning'), State(2), Real(0.2))
fst.add_arc(State(0), Sym('data'), Sym('data'), State(3), Real(0.3))

fst.add_arc(State(1), Sym('language'), Sym('language'), State(2), Real(0.6))
fst.add_arc(State(1), Sym('languages'), Sym('languages'), State(4), Real(0.4))

fst.add_arc(State(2), Sym('is'), Sym('is'), State(5), Real(1.0))

fst.add_arc(State(5), Sym('fun'), Sym('fun'), State(6), Real(1.0))

fst.add_arc(State(3), Sym('is'), Sym('is'), State(5), Real(0.5))
fst.add_arc(State(3), Sym('are'), Sym('are'), State(5), Real(0.5))

fst.add_arc(State(4), Sym('are'), Sym('are'), State(5), Real(1.0))

# set initial weights
# if not set, it defaults to semiring zero
fst.set_I(State(0), Real(1.0))
# set final weights
fst.set_F(State(6), Real(1.0))

# visualize
fst

In [6]:
# A simple WFSA for transliteration

sem = Real # semiring
fst = FST(sem) # initialize WFST in the chosen semiring
one, zero = sem.one, sem.zero # semiring one and zero

# add transitions
fst.add_arc(State(0), Sym('d'), Sym('data'), State(1), Real(0.5))
fst.add_arc(State(0), Sym('d'), Sym('dew'), State(5), Real(0.5))

fst.add_arc(State(1), Sym('ey'), Sym('ÃŽÂµ'), State(2), Real(0.5))
fst.add_arc(State(1), Sym('ae'), Sym('ÃŽÂµ'), State(2), Real(0.5))

fst.add_arc(State(2), Sym('t'), Sym('ÃŽÂµ'), State(3), Real(0.7))
fst.add_arc(State(2), Sym('dx'), Sym('ÃŽÂµ'), State(3), Real(0.3))

fst.add_arc(State(3), Sym('ax'), Sym('ÃŽÂµ'), State(4), one)

fst.add_arc(State(5), Sym('uw'), Sym('ÃŽÂµ'), State(6), one)

# initial and final weights
fst.set_I(State(0), one)
fst.set_F(State(4), one)
fst.set_F(State(6), one)

fst

### Subquestion b) - String Transducer

In [7]:
def string_fst(s, semiring=Log):
    """ Takes a string as input and a semiring
    and returns a transducer encoding of the
    string
    """
    fst = FST(semiring)

    # TODO: add arcs, initial and final weights
    Σ = list(s)
    for q, a in enumerate(Σ):
        fst.add_arc(State(q), Sym(a), Sym(a), State(q+1), semiring.one)

    fst.set_I(State(0), semiring.one)
    fst.set_F(State(len(Σ)), semiring.one)


    return fst

In [8]:
# you can visualise the transducer
fst = string_fst("abc", semiring=Log)
fst

### Subquestion c) - Edit Distance Transducer

As we will not learn the weights on the arcs, you can set them to some pre-defined constants. You should choose the weights such that the pathsum that needs to be computed in e) does not diverge (this is caused by the weights of the self-loops in the transducer). 

In [9]:
def edit_distance_fst(source_alphabet, target_alphabet, semiring=Log):
    """ Takes as input two sets of characters (source and target 
    alphabets) and a semiring and returns an edit-distance transducer
    """
    
    fst = FST(semiring)
    
    for q, d in enumerate(target_alphabet):
        # Num insertions + Num deletions + Num substitutions
        num_outgoing_arcs = len(target_alphabet) + len(source_alphabet) + (len(source_alphabet) * len(target_alphabet))
        num_outgoing_arcs_from_initial_state = num_outgoing_arcs - len(target_alphabet)
        # Weights chosen so that we get local normalization to semiring one
        if semiring == Log:
            w = Log(-log(num_outgoing_arcs))
            w_zero = Log(-log(num_outgoing_arcs_from_initial_state))
            # w = Log(-0.0001)
            # w_zero = Log(-0.0001)
        elif semiring == Real:
            w = Real(1 / num_outgoing_arcs)
            w_zero = Real(1 / num_outgoing_arcs_from_initial_state)
        elif semiring == Tropical:
            w = semiring.one
            w_zero = semiring.one
        elif semiring == Boolean:
            w = Boolean(True)
            w_zero = Boolean(True)
        # Insertions
        fst.set_arc(State(0), ε, Sym(d), State(q + 1), w_zero)
        fst.add_arc(State(q + 1), ε, Sym(d), State(q + 1), w)
        for p, e in enumerate(target_alphabet):
            if p != q:
                fst.add_arc(State(q + 1), ε, Sym(e), State(p + 1), w)
                for a in source_alphabet:
                    # Substitutions Part I
                    fst.set_arc(State(0), a, Sym(e), State(p + 1), w_zero)
                    fst.add_arc(State(q + 1), a, Sym(e), State(p + 1), w)
                
        for a in source_alphabet:
            # Substitutions Part II
            fst.set_arc(State(0), a, ε, State(0), w_zero)
            fst.add_arc(State(q + 1), Sym(a), Sym(d), State(q + 1), w)
            # Deletions
            fst.add_arc(State(q + 1), a, ε, State(q + 1), w)
        
        fst.set_F(State(q + 1), semiring.one)

    fst.set_I(State(0), semiring.one)

    return fst

In [10]:
edit_fst = edit_distance_fst({'a', 'b'}, {'c', 'd'}, semiring=Log)
edit_fst

### Subquestion d) - Composition

We provide some composition test cases for you to inspect.

In [11]:
with open('test_cases.pkl', 'rb') as in_file:
    examples = dill.load(in_file)

In [12]:
# FST 1
fst1 = examples[0]['fst1']['fst']
fst1

In [13]:
# FST 2
fst2 = examples[0]['fst2']['fst']
fst2

In [14]:
# composition
examples[0]['composition']['fst']

In [15]:
def augment_and_relabel(fst1, fst2):
    """Augment with self loops and relabel."""
    semiring = fst1.R
    fst1_relabeled = FST(R=fst1.R)
    for q in fst1.Q:
        # Augmentation
        fst1_relabeled.add_arc(q, ε, ε_1, q, semiring.one)
        for arc in fst1.arcs(q):
            (a, b), j, w = arc
            if b == ε:
                fst1_relabeled.add_arc(q, a, ε_2, j, w)
            else:
                fst1_relabeled.add_arc(q, a, b, j, w)

    fst2_relabeled = FST(R=fst2.R)
    for q in fst2.Q:
        # Augmentation
        fst2_relabeled.add_arc(q, ε_2, ε, q, semiring.one)
        for arc in fst2.arcs(q):
            (a, b), j, w = arc
            if a == ε:
                fst2_relabeled.add_arc(q, ε_1, b, j, w)
            else:
                fst2_relabeled.add_arc(q, a, b, j, w)


    for q in fst1.I:
        fst1_relabeled.set_I(*q)

    for q in fst1.F:
        fst1_relabeled.set_F(*q)

    for q in fst2.I:
        fst2_relabeled.set_I(*q)

    for q in fst2.F:
        fst2_relabeled.set_F(*q)

    return fst1_relabeled, fst2_relabeled

In [16]:
fst1_relabeled, fst2_relabeled = augment_and_relabel(fst1, fst2)
fst2_relabeled

In [17]:
def compose(fst1, fst2):
    """Takes as input two transducers and returns
    their composition.
    """
    # The machines must be in the same semiring
    assert fst1.R == fst2.R
    semiring = fst1.R
    fst1, fst2 = augment_and_relabel(fst1, fst2)
    comp_fst = FST(R=fst1.R)
    stack = [(q1[0], q2[0], State("0")) for q1 in fst1.I for q2 in fst2.I]
    # F = [(qw1[0], qw2[0]) for qw1 in fst1.F for qw2 in fst2.F]
    visited = set(stack)
    while len(stack) > 0:
        q1, q2, q3 = stack.pop()
        E1 = fst1.arcs(q1)
        E2 = fst2.arcs(q2)
        # Set final weights
        # if (q1, q2) in F:
        #     comp_fst.set_F(PairState(q1, q2), fst1.ρ[q1] * fst1.ρ[q1] * semiring.one)
        for e1, e2 in product(E1, E2):
            (a, b), q1_prime, w1 = e1
            (c, d), q2_prime, w2 = e2
            q3_prime = epsilon_filter(b, c, q3)
            if q3_prime != State("⊥"):  # If our eps filter is not in the blocking state
                comp_fst.set_arc(PairState(q1, q2), a, d, PairState(q1_prime, q2_prime), w1*w2)
                if (q1_prime, q2_prime, q3_prime) not in visited:
                    visited.add((q1_prime, q2_prime, q3_prime))
                    stack.append((q1_prime, q2_prime, q3_prime))
    for q in comp_fst.Q:
        q1, q2 = q.idx
        comp_fst.λ[q] = fst1.λ[q1] * fst2.λ[q2]
        comp_fst.ρ[q] = fst1.ρ[q1] * fst2.ρ[q2]

    return comp_fst, fst1, fst2

Here is a more efficient way of doing the composition, that implicitly re-labels the fsts.  
We can do this because the re-labeling is one-to-one. E.g we don't have to explicitly   
re-label (a, ε) in fst1 to (a, ε_2), because we know that whenever the output of fst1 is ε,  
it would get re-labeled to ε_2, and therefore we can just adjust our epsilon filter, replacing  
all occurances of $\bullet : ε_2$ with $\bullet : ε$.

In [18]:
ε_L = State("ε_L")
def implicit_epsilon_filter(b, c, q3):
    """Modified to implicitly re-label eps transitions."""
    if b == c and b != ε_L and q3 == State("0"):
        return State("0")
    if b == c and b not in [ε, ε_L]:
        return State("0")
    elif (b, c) == (ε_L, ε) and q3 != State("2"):
        return State("1")
    elif (b, c) == (ε, ε_L) and q3 != State("1"):
        return State("2")
    else:
        return State("⊥")

In [19]:
def implicit_compose(fst1, fst2):
    """Takes as input two transducers and returns their composition.
    Uses a modified epsilon filter to do implict relabeling and then also
    does implicit augmentation.
    """
    # The machines must be in the same semiring
    assert fst1.R == fst2.R
    semiring = fst1.R
    comp_fst = FST(R=fst1.R)
    # Implicit augmentation
    E1 = lambda q1: set(list(fst1.arcs(q1)) + [((ε, ε_L), q1, semiring.one)])
    E2 = lambda q2: set(list(fst2.arcs(q2)) + [((ε_L, ε), q2, semiring.one)])
    F = [(qw1[0], qw2[0]) for qw1 in fst1.F for qw2 in fst2.F]
    stack = [(qw1[0], qw2[0], State("0")) for qw1 in fst1.I for qw2 in fst2.I]
    visited = set(stack)
    while len(stack) > 0:
        q1, q2, q3 = stack.pop()
        # Set final weights
        if (q1, q2) in F:
            comp_fst.set_F(PairState(q1, q2), fst1.ρ[q1] * fst1.ρ[q1] * semiring.one)
        for e1, e2 in product(E1(q1), E2(q2)):
            (a, b), q1_prime, w1 = e1
            (c, d), q2_prime, w2 = e2
            q3_prime = implicit_epsilon_filter(b, c, q3)
            if q3_prime != State("⊥"):  # If our eps filter is not in the blocking state
                comp_fst.set_arc(PairState(q1, q2), a, d, PairState(q1_prime, q2_prime), w1*w2)
                if (q1_prime, q2_prime, q3_prime) not in visited:
                    visited.add((q1_prime, q2_prime, q3_prime))
                    stack.append((q1_prime, q2_prime, q3_prime))

    for q in comp_fst.Q:
        q1, q2 = q.idx
        comp_fst.λ[q] = fst1.λ[q1] * fst2.λ[q2]
        comp_fst.ρ[q] = fst1.ρ[q1] * fst2.ρ[q2]

    return comp_fst

In [20]:
print("My result with explicit augmentation and re-labeling:")
comp_fst, fst1_relabeled, fst2_relabeled = compose(fst1, fst2)
comp_fst

My result with explicit augmentation and re-labeling:


In [21]:
print("My result with implicit augmentation and re-labeling:")
comp_fst = implicit_compose(fst1, fst2)
comp_fst

My result with implicit augmentation and re-labeling:


In [22]:
print("Expected:")
examples[0]['composition']['fst']

Expected:


### Subquestion e) - Lehmann's Algorithm

In [23]:
class Pathsum:
    def __init__(self, fsa):

        # basic FSA stuff
        self.fsa = fsa
        self.R = fsa.R
        self.N = self.fsa.num_states

        # state dictionary
        self.I = {}
        for n, q in enumerate(self.fsa.Q):
            self.I[q] = n

        # lift into the semiring
        self.W, self.W_dict, self.alpha, self.beta = self.lift()

        
    def lift(self):
        """ creates the weight matrix, initial weight vector,
        final weight vector from the automaton """
        W = self.R.zeros((self.N, self.N))
        W_dict = {a: {} for a in self.fsa.Sigma}
        alpha = self.R.zeros((self.N,))
        beta = self.R.zeros((self.N,))
        for a in self.fsa.Sigma:
            for b in self.fsa.Delta:
                W_dict[a][b] = self.R.zeros((self.N, self.N))
        for p in self.fsa.Q:
            alpha[self.I[p]] += self.fsa.λ[p]
            beta[self.I[p]] += self.fsa.ρ[p]
            for (a, b), q, w in self.fsa.arcs(p):
                W[self.I[p], self.I[q]] += w
                W_dict[a][b][self.I[p], self.I[q]] = w
        return W, W_dict, alpha, beta
    
    def lehmann(self):
        """
        Lehmann's (1977) algorithm.
        """
        N = self.N
        semiring = self.R
        # Create semiring identity matrix
        I = semiring.diag(N)
        # Initialize R_0
        R_j = self.W.copy()
        for j in range(N):
            R_prev = R_j.copy()
            for i in range(N):
                for k in range(N):
                    R_j[i, k] = R_prev[i, k] + (R_prev[i, j] * semiring.star(R_prev[j, j]) * R_prev[j, k])

        return I + R_j

    def _iterate(self, K):
        P = self.R.diag(self.N)
        for n in range(K):
            P += self.W @ P
        return P

    def _fixpoint(self, K=200):
        if self.fsa.R.idempotent:
            return self._iterate(self.fsa.num_states)

        diag = self.R.zeros((self.fsa.num_states, self.fsa.num_states))
        for n in range(self.fsa.num_states):
            diag[n, n] = self.fsa.R.one
        P_old = diag

        for _ in range(K):
            P_new = diag + self.W @ P_old
            P_old = P_new

        return P_old

    def fixpoint(self):

        P = self._fixpoint()
        W = {}

        for p in self.fsa.Q:
            for q in self.fsa.Q:
                if p in self.I and q in self.I:
                    W[p, q] = P[self.I[p], self.I[q]]
                elif p == q:
                    W[p, q] = self.R.one
                else:
                    W[p, q] = self.R.zero

        return frozendict(W)

In [24]:
# Testing >> We only test 0-closed semirings, because our naive implementation
# is only correct for 0-closed semirings.

def is_zero_closed(fst):
    if fst.R in [Tropical, Boolean]:
        return True
    elif fst.R in [Real, Log]:
        return False
    else:
        raise NotImplementedError(f"0-closed Test not implemented for the {fst.R} Semiring!")

def naive_star(fst):
    if not is_zero_closed(fst):
        raise ValueError(f"{fst.R} is not a zero-closed semiring!")
    pathsum = Pathsum(fst)
    W, W_dict, alpha, beta = pathsum.lift()
    N = W.shape[0]
    semiring = fst.R
    I = fst.R.diag(N)
    prod = I.copy()
    W_star = I.copy()
    for n in range(1, N):
        prod = W @ prod
        W_star += prod
    return W_star


def test_lehmann_against_naive(fst):
    return np.array_equal(naive_star(fst), Pathsum(fst).lehmann())

def test_lehmann_against_fixpoint(fst):
    test_pathsum = Pathsum(fst)
    fixpoint_dict = test_pathsum.fixpoint()
    fixpoint_arr = np.array(list(fixpoint_dict.values())).reshape(test_pathsum.W.shape)
    return np.array_equal(fixpoint_arr, Pathsum(fst).lehmann())

passed = 0
total = 0
for example_num, example in enumerate(examples):
    fst1 = example["fst1"]["fst"]
    fst2 = example["fst2"]["fst"]
    fsts = [fst1, fst2]
    for fst in fsts:
        if is_zero_closed(fst):
            if test_lehmann_against_naive(fst):
                print(f"Naive test \033[1;32m[PASSED]\033[0m for example {example_num} / {len(examples)-1} | Semiring = {fst.R}")
                passed += 1
            else:
                print(f"Naive test \033[1;31m[FAILED]\033[0m for example {example_num} / {len(examples)-1} | Semiring = {fst.R}")
            print("----------------------------------------------------------------------------------")
            total += 1
        else:
            print("fst not 0-closed, skipping...")

for example_num, example in enumerate(examples):
    fst1 = example["fst1"]["fst"]
    fst2 = example["fst2"]["fst"]
    fsts = [fst1, fst2]
    for fst in fsts:
        if test_lehmann_against_fixpoint(fst):
            print(f"Fixpoint test \033[1;32m[PASSED]\033[0m for example {example_num} / {len(examples)-1} | Semiring = {fst.R}")
            passed += 1
        else:
            print(f"Fixpoint test \033[1;31m[FAILED]\033[0m for example {example_num} / {len(examples)-1} | Semiring = {fst.R}")
        print("----------------------------------------------------------------------------------")
        total += 1
print(f"{passed} / {total} tests passed.")

if passed == total:
    print(":)")

fst not 0-closed, skipping...
fst not 0-closed, skipping...
Naive test [1;32m[PASSED][0m for example 1 / 4 | Semiring = <class 'rayuela.base.semiring.Tropical'>
----------------------------------------------------------------------------------
Naive test [1;32m[PASSED][0m for example 1 / 4 | Semiring = <class 'rayuela.base.semiring.Tropical'>
----------------------------------------------------------------------------------
Naive test [1;32m[PASSED][0m for example 2 / 4 | Semiring = <class 'rayuela.base.semiring.Boolean'>
----------------------------------------------------------------------------------
Naive test [1;32m[PASSED][0m for example 2 / 4 | Semiring = <class 'rayuela.base.semiring.Boolean'>
----------------------------------------------------------------------------------
fst not 0-closed, skipping...
fst not 0-closed, skipping...
Naive test [1;32m[PASSED][0m for example 4 / 4 | Semiring = <class 'rayuela.base.semiring.Boolean'>
------------------------------------

### Subquestion f) - Normalizer

We provide some test cases for you to inspect.

In [25]:
examples[0]['fst1']['Z'], examples[0]['fst2']['Z'], examples[0]['composition']['Z']

(0.00138425, 0.000997195294118, 7.07192371e-07)

In [26]:
def normalizer(fst):
    """Takes as input a finite-state transducer and returns its normalizer."""
    semiring = fst.R
    pathsum = Pathsum(fst)
    R = pathsum.lehmann()
    N = pathsum.N
    Z = semiring.zero
    for i in range(N):
        for k in range(N):
            qi = list(fst.Q)[i]
            qk = list(fst.Q)[k]
            Z += fst.λ[qi] * R[i, k] * fst.ρ[qk]

    return Z

In [27]:
# Test normalizer

def test_normalizer(fst_dict):
    return normalizer(fst_dict["fst"]) == fst_dict["Z"]

passed = 0
total = 0
for example_num, example in enumerate(examples):
    fst_dicts = [example["fst2"], example["fst1"]]
    for fst_dict in fst_dicts:
        if test_normalizer(fst_dict):
            print(f"Test \033[1;32m[PASSED]\033[0m for example {example_num} / {len(examples)-1}")
            passed += 1
        else:
            print(f"Test \033[1;31m[FAILED]\033[0m for example {example_num} / {len(examples)-1}")
        print("----------------------------------------------------------------------------------")
        total += 1

print(f"{passed} / {total} tests passed.")

if passed == total:
    print(":)")

Test [1;32m[PASSED][0m for example 0 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 0 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 1 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 1 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 2 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 2 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 3 / 4
----------------------------------------------------------------------------------
Test [1;32m[PASSED][0m for example 3 / 4
---------------------------------------------------------------------------

### Subquestion g) - Log-likelihood

In [28]:
def log_likelihood(source_alphabet, target_alphabet, source_str, target_str):
    """ Takes as input the source and target alphabets (sets of characters)
    and a pair of source and target strings and returns the log-likelihood
    """
    Tx = string_fst(source_str)
    Ty = string_fst(target_str)
    T = edit_distance_fst(source_alphabet, target_alphabet)    
    Tx_T = implicit_compose(Tx, T)
    Tx_T_Ty = implicit_compose(Tx_T, Ty)
    return normalizer(Tx_T_Ty).score - normalizer(Tx_T).score  # Division in log space

In [29]:
log_likelihood(source_alphabet={'y', 'e', 's'}, target_alphabet={'j', 'a', 'b'}, source_str='yes', target_str='ja')

-3.6566107459570265

### Subquestion h) - Decoding

In [30]:
class Arctic(Tropical):
    pass

In [31]:
def to_semiring(fst, semiring):
    """Lifts an WFST into a different semiring.
    Modified to account for the Arctic semiring case.
    """
    oone, one = fst.R.one, semiring.one
    nfst = FST(semiring)
    if semiring == Arctic:  # Negate weights
        score_multiplier = -1
    else:
        score_multiplier = 1

    # initial weights
    for q, w in fst.I:
        if w == oone:
            nfst.set_I(q, one)
        else:
            nfst.set_I(q, semiring(score_multiplier * w.score))

    # final weights
    for q, w in fst.F:
        if w == oone:
            nfst.set_F(q, one)
        else:
            nfst.set_F(q, semiring(score_multiplier * w.score))

    # arcs
    for p in fst.Q:
        for (a, b), q, w in fst.arcs(p):
            if w == oone:
                nfst.add_arc(p, a, b, q, one)
            else:
                nfst.add_arc(p, a, b, q, semiring(score_multiplier * w.score))

    return nfst

In [32]:
def decode(in_str, edit_distance_fst):
    """ Gets as input a source string and an
    edit-distance transducer and returns the weight of the most
    probable target string
    """
    Tx = string_fst(in_str)
    T = edit_distance_fst
    Tx_T = implicit_compose(Tx, T)
    return -1 * normalizer(to_semiring(Tx_T, Arctic)).score


In [33]:
decode("yes", edit_fst)

-inf