# Learnability Project: PFA Phonotactic Learner

------


# General functions and other stuff

In [11]:
from pyfoma import FST, State
from math import log, exp
import csv
import numpy as np

lam = "λ"


def data_iterator(path):
    """Yields a word as a list of tokens from a corpus"""
    with open (path, "r") as fin:
        for line in fin.readlines():
            yield line.split()
              
def make_alphabet(path) -> list:
    """This function extracts an alphabet from a corpusl"""
    alph = []
    for word in data_iterator(path):
        for sym in word:
            if sym not in alph:
                alph.append(sym)
    return alph

# Strictly 2-Local

In [2]:
def make_SL2_dfa(alphabet:list) -> list[FST]:
    """This function initializes a SL2 DFA according to an alphabet. All weights are set to 0 by default"""
    
    # create the FST object and instantiate the initial state
    dfa = FST()
    q0 = dfa.initialstate
    q0.finalweight = 0
    q0.name = lam
    states = {q0,}
    
    # initialize all states and connect q0 to them
    for sym in alphabet:
        q = State()
        q.name = sym
        q0.add_transition(q, sym, 0)
        states.add(q)
        q.finalweight = 0
    
    # loopify the states, excluding q0
    states_no_q0 = set([s for s in states if s.name != lam])
    for state1 in states_no_q0:
        for state2 in states_no_q0:
            state1.add_transition(state2, state2.name, 0)
    
    dfa.states      = states
    dfa.finalstates = states
    dfa.alphabet    = alphabet
    
    return [dfa]

# Strictly 2-Piecewise

In [3]:
def make_SP2_dfas(alphabet:list) -> list[FST]:
    """This function intitializes SP2 machines according the length of the alphabet. Weights are 0 by default"""
    SP2_machines = []
    
    # initializes all 2 state machines and loopify them
    for sym1 in alphabet:
        dfa             = FST()
        q0              = dfa.initialstate
        q0.finalweight  = 0
        q0.name         = lam
        q1              = State()
        q1.finalweight  = 0
        q1.name         = sym1
        dfa.states      = {q0, q1}
        dfa.finalstates = {q0, q1}
        dfa.alphabet    = alphabet
        q0.add_transition(q1, sym1, 0)
        
        # loopification
        for sym2 in alphabet:
            if sym1 != sym2:
                q0.add_transition(q0, sym2, 0)
            
            q1.add_transition(q1, sym2, 0)
        
        SP2_machines.append(dfa)
     
    return SP2_machines

# MLE

In [4]:
def update_transitions(dfas:list[FST], path:str) -> list[FST]:
    """This function updates one or more DFAs with relative frequencies from a given corpus"""
    for word in data_iterator(path):
        for dfa in dfas:
            cs = dfa.initialstate # gets set back to q0 when a new word is processed
            
            for token in word:
                for _, trans in cs.all_transitions():

                    if trans.label == token: # find the transition that matches the current token
                        trans.weight += 1
                        
                    if token == trans.targetstate.name: # update cs to transition's target state. This is how we "traverse" the states
                        cs = trans.targetstate
                                 
            cs.finalweight += 1 # update the weight of the last state the current string ends in
            
    return dfas

def get_trans_sum(state):
    """sums all the transition weights (counts) and the finalweight for given state """
    return sum([tran.weight for _, tran in state.all_transitions()])


def assert_sum_to_1(dfa:FST) -> None:
    """This function asserts that each state's transition weights sum to 1 for given DFA"""
    for state in dfa.states:
        try:
            total = get_trans_sum(state) + state.finalweight
            assert total > .99 # accounts for floating point precision

        except AssertionError:
            print(f"State: {state.name} is misbehaving!\nWeighted sum: {total} != 1")
            raise RuntimeError

def MLE(dfas:list[FST], path:str) -> list[FST]:
    """This function takes one or more dfas, updates the weights from a corpus, and normalizes"""
    dfas = update_transitions(dfas, path)
    
    for dfa in dfas:
        all_state_trans_sums = [get_trans_sum(state) + state.finalweight for state in dfa.states]# list of state weight sums for current dfa

        for i, state in enumerate(dfa.states):
            for _, trans in state.all_transitions():

                #  normalize transition weight by sum of the state's transition weights
                trans.weight = (trans.weight) / all_state_trans_sums[i]

            state.finalweight = (state.finalweight / all_state_trans_sums[i])

        assert_sum_to_1(dfa)
    return dfas

# Evaluation

For the datasets, the evaluation procedures are as follows:
- For the toy dataset, in the `TestingData_toy.txt` file, there are ten (legal, illegal) word pairs. Performance is measured based on the average of the **legal-illegal difference** scores. A more positive value indicates that the legal string is receiving higher probability (Dai & Futrell, 2022).

- For the Navajo and Quechua datasets, the `TestingData_navajo.txt` and `TestingData_quecha.txt` contain either legal or illegal words. To measure performance, the sum of the log likelihood of all legal words is compared to the sum of the log likelihood of all illegal words. The hypothesis is that the sum of legal words should be *greater than* the sum of illegal words.

In [33]:
# Learning data paths
toy     = "data/LearningData_toy.txt"
navajo  = "data/LearningData_navajo.txt"
quechua = "data/LearningData_quechua.txt"

toy_alph     = make_alphabet(toy)
navajo_alph  = make_alphabet(navajo)
quechua_alph = make_alphabet(quechua)

def process_word(dfas:list[FST], word:list[str]):
    """Retrieves the log likelihood of a single word given one or more DFAs"""
    logprobs = []
    
    for dfa in dfas:
        cs = dfa.initialstate 
        for token in word:
            for _, trans in cs.all_transitions():
                
                if trans.label == token:
                    try:
                        logprobs.append(log(trans.weight))
                    except ValueError: # if sequence doesn't exist, trans.weight == 0, so logging this crashes
                        logprobs.append(trans.weight)
                        
                if token == trans.targetstate.name: 
                    cs = trans.targetstate
                
    return sum(logprobs)

def read_eval(path, delim):
    
    with open(path, "r") as fin:
        reader = csv.reader(fin, delimiter=delim)
        for row in reader:
            yield row  

## Toy 

In [34]:
toy_SL2     = make_SL2_dfa(toy_alph)
toy_SL2_MLE = MLE(toy_SL2, toy)

toy_SP2     = make_SP2_dfas(toy_alph)
toy_SP2_MLE = MLE(toy_SP2, toy)

In [36]:
SL2_scores = []
SP2_scores = []

for legal, illegal in read_eval("data/TestingData_toy.txt", delim=","):
    SL2_scores.append(process_word(toy_SL2_MLE, legal) - process_word(toy_SL2_MLE, illegal))
    SP2_scores.append(process_word(toy_SP2_MLE, legal) - process_word(toy_SP2_MLE, illegal))
    

print(f"SL2 Score: {np.mean(SL2_scores)}")
print(f"SP2 Score: {np.mean(SP2_scores)}")

SL2 Score: -0.6333024198097303
SP2 Score: -0.5260450150019949


## Navajo

In [9]:
nav_SL2     = make_SL2_dfa(navajo_alph)
nav_SL2_MLE = MLE(nav_SL2, navajo)

nav_SP2     = make_SP2_dfas(navajo_alph)
nav_SP2_MLE = MLE(nav_SP2, navajo)

In [43]:
legal_SL2_scores = []
illegal_SL2_scores = []

legal_SP2_scores = []
illegal_SP2_scores = []

for word, label in read_eval("data/TestingData_navajo.txt", delim="\t"):
    
    if label == "legal":
        legal_SL2_scores.append(process_word(nav_SL2_MLE, word))
        legal_SP2_scores.append(process_word(nav_SP2_MLE, word))
    elif label == "illegal":
        illegal_SL2_scores.append(process_word(nav_SL2_MLE, word))
        illegal_SP2_scores.append(process_word(nav_SP2_MLE, word))
        
print(f"SL2: \nLegal word score: {sum(legal_SL2_scores)}\nIllegal word score: {sum(illegal_SL2_scores)}\n\n")
print(f"SP2: \nLegal word score: {sum(legal_SP2_scores)}\nIllegal word score: {sum(illegal_SP2_scores)}")

SL2: 
Legal word score: -35156.98666530183
Illegal word score: -66877.02525153007


SP2: 
Legal word score: -2602164.941207403
Illegal word score: -4872983.739561507


## Quechua 

In [10]:
que_SL2     = make_SL2_dfa(quechua_alph)
que_SL2_MLE = MLE(que_SL2, quechua)

que_SP2     = make_SP2_dfas(quechua_alph)
que_SP2_MLE = MLE(que_SP2, quechua)

In [44]:
legal_SL2_scores = []
illegal_SL2_scores = []

legal_SP2_scores = []
illegal_SP2_scores = []

for word, label in read_eval("data/TestingData_quechua.txt", delim="\t"):
    
    if label.startswith("legal"):
        legal_SL2_scores.append(process_word(nav_SL2_MLE, word))
        legal_SP2_scores.append(process_word(nav_SP2_MLE, word))
    elif label.startswith("illegal"):
        illegal_SL2_scores.append(process_word(nav_SL2_MLE, word))
        illegal_SP2_scores.append(process_word(nav_SP2_MLE, word))
        
print(f"SL2: \nLegal word score: {sum(legal_SL2_scores)}\nIllegal word score: {sum(illegal_SL2_scores)}\n\n")
print(f"SP2: \nLegal word score: {sum(legal_SP2_scores)}\nIllegal word score: {sum(illegal_SP2_scores)}")

SL2: 
Legal word score: -155396.00004299358
Illegal word score: -43021.74540275179


SP2: 
Legal word score: -8449660.308918558
Illegal word score: -2227730.152033898
