In [1]:
from nltk import PCFG, Nonterminal
from nltk.parse.generate import generate

def equal_production(terminals, total=1):
    """Shorthand to write a list of terminals that are all equally likely"""
    terminals = terminals.split(' | ')
    return ' | '.join([f"'{terminal}' [{total/len(terminals)}]" for terminal in terminals])

equal_production('man | woman | girl')

"'man' [0.3333333333333333] | 'woman' [0.3333333333333333] | 'girl' [0.3333333333333333]"

In [22]:
import random
from tqdm import tqdm

# Randomly generate sentences using CFG
def weighted_choice(choices):
    total = sum(w for c, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c
        upto += w

def generate_sentence(grammar, symbol=Nonterminal('S')):
    productions = grammar.productions(lhs=symbol)
    chosen_prod = weighted_choice([(prod, prod.prob()) for prod in productions])
    
    sentence = []
    # print(symbol)
    for sym in chosen_prod.rhs():
        if isinstance(sym, Nonterminal):
            sentence.extend(generate_sentence(grammar, sym))
        else:
            sentence.append(sym)
            
    return sentence

# Morphology
from pyfoma import *

fsts = {}
fsts['lex'] = FST.re("[a-zA-Z\+]*")

fsts['sib']       = FST.re("s|sh|z|zh|ch|x")
fsts['C']         = FST.re("[a-z] - [aeiou]")
fsts['sibrk']     = FST.re("$^rewrite('':e / $sib _ \+ s)", fsts)
fsts['yrule']     = FST.re("$^rewrite(y:(ie) / $C _ \+ s)", fsts)
fsts['cleanup']   = FST.re("$^rewrite(\+:'')")
fsts['grammar']   = FST.re("$lex @ $sibrk @ $yrule @ $cleanup", fsts)

def fix_morphology(words):
    """Combines morpheme clusters into proper words using an FST"""
    combined_words = []
    for word in words:
        if word[0] == "+":
            combined_words[-1] += word
        else:
            combined_words.append(word)
    return [list(fsts['grammar'].generate(word))[0] for word in combined_words]

def sample_sentences(grammar, n):
    sents = [generate_sentence(grammar) for _ in tqdm(range(n))]
    sents = [' '.join(fix_morphology(sent)) for sent in sents]
    # sents = [sent.capitalize() + "." for sent in sents]
    return sents

In [23]:
pcfg = PCFG.fromstring(f"""
S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]

VP_3Sg        -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
VP            -> VT      NP_acc [0.475] | VI      [0.475] | VP     'and' VP     [0.05]

NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
Det_Sg        -> {equal_production('the | a')}

NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
Det_Pl        -> {equal_production('the | those | these')}

NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]

N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.2] | N_common [0.6]
N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}

Rel_Sg         -> 'that' VP_3Sg [1]
Rel_Pl         -> 'that' VP [1]

VI            -> {equal_production('run | walk | think | laugh | ponder')}
VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}

Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
""")

sample_sentences(pcfg, 20)

100%|██████████| 20/20 [00:00<00:00, 23039.30it/s]


['those wugs love a boy',
 'the bird thinks',
 'a wug walks',
 'they love a dude',
 'a girl hugs them',
 'these wugs kick us',
 'the happy cheese laughs',
 'she kicks the girl',
 'a lady that ponders loves a cat',
 'they run',
 'a rabbit fights these sparkling dogs',
 'a cat ponders',
 'he fights me',
 'the cat laughs',
 'the wug ponders',
 'she loves me',
 'the sparkling cats that run kiss these rabbits',
 'a cheese that kisses these mad girls runs',
 'I walk and laugh',
 'we ponder']

In [46]:
# agreement_violations = PCFG.fromstring(f"""
# S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]
#
# VP_3Sg        -> VT      NP_acc [0.475] | VI      [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
# VP            -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP     'and' VP     [0.05]
#
# NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
# NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
# Det_Sg        -> {equal_production('the | a')}
#
# NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
# NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
# Det_Pl        -> {equal_production('the | those | these')}
#
# NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]
#
# N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.15] | N_common [0.65]
# N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
# N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}
#
# Rel_Sg         -> 'that' VP_3Sg [1]
# Rel_Pl         -> 'that' VP [1]
#
# VI            -> {equal_production('run | walk | think | laugh | ponder')}
# VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}
#
# Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
# """)
#
# sample_sentences(agreement_violations, 20)

verb_sing = ["runs", "walks", "thinks", "laughs", "ponders", "kicks", "kisses", "hugs", "punches", "fights", "loves"]
verb_pl = ["run", "walk", "think", "laugh", "ponder", "kick", "kiss", "hug", "punch", "fight", "love"]

def deform_sentence(sentence: str):
    """Deforms a sentence by randomly switching one or more verbs from singular to plural or vice versa"""
    words = sentence.split(' ')
    verb_indices = []
    for i, word in enumerate(words):
        if word in verb_sing or word in verb_pl:
            verb_indices.append(i)

    indices_to_deform = random.sample(verb_indices, random.randint(1, len(verb_indices)))
    for index in indices_to_deform:
        word = words[index]
        if word in verb_sing:
            verb_index = verb_sing.index(word)
            words[index] = verb_pl[verb_index]
        elif word in verb_pl:
            verb_index = verb_pl.index(word)
            words[index] = verb_sing[verb_index]

    return ' '.join(words)

deform_sentence('the cat that runs ponders and hugs the duck')

'the cat that runs ponder and hug the duck'

In [59]:
import datasets

random.seed(1)
valid_num = 2000
invalid_num = 2000
valid = sample_sentences(pcfg, valid_num)
invalid = [deform_sentence(sentence) for sentence in sample_sentences(pcfg, invalid_num)]

dataset = datasets.Dataset.from_dict({"text": valid + invalid, "labels": [1] * valid_num + [0] * invalid_num}).shuffle()

dataset = dataset.class_encode_column('labels')
dataset = dataset.train_test_split(test_size=0.4, stratify_by_column='labels')
dataset_eval_test = dataset['test'].train_test_split(test_size=0.5, stratify_by_column='labels')
dataset = datasets.DatasetDict({
    'train': dataset['train'],
    'eval': dataset_eval_test['train'],
    'test': dataset_eval_test['test']
})

dataset.push_to_hub("michaelginn/latent-trees-agreement-ID")

100%|██████████| 2000/2000 [00:00<00:00, 39714.84it/s]
100%|██████████| 2000/2000 [00:00<00:00, 54913.28it/s]


Stringifying the column:   0%|          | 0/4000 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/4000 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/749 [00:00<?, ?B/s]

In [60]:
noun_sing = ["girl", "boy", "cat", "turtle", "rutabaga", "duck", "cheese", "dude", "rabbit", "wug", "linguist", "physicist", "lady", "dog", "cat", "bird", "he", "she"]
noun_pl = ["girls", "boys", "cats", "turtles", "rutabagas", "ducks", "cheeses", "dudes", "rabbits", "wugs", "linguists", "physicists", "ladies", "dogs", "cats", "birds", "I", "you",  "me", "us", "them", "they", "we"]

verb_sing = ["runs", "walks", "thinks", "laughs", "ponders", "kicks", "kisses", "hugs", "punches", "fights", "loves"]
verb_pl = ["run", "walk", "think", "laugh", "ponder", "kick", "kiss", "hug", "punch", "fight", "love"]

def check_linear_heuristic(sentence: str):
    # Returns false if the sentence fails the linear heuristic (verb should agree with most recent noun)
    most_recent_noun = None
    most_recent_noun_num = None # 'pl' or 'sg'

    for word in sentence.split(' '):
        if word in noun_sing:
            most_recent_noun = word
            most_recent_noun_num = 'sg'
        elif word in noun_pl:
            most_recent_noun = word
            most_recent_noun_num = 'pl'
        elif word in verb_sing:
            if most_recent_noun_num == 'pl':
                return False
        elif word in verb_pl:
            if most_recent_noun_num == 'sg':
                return False
    return True

print(check_linear_heuristic('the linguist that punches these rutabagas thinks'))
print(check_linear_heuristic('the linguist that punches this rutabaga thinks'))
print(check_linear_heuristic('the linguist that punches these rutabagas think'))
print(check_linear_heuristic('they thinks'))

False
True
True
False


In [61]:
import datasets

random.seed(1)


valid_num = 20000
invalid_num = 200000
valid = sample_sentences(pcfg, valid_num)
invalid = [deform_sentence(sentence) for sentence in sample_sentences(pcfg, invalid_num)]

# Filter sents that pass/fail true heuristic to also pass/fail linear heuristic so we have an ambiguous training dataset
both_valid = []
true_valid_but_linear_invalid = []
for valid_sent in tqdm(valid):
    if check_linear_heuristic(valid_sent):
        both_valid.append(valid_sent)
    else:
        true_valid_but_linear_invalid.append(valid_sent)

both_invalid = []
true_invalid_but_linear_valid = []
for invalid_sent in tqdm(invalid):
    if not check_linear_heuristic(invalid_sent):
        both_invalid.append(invalid_sent)
    else:
        true_invalid_but_linear_valid.append(invalid_sent)

print(f"Passes both heuristics: {len(both_valid)}")
print(f"Fails both heuristics: {len(both_invalid)}")
print(f"Passes hierarch, fails linear: {len(true_valid_but_linear_invalid)}")
print(f"Fails hierarch, passes linear: {len(true_invalid_but_linear_valid)}")

100%|██████████| 20000/20000 [00:00<00:00, 51620.11it/s]
100%|██████████| 200000/200000 [00:03<00:00, 53286.97it/s]
100%|██████████| 20000/20000 [00:00<00:00, 379442.82it/s]
100%|██████████| 200000/200000 [00:00<00:00, 544110.98it/s]

Passes both heuristics: 19116
Fails both heuristics: 198239
Passes hierarch, fails linear: 884
Fails hierarch, passes linear: 1761





In [66]:
# Create generalization scenario
train_examples = 3200

dataset_gen_train = datasets.Dataset.from_dict({"text": both_valid[:1600] + both_invalid[:1600], "labels": [1] * 1600 + [0] * 1600}).shuffle()
dataset_gen_eval = datasets.Dataset.from_dict({"text": true_valid_but_linear_invalid[:400] + true_invalid_but_linear_valid[:400], "labels": [1] * 400 + [0] * 400}).shuffle()

dataset_gen_train_eval = dataset_gen_train.class_encode_column('labels').train_test_split(test_size=0.25, stratify_by_column='labels')
dataset_gen_test = dataset_gen_eval.class_encode_column('labels')

gen_dataset_dict = datasets.DatasetDict({
    'train': dataset_gen_train_eval['train'],
    'eval': dataset_gen_train_eval['test'],
    'test': dataset_gen_test
})
# dataset.train_test_split(test_size=0.3, stratify_by_column='labels')
gen_dataset_dict.push_to_hub("michaelginn/latent-trees-agreement-GEN")

Stringifying the column:   0%|          | 0/3200 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/3200 [00:00<?, ? examples/s]

Stringifying the column:   0%|          | 0/800 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/800 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/748 [00:00<?, ?B/s]

In [67]:
gen_dataset_dict

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 2400
    })
    eval: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
})

In [64]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 2400
    })
    eval: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
})