In [2]:
from nltk import PCFG, Nonterminal
from nltk.parse.generate import generate

def equal_production(terminals, total=1):
    """Shorthand to write a list of terminals that are all equally likely"""
    terminals = terminals.split(' | ')
    return ' | '.join([f"'{terminal}' [{total/len(terminals)}]" for terminal in terminals])

equal_production('man | woman | girl')

"'man' [0.3333333333333333] | 'woman' [0.3333333333333333] | 'girl' [0.3333333333333333]"

In [3]:
import random

# Randomly generate sentences using CFG
def weighted_choice(choices):
    total = sum(w for c, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c
        upto += w

def generate_sentence(grammar, symbol=Nonterminal('S')):
    productions = grammar.productions(lhs=symbol)
    chosen_prod = weighted_choice([(prod, prod.prob()) for prod in productions])
    
    sentence = []
    # print(symbol)
    for sym in chosen_prod.rhs():
        if isinstance(sym, Nonterminal):
            sentence.extend(generate_sentence(grammar, sym))
        else:
            sentence.append(sym)
            
    return sentence

# Morphology
from pyfoma import *

fsts = {}
fsts['lex'] = FST.re("[a-zA-Z\+]*")

fsts['sib']       = FST.re("s|sh|z|zh|ch|x")
fsts['C']         = FST.re("[a-z] - [aeiou]")
fsts['sibrk']     = FST.re("$^rewrite('':e / $sib _ \+ s)", fsts)
fsts['yrule']     = FST.re("$^rewrite(y:(ie) / $C _ \+ s)", fsts)
fsts['cleanup']   = FST.re("$^rewrite(\+:'')")
fsts['grammar']   = FST.re("$lex @ $sibrk @ $yrule @ $cleanup", fsts)

def fix_morphology(words):
    """Combines morpheme clusters into proper words using an FST"""
    combined_words = []
    for word in words:
        if word[0] == "+":
            combined_words[-1] += word
        else:
            combined_words.append(word)
    return [list(fsts['grammar'].generate(word))[0] for word in combined_words]

def sample_sentences(grammar, n):
    sents = [generate_sentence(grammar) for _ in range(n)]
    sents = [' '.join(fix_morphology(sent)) for sent in sents]
    # sents = [sent.capitalize() + "." for sent in sents]
    return sents

In [8]:
pcfg = PCFG.fromstring(f"""
S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]

VP_3Sg        -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
VP            -> VT      NP_acc [0.475] | VI      [0.475] | VP     'and' VP     [0.05]

NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
Det_Sg        -> {equal_production('the | a')}

NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
Det_Pl        -> {equal_production('the | those | these')}

NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]

N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.2] | N_common [0.6]
N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
N_common      -> {equal_production('girl | boy | cat | turtle | asparagus | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}

Rel_Sg         -> 'that' VP_3Sg [1]
Rel_Pl         -> 'that' VP [1]

VI            -> {equal_production('run | walk | think | laugh | ponder')}
VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}

Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
""")

sample_sentences(pcfg, 20)

['the dude walks',
 'I hug them',
 'those cats kiss those ducks and the ladies',
 'the linguist that punches these asparaguses thinks',
 'he fights a cat',
 'they love a linguist',
 'the physicist thinks',
 'the girl laughs',
 'I walk',
 'those shiny small dogs and these ladies and those asparaguses and the dudes that run and the cats and those boys ponder',
 'he thinks',
 'these girls fight a cat',
 'they ponder',
 'those ladies run',
 'the lady that runs punches those boys',
 'the asparaguses that punch the ladies run',
 'a asparagus that laughs runs',
 'he runs',
 'a linguist punches us',
 'we run']

In [6]:
agreement_violations = PCFG.fromstring(f"""
S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]

VP_3Sg        -> VT      NP_acc [0.475] | VI      [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
VP            -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP     'and' VP     [0.05]

NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
Det_Sg        -> {equal_production('the | a')}

NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
Det_Pl        -> {equal_production('the | those | these')}

NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]

N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.15] | N_common [0.65]
N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}

Rel_Sg         -> 'that' VP_3Sg [1]
Rel_Pl         -> 'that' VP [1]

VI            -> {equal_production('run | walk | think | laugh | ponder')}
VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}

Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
""")

sample_sentences(agreement_violations, 20)

['she walk',
 'I kisses these cats',
 'these ducks hugs those blue cheeses',
 'a linguist punch you',
 'I fights a cat',
 'I hugs those rabbits',
 'a linguist ponder',
 'these turtles kisses us',
 'these ladies hugs the linguist',
 'I kicks those cheeses and these dogs and walks',
 'the sparkling dudes runs',
 'he hug those cats',
 'she walk and love the dudes and these girls and punch you',
 'she punch the dogs that kisses us',
 'he laugh',
 'they ponders',
 'you fights us',
 'she ponder',
 'those cats that punches the dog that walk loves us',
 'I loves a cat']

In [14]:
import datasets

random.seed(1)
valid_num = 2000
invalid_num = 2000
valid = sample_sentences(pcfg, valid_num)
invalid = sample_sentences(agreement_violations, invalid_num)

dataset = datasets.Dataset.from_dict({"text": valid + invalid, "labels": [1] * valid_num + [0] * invalid_num}).shuffle()

dataset = dataset.class_encode_column('labels')
dataset.train_test_split(test_size=0.3, stratify_by_column='labels')
dataset.push_to_hub("michaelginn/latent-trees-agreement-ID")

Stringifying the column:   0%|          | 0/4000 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/4000 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/547 [00:00<?, ?B/s]

In [None]:
def check_linear_heuristic(sentence: str):
    for word in sentence.split(' '):
