In [1]:
from nltk import PCFG, Nonterminal
from nltk.parse.generate import generate

def equal_production(terminals, total=1):
    """Shorthand to write a list of terminals that are all equally likely"""
    terminals = terminals.split(' | ')
    return ' | '.join([f"'{terminal}' [{total/len(terminals)}]" for terminal in terminals])

equal_production('man | woman | girl')

"'man' [0.3333333333333333] | 'woman' [0.3333333333333333] | 'girl' [0.3333333333333333]"

In [33]:
import random
from tqdm import tqdm

# Randomly generate sentences using CFG
def weighted_choice(choices):
    total = sum(w for c, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c
        upto += w

def generate_sentence(grammar, symbol=Nonterminal('S'), depth=0):
    productions = grammar.productions(lhs=symbol)
    chosen_prod = weighted_choice([(prod, prod.prob()) for prod in productions])
    
    sentence = []
    max_depth = depth

    # print(symbol)
    for sym in chosen_prod.rhs():
        if isinstance(sym, Nonterminal):
            nonterminal_is_relative_clause = sym == Nonterminal('Rel_Sg') or sym == Nonterminal('Rel_Pl')
            phrase, new_max_depth = generate_sentence(grammar, sym, depth + 1 if nonterminal_is_relative_clause else depth)
            sentence.extend(phrase)
            max_depth = max(max_depth, new_max_depth)
        else:
            sentence.append(sym)
            
    return sentence, max_depth

# Morphology
from pyfoma import *

fsts = {}
fsts['lex'] = FST.re("[a-zA-Z\+]*")

fsts['sib']       = FST.re("s|sh|z|zh|ch|x")
fsts['C']         = FST.re("[a-z] - [aeiou]")
fsts['sibrk']     = FST.re("$^rewrite('':e / $sib _ \+ s)", fsts)
fsts['yrule']     = FST.re("$^rewrite(y:(ie) / $C _ \+ s)", fsts)
fsts['cleanup']   = FST.re("$^rewrite(\+:'')")
fsts['grammar']   = FST.re("$lex @ $sibrk @ $yrule @ $cleanup", fsts)

def fix_morphology(words):
    """Combines morpheme clusters into proper words using an FST"""
    combined_words = []
    for word in words:
        if word[0] == "+":
            combined_words[-1] += word
        else:
            combined_words.append(word)
    return [list(fsts['grammar'].generate(word))[0] for word in combined_words]

def sample_sentences(grammar, n):
    sents = [generate_sentence(grammar) for _ in tqdm(range(n))]
    sents = [(' '.join(fix_morphology(sent)), max_depth) for (sent, max_depth) in sents]
    # sents = [sent.capitalize() + "." for sent in sents]
    return sents

In [34]:
pcfg = PCFG.fromstring(f"""
S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]

VP_3Sg        -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
VP            -> VT      NP_acc [0.475] | VI      [0.475] | VP     'and' VP     [0.05]

NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
Det_Sg        -> {equal_production('the | a')}

NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
Det_Pl        -> {equal_production('the | those | these')}

NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]

N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.2] | N_common [0.6]
N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}

Rel_Sg         -> 'that' VP_3Sg [1]
Rel_Pl         -> 'that' VP [1]

VI            -> {equal_production('run | walk | think | laugh | ponder')}
VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}

Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
""")

sample_sentences(pcfg, 20)

100%|██████████| 20/20 [00:00<00:00, 25739.82it/s]


[('she kicks a red red cat', 0),
 ('you run', 0),
 ('a girl that kisses the duck fights the ducks', 1),
 ('he hugs these cats and these cats that think', 1),
 ('they laugh', 0),
 ('he punches me', 0),
 ('she ponders and thinks', 0),
 ('the girls that kick those boys run', 1),
 ('they punch me', 0),
 ('you run', 0),
 ('a physicist punches us', 0),
 ('a cheese that loves the dude and laughs runs', 1),
 ('they laugh and run', 0),
 ('the cats and the rutabagas walk', 0),
 ('he kisses the girls', 0),
 ('a linguist that ponders thinks', 1),
 ('a turtle fights a cat that fights them', 1),
 ('he walks', 0),
 ('you walk', 0),
 ('we love those physicists', 0)]

In [35]:
# agreement_violations = PCFG.fromstring(f"""
# S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]
#
# VP_3Sg        -> VT      NP_acc [0.475] | VI      [0.475] | VP_3Sg 'and' VP_3Sg [0.05]
# VP            -> VT '+s' NP_acc [0.475] | VI '+s' [0.475] | VP     'and' VP     [0.05]
#
# NP_3Sg_nom    -> 'he' [0.25] | 'she' [0.25] | NP_common_Sg [0.5]
# NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
# Det_Sg        -> {equal_production('the | a')}
#
# NP_nom        -> {equal_production('I | you | we | they', total=0.5)} | NP_common_Pl [0.5]
# NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
# Det_Pl        -> {equal_production('the | those | these')}
#
# NP_acc        -> {equal_production('me | you | us | them', total=0.30)} | NP_common_Pl [0.35] | NP_common_Sg [0.35]
#
# N_bar_common_Sg  -> Adj N_bar_common_Sg [0.2] | N_common Rel_Sg [0.15] | N_common [0.65]
# N_bar_common_Pl  -> Adj N_bar_common_Pl [0.2] | N_common '+s' Rel_Pl [0.15] | N_common '+s' [0.65]
# N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}
#
# Rel_Sg         -> 'that' VP_3Sg [1]
# Rel_Pl         -> 'that' VP [1]
#
# VI            -> {equal_production('run | walk | think | laugh | ponder')}
# VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}
#
# Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
# """)
#
# sample_sentences(agreement_violations, 20)

verb_sing = ["runs", "walks", "thinks", "laughs", "ponders", "kicks", "kisses", "hugs", "punches", "fights", "loves"]
verb_pl = ["run", "walk", "think", "laugh", "ponder", "kick", "kiss", "hug", "punch", "fight", "love"]

def deform_sentence(sentence: str):
    """Deforms a sentence by randomly switching one or more verbs from singular to plural or vice versa"""
    words = sentence.split(' ')
    verb_indices = []
    for i, word in enumerate(words):
        if word in verb_sing or word in verb_pl:
            verb_indices.append(i)

    indices_to_deform = random.sample(verb_indices, random.randint(1, len(verb_indices)))
    for index in indices_to_deform:
        word = words[index]
        if word in verb_sing:
            verb_index = verb_sing.index(word)
            words[index] = verb_pl[verb_index]
        elif word in verb_pl:
            verb_index = verb_pl.index(word)
            words[index] = verb_sing[verb_index]

    return ' '.join(words)

deform_sentence('the cat that runs ponders and hugs the duck')

'the cat that runs ponders and hug the duck'

In [39]:
import datasets

random.seed(1)
valid_num = 2000
invalid_num = 2000
valid = sample_sentences(pcfg, valid_num)
invalid = [(deform_sentence(sentence), max_depth) for sentence, max_depth in sample_sentences(pcfg, invalid_num)]

texts = [sent for sent, _ in (valid + invalid)]
labels = [1] * valid_num + [0] * invalid_num
depths = [depth for _, depth in (valid + invalid)]

dataset = datasets.Dataset.from_dict({"text": texts, "labels": labels, "depth": depths}).shuffle()

dataset = dataset.class_encode_column('labels')
dataset = dataset.train_test_split(test_size=0.4, stratify_by_column='labels')
dataset_eval_test = dataset['test'].train_test_split(test_size=0.5, stratify_by_column='labels')
dataset = datasets.DatasetDict({
    'train': dataset['train'],
    'eval': dataset_eval_test['train'],
    'test': dataset_eval_test['test']
})

dataset.push_to_hub("michaelginn/latent-trees-agreement-ID")

100%|██████████| 2000/2000 [00:00<00:00, 33469.42it/s]
100%|██████████| 2000/2000 [00:00<00:00, 37011.61it/s]


Stringifying the column:   0%|          | 0/4000 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/4000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/748 [00:00<?, ?B/s]

In [40]:
noun_sing = ["girl", "boy", "cat", "turtle", "rutabaga", "duck", "cheese", "dude", "rabbit", "wug", "linguist", "physicist", "lady", "dog", "cat", "bird", "he", "she"]
noun_pl = ["girls", "boys", "cats", "turtles", "rutabagas", "ducks", "cheeses", "dudes", "rabbits", "wugs", "linguists", "physicists", "ladies", "dogs", "cats", "birds", "I", "you",  "me", "us", "them", "they", "we"]

verb_sing = ["runs", "walks", "thinks", "laughs", "ponders", "kicks", "kisses", "hugs", "punches", "fights", "loves"]
verb_pl = ["run", "walk", "think", "laugh", "ponder", "kick", "kiss", "hug", "punch", "fight", "love"]

def check_linear_heuristic(sentence: str):
    # Returns false if the sentence fails the linear heuristic (verb should agree with most recent noun)
    most_recent_noun = None
    most_recent_noun_num = None # 'pl' or 'sg'

    for word in sentence.split(' '):
        if word in noun_sing:
            most_recent_noun = word
            most_recent_noun_num = 'sg'
        elif word in noun_pl:
            most_recent_noun = word
            most_recent_noun_num = 'pl'
        elif word in verb_sing:
            if most_recent_noun_num == 'pl':
                return False
        elif word in verb_pl:
            if most_recent_noun_num == 'sg':
                return False
    return True

print(check_linear_heuristic('the linguist that punches these rutabagas thinks'))
print(check_linear_heuristic('the linguist that punches this rutabaga thinks'))
print(check_linear_heuristic('the linguist that punches these rutabagas think'))
print(check_linear_heuristic('they thinks'))

False
True
True
False


In [43]:
import datasets

random.seed(1)

def generate_with_heuristic(grammar, n_valid, n_invalid):
    valid = sample_sentences(grammar, n_valid)
    invalid = [(deform_sentence(sentence), depth) for sentence, depth in sample_sentences(grammar, n_invalid)]

    # Filter sents that pass/fail true heuristic to also pass/fail linear heuristic so we have an ambiguous training dataset
    both_valid = []
    true_valid_but_linear_invalid = []
    for valid_sent, depth in tqdm(valid):
        if check_linear_heuristic(valid_sent):
            both_valid.append((valid_sent, depth))
        else:
            true_valid_but_linear_invalid.append((valid_sent, depth))

    both_invalid = []
    true_invalid_but_linear_valid = []
    for invalid_sent, depth in tqdm(invalid):
        if not check_linear_heuristic(invalid_sent):
            both_invalid.append((invalid_sent, depth))
        else:
            true_invalid_but_linear_valid.append((invalid_sent, depth))

    print(f"Passes both heuristics: {len(both_valid)}")
    print(f"Fails both heuristics: {len(both_invalid)}")
    print(f"Passes hierarch, fails linear: {len(true_valid_but_linear_invalid)}")
    print(f"Fails hierarch, passes linear: {len(true_invalid_but_linear_valid)}")

    return both_valid, true_valid_but_linear_invalid, both_invalid, true_invalid_but_linear_valid

In [44]:
# Create generalization scenario
train_examples = 3200

valid_num = 20000
invalid_num = 200000

both_valid, true_valid_but_linear_invalid, both_invalid, true_invalid_but_linear_valid = generate_with_heuristic(pcfg, valid_num, invalid_num)

train_sents = [sent for sent, _ in both_valid[:1600] + both_invalid[:1600]]
train_depths = [depth for _, depth in both_valid[:1600] + both_invalid[:1600]]
train_labels = [1] * 1600 + [0] * 1600

eval_sents = [sent for sent, _ in true_valid_but_linear_invalid[:400] + true_invalid_but_linear_valid[:400]]
eval_depths = [depth for _, depth in true_valid_but_linear_invalid[:400] + true_invalid_but_linear_valid[:400]]
eval_labels = [1] * 400 + [0] * 400

dataset_gen_train = datasets.Dataset.from_dict({"text": train_sents, "labels": train_labels, "depth": train_depths}).shuffle()
dataset_gen_eval = datasets.Dataset.from_dict({"text": eval_sents, "labels": eval_labels, "depth": eval_depths}).shuffle()

dataset_gen_train_eval = dataset_gen_train.class_encode_column('labels').train_test_split(test_size=0.25, stratify_by_column='labels')
dataset_gen_test = dataset_gen_eval.class_encode_column('labels')

gen_dataset_dict = datasets.DatasetDict({
    'train': dataset_gen_train_eval['train'],
    'eval': dataset_gen_train_eval['test'],
    'test': dataset_gen_test
})
# dataset.train_test_split(test_size=0.3, stratify_by_column='labels')
gen_dataset_dict.push_to_hub("michaelginn/latent-trees-agreement-GEN")

100%|██████████| 20000/20000 [00:00<00:00, 35069.87it/s]
100%|██████████| 200000/200000 [00:05<00:00, 34615.10it/s]
100%|██████████| 20000/20000 [00:00<00:00, 402988.47it/s]
100%|██████████| 200000/200000 [00:00<00:00, 566816.60it/s]

Passes both heuristics: 19116
Fails both heuristics: 198239
Passes hierarch, fails linear: 884
Fails hierarch, passes linear: 1761





Stringifying the column:   0%|          | 0/3200 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/3200 [00:00<?, ? examples/s]

Stringifying the column:   0%|          | 0/800 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/800 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/747 [00:00<?, ?B/s]

In [19]:
gen_dataset_dict

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 2400
    })
    eval: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 800
    })
})

# Extreme Generalization

In [46]:
extreme_pcfg = PCFG.fromstring(f"""
S             -> NP_3Sg_nom VP_3Sg [0.5] | NP_nom VP [0.5]

VP_3Sg        -> VT '+s' NP_acc [0.8] | VI '+s' [0.1] | VP_3Sg 'and' VP_3Sg [0.1]
VP            -> VT      NP_acc [0.8] | VI      [0.1] | VP     'and' VP     [0.1]

NP_3Sg_nom    -> 'he' [0.05] | 'she' [0.05] | NP_common_Sg [0.9]
NP_common_Sg  -> Det_Sg N_bar_common_Sg [1]
Det_Sg        -> {equal_production('the | a')}

NP_nom        -> {equal_production('I | you | we | they', total=0.1)} | NP_common_Pl [0.9]
NP_common_Pl  -> Det_Pl N_bar_common_Pl [0.8] | NP_common_Pl 'and' NP_common_Pl [0.2]
Det_Pl        -> {equal_production('the | those | these')}

NP_acc        -> {equal_production('me | you | us | them', total=0.1)} | NP_common_Pl [0.45] | NP_common_Sg [0.45]

N_bar_common_Sg  -> Adj N_bar_common_Sg [0.05] | N_common Rel_Sg [0.8] | N_common [0.15]
N_bar_common_Pl  -> Adj N_bar_common_Pl [0.05] | N_common '+s' Rel_Pl [0.8] | N_common '+s' [0.15]
N_common      -> {equal_production('girl | boy | cat | turtle | rutabaga | duck | cheese | dude | rabbit | wug | linguist | physicist | lady | dog | cat | bird')}

Rel_Sg         -> 'that' VP_3Sg [1]
Rel_Pl         -> 'that' VP [1]

VI            -> {equal_production('run | walk | think | laugh | ponder')}
VT            -> {equal_production('kick | kiss | hug | punch | fight | love')}

Adj           -> {equal_production('big | small | happy | mad | red | blue | sparkling | shiny')}
""")

sents = sample_sentences(extreme_pcfg, 20)

print(sents)
print("Average depth", sum(depth for _, depth in sents) / len(sents))

100%|██████████| 20/20 [00:00<00:00, 1949.34it/s]

[('a cat that kicks the wugs that fight these ladies that fight those rabbits that kiss a dude and these boys that kick these physicists and those birds that kiss the birds that love the bird that hugs the dog that hugs a rabbit that loves those ducks that love those dudes and those wugs that fight the lady that kicks these linguists that kiss the boy that kisses those physicists that hug a boy that hugs these ducks and those girls and the birds that fight a girl that hugs a shiny boy and the boys that kick me and those physicists that hug a boy that punches the cats and those ladies that walk kisses those birds that think', 10), ('those boys that walk and kiss you and ponder and love these physicists that kiss the linguists that love the wugs that hug the dude that kisses the boy and laugh and ponder and the rabbits that walk punch a duck that hugs these cats that love those cheeses and these girls that kiss these blue physicists that punch a cat that hugs those girls that love the bo




In [68]:
# Create extreme generalization scenario
train_examples = 3200

# both_valid, true_valid_but_linear_invalid, both_invalid, true_invalid_but_linear_valid = generate_with_heuristic(extreme_pcfg, 2000, 300000)

# dataset_extreme_train = datasets.Dataset.from_dict({"text": both_valid[:1600] + both_invalid[:1600], "labels": [1] * 1600 + [0] * 1600}).shuffle()

true_valid_but_linear_invalid = [(sent, depth) for sent, depth in true_valid_but_linear_invalid if depth > 3]
true_invalid_but_linear_valid = [(sent, depth) for sent, depth in true_invalid_but_linear_valid if depth > 3]

# Sort by depth
true_valid_but_linear_invalid.sort(key=lambda x: x[1])
true_invalid_but_linear_valid.sort(key=lambda x: x[1])

extreme_eval_sents = [sent for sent, _ in true_valid_but_linear_invalid[:400] + true_invalid_but_linear_valid[:400]]
extreme_eval_depths = [depth for _, depth in true_valid_but_linear_invalid[:400] + true_invalid_but_linear_valid[:400]]
extreme_eval_labels = [1] * 400 + [0] * 400

dataset_extreme_eval = datasets.Dataset.from_dict({"text": extreme_eval_sents, "labels": extreme_eval_labels, "depth": extreme_eval_depths}).shuffle()

# dataset_extreme_train_eval = dataset_extreme_train.class_encode_column('labels').train_test_split(test_size=0.25, stratify_by_column='labels')
dataset_extreme_eval = dataset_extreme_eval.class_encode_column('labels')

extreme_dataset_dict = datasets.DatasetDict({
    'train': dataset_gen_train_eval['train'],
    'eval': dataset_gen_train_eval['test'],
    'test': dataset_extreme_eval
})
extreme_dataset_dict.push_to_hub("michaelginn/latent-trees-agreement-GENX")

Stringifying the column:   0%|          | 0/800 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/800 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/603 [00:00<?, ?B/s]

In [69]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# max_length = 100
def tokenize_function(example):
    return tokenizer(example['text'])
dataset_extreme_eval_tok = dataset_extreme_eval.map(tokenize_function, batched=True, load_from_cache_file=False)

# Get longest token sequence length
max_length = max(len(x) for x in dataset_extreme_eval_tok["input_ids"])
print(max_length)

# Average length
print(sum(len(x) for x in dataset_extreme_eval_tok["input_ids"]) / len(dataset_extreme_eval_tok["input_ids"]))

# Number of rows with more than 512 tokens
print(sum(len(x) > 512 for x in dataset_extreme_eval_tok["input_ids"]))

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

252
52.41625
0
