In [18]:
import spacy
import utils
import re
import config

from collections import defaultdict, Counter
from minicons.utils import get_batch
from tqdm import tqdm

In [2]:
aochildes_dev = utils.read_file("../../smolm/data/corpora/babylm_data/babylm_dev/aochildes.dev")
aochildes_test = utils.read_file("../../smolm/data/corpora/babylm_data/babylm_test/aochildes.test")

aochildes_unseen = aochildes_dev + aochildes_test
len(aochildes_unseen)

130000

In [8]:
def get_postags(texts, batch_size, processor):
    taglist = []
    for doc in tqdm(processor.pipe(texts, disable=["tok2vec", "parser", "attribute_ruler", "lemmatizer", "ner"], batch_size=batch_size)):
        sentlist = []
        for entity in doc:
            sentlist += [(entity.text, entity.tag_)]
        taglist += [sentlist]
    return taglist

# recursively get children of a verb using spacy
def get_children_flatten(token, depth=0, dep=False, return_tokens=False):
    children = []
    for child in token.children:
        if dep:
            if return_tokens:
                children.append((child.text.lower(), child.dep_, child.tag_, depth, child.i, child))
            else:
                children.append((child.text.lower(), child.dep_, child.tag_, depth, child.i))
        else:
            children.append(child.text.lower())
        children.extend(get_children_flatten(child, depth+1, dep, return_tokens))
    return children

def collect_args(children_obj, hyp = 'do'):
    args = {'theme': '', 'recipient': '', 'theme_pos': '', 'recipient_pos': ''}
    hyp_args = []
    if hyp == "do":
        for child in children_obj:
            if child[1] == "dobj" or child[1] == "dative":
                hyp_args.append((child[0], child[-1], child[2]))
        
        # sort by index
        hyp_args = sorted(hyp_args, key=lambda x: x[1])
        args['recipient'] = hyp_args[0][0]
        args['theme'] = hyp_args[1][0]
        args['recipient_pos'] = hyp_args[0][-1]
        args['theme_pos'] = hyp_args[1][-1]

    elif hyp == "pp":
        for child in children_obj:
            if child[1] == "pobj" or child[1] == "dobj":
                hyp_args.append((child[0], child[-1], child[2]))
        
        # sort by index
        hyp_args = sorted(hyp_args, key=lambda x: x[1])
        args['recipient'] = hyp_args[1][0]
        args['theme'] = hyp_args[0][0]
        args['recipient_pos'] = hyp_args[1][-1]
        args['theme_pos'] = hyp_args[0][-1]

    return args

def get_datives(texts, batch_size, processor):
    dos, pps = [], []
    for doc in tqdm(processor.pipe(texts, disable = ["ner"], batch_size = batch_size)):
        do = False
        pp = False
        for entity in doc:
            if entity.pos_ == "VERB":
                children = get_children_flatten(entity, 0, dep=True)
                if len(children) > 0:
                    tokens, dep, pos_string, depth, index = list(zip(*children))
                    if "to" in tokens:
                        # possibility for pp
                        dep_depth = [f"{d}_{str(depth[i])}" for i, d in enumerate(dep)]
                        tok_dep = [f"{tokens[i]}_{dep[i]}" for i in range(len(tokens))]
                        if ("dobj_0" in dep_depth and "dative_0" in dep_depth and "pobj_1" in dep_depth) or ("dobj_0" in dep_depth and "prep_0" in dep_depth and "pobj_1" in dep_depth):
                            if "to_dative" in tok_dep or "to_prep" in tok_dep:
                                pp = True
                                # pps.append(sentence)
                                # print(children)
                                # args = collect_args(children, "pp")
                                pps.append((doc.text, entity.lemma_, entity.text, entity.tag_, children))
                                break
                    else:
                        # possibility for DO
                        # concatenation of dep and depth
                        dep_depth = [f"{d}_{str(depth[i])}" for i, d in enumerate(dep)]
                        # pos_dep = [f"{pos_string[i]}_{dep[i]}" for i in range(len(pos_string))]
                        tokens_dep = [f"{tokens[i]}_{dep[i]}" for i in range(len(tokens))]
                        if ("dobj_0" in dep_depth and "dative_0" in dep_depth) or Counter(dep_depth)['dobj_0'] >= 2:
                            if 'for_dative' not in tokens_dep and 'for_dobj' not in tokens_dep:
                                do = True
                                # dos.append(sentence)
                                # print(children)
                                # args = collect_args(children)
                                dos.append((doc.text, entity.lemma_, entity.text, entity.tag_, children))
                                break

    return dos, pps

In [4]:
# spacy setup
gpu = spacy.prefer_gpu(2)
nlp = spacy.load("en_core_web_trf")

In [9]:
DOS, PPS = [], []
for batch in get_batch(aochildes_unseen, batch_size=8192):
    dos, pps = get_datives(batch, 8192, nlp)
    DOS.extend(dos)
    PPS.extend(pps)

8192it [00:05, 1457.77it/s]
8192it [00:05, 1445.09it/s]
8192it [00:05, 1437.17it/s]
8192it [00:05, 1444.57it/s]
8192it [00:05, 1436.25it/s]
8192it [00:05, 1390.85it/s]
8192it [00:05, 1476.63it/s]
8192it [00:05, 1511.55it/s]
8192it [00:05, 1462.97it/s]
8192it [00:05, 1498.94it/s]
8192it [00:05, 1524.20it/s]
8192it [00:05, 1493.88it/s]
8192it [00:05, 1463.55it/s]
8192it [00:05, 1439.72it/s]
8192it [00:05, 1578.05it/s]
7120it [00:05, 1412.56it/s]


In [10]:
len(DOS), len(PPS)

(1617, 809)

In [36]:
# PPS
vbds = [pp[0] for pp in DOS if pp[3] == 'VB' and pp[1] in config.DATIVE_VERBS]
len(vbds)


857

In [37]:
vbds

['you could make him some tea ?',
 'you could give him some medicine .',
 'you gonna give mommy a good morning kiss !',
 'bring them the muffins !',
 "and we're not quite set up yet just give me a minute .",
 "so we're gonna give you some food .",
 "yeah i'm gonna give you tomato sauce also ?",
 "here i'll give her a little bowl ?",
 'and give him a ticket .',
 "i can read you the book william but we're not gonna watch that right now okay .",
 "alright well why don't i give you call when ?",
 'hopefully looking for a morning class if you could give me a call .',
 'that will give you an even spread of tape .',
 "don't give me that cookie .",
 'better give me kiss .',
 'hand me that boot .',
 'show me the red light .',
 "do you wanna give sally a bath when you're done ?",
 'yeah i think that means that you wanna give sally a bath too .',
 "actually i'll bring you some more of your cards .",
 "okay i'll find you another letter ?",
 'are you gonna give your coffee shop a name .',
 'do you 

In [28]:
len(set(vbds))

97

In [13]:
DO_verbs, PP_verbs = [do[3] for do in DOS], [pp[3] for pp in PPS]

In [14]:
Counter(DO_verbs), Counter(PP_verbs) 

(Counter({'VB': 1124,
          'VBD': 258,
          'VBG': 119,
          'VBP': 70,
          'VBZ': 33,
          'VBN': 13}),
 Counter({'VBP': 47, 'VB': 545, 'VBG': 70, 'VBD': 127, 'VBZ': 18, 'VBN': 2}))

In [38]:
DOS_FILTERED, PPS_FILTERED = [], []

DOS_DISCARDED, PPS_DISCARDED = [], []

for sentence, lemma, verb, verb_pos, children in DOS:
    if lemma in config.DATIVE_VERBS:
        args = collect_args(children)
        # if args['theme_pos'] in ['NN', 'NNS', 'NNP', 'PRP', 'DT'] and args['recipient_pos'] in ['NN', 'NNS', 'NNP', 'PRP']:
        #     DOS_FILTERED.append((sentence, lemma, verb, args['theme'], args['recipient']))
        # else:
        #     DOS_DISCARDED.append((sentence, lemma, verb, args['theme'], args['recipient'], args['theme_pos'], args['recipient_pos']))

        DOS_FILTERED.append((sentence, lemma, verb, verb_pos, args['theme'], args['recipient'], args['theme_pos'], args['recipient_pos']))

for sentence, lemma, verb, verb_pos, children in PPS:
    if lemma in config.DATIVE_VERBS:
        args = collect_args(children, "pp")
        # tag = children[2]
        # if args['theme_pos'] in ['NN', 'NNS', 'NNP', 'PRP', 'DT'] and args['recipient_pos'] in ['NN', 'NNS', 'NNP', 'PRP']:
        #     PPS_FILTERED.append((sentence, lemma, verb, args['theme'], args['recipient']))
        # else:
        #     PPS_DISCARDED.append((sentence, lemma, verb, args['theme'], args['recipient'], args['theme_pos'], args['recipient_pos']))
        PPS_FILTERED.append((sentence, lemma, verb, verb_pos, args['theme'], args['recipient'], args['theme_pos'], args['recipient_pos']))

In [40]:
len(DOS_FILTERED), len(PPS_FILTERED)

(1239, 579)

In [43]:
vbds = [pp[0] for pp in PPS_FILTERED if pp[3] == 'VBD' and pp[1] in config.DATIVE_VERBS]
len(vbds)

104

In [46]:
# write both to csv in data/
import csv

with open("../data/aochildes_unseen_dos.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["sentence", "lemma", "verb", "verb_pos", "theme", "recipient", "theme_pos", "recipient_pos"])
    writer.writerows(set(DOS_FILTERED))

with open("../data/aochildes_unseen_pps.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["sentence", "lemma", "verb", "verb_pos", "theme", "recipient", "theme_pos", "recipient_pos"])
    writer.writerows(set(PPS_FILTERED))