In [None]:
####################################################
#
# All the needed data loading helper functions
# and libs.
#
####################################################
from __future__ import absolute_import, division, print_function

import collections
import unicodedata
import six
from torch.utils.data import Dataset

def glove2dict(src_filename):
    """
    GloVe vectors file reader.
    Parameters
    ----------
    src_filename : str
        Full path to the GloVe file to be processed.
    Returns
    -------
    dict
        Mapping words to their GloVe vectors as `np.array`.
    """
    # This distribution has some words with spaces, so we have to
    # assume its dimensionality and parse out the lines specially:
    if '840B.300d' in src_filename:
        line_parser = lambda line: line.rsplit(" ", 300)
    else:
        line_parser = lambda line: line.strip().split()
    data = {}
    with open(src_filename, encoding='utf8') as f:
        while True:
            try:
                line = next(f)
                line = line_parser(line)
                data[line[0]] = np.array(line[1: ], dtype=np.float64)
            except StopIteration:
                break
            except UnicodeDecodeError:
                pass
    return data


def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

def convert_tokens_to_ids(vocab, tokens):
    """Converts a sequence of tokens into ids using the vocab."""
    ids = []
    for token in tokens:
        if token not in vocab.keys():
            ids.append(vocab['[UNK]'])
        else:
            ids.append(vocab[token])
    return ids
        
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

class WordLevelTokenizer(object):
    """Runs end-to-end tokenziation."""
    def __init__(self, vocab_file, config, delimiter=" ", max_seq_len=128):
        self.vocab = load_vocab(vocab_file)
        self.vocab_reverse = collections.OrderedDict()
        for k, v in self.vocab.items():
            self.vocab_reverse[v] = k
        self.pad_token_id = config.pad_token_id
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.unk_token_id = config.unk_token_id
        self.mask_token_id = config.mask_token_id
        self.special_token_ids = set(
            [config.pad_token_id, config.bos_token_id, config.eos_token_id, 
            config.unk_token_id, config.mask_token_id]
        )
        
        self.max_seq_len = max_seq_len
        self.delimiter = delimiter
        
    def tokenize(self, text):
        split_tokens = []
        for token in text.split(self.delimiter):
            split_tokens.append(token)
        return split_tokens
    
    def convert_tokens_to_ids(self, tokens):
        return convert_tokens_to_ids(self.vocab, tokens)
    
    def __call__(self, text):
        original = self.convert_tokens_to_ids(self.tokenize(text))
        original = original[:(self.max_seq_len-2)]
        return [self.bos_token_id] + original + [self.eos_token_id]
    
    def batch_decode(self, pred_labels, skip_special_tokens=True):
        decode_labels_batch = []
        for labels in pred_labels:
            decode_labels = []
            for l in labels.tolist():
                if l == self.eos_token_id:
                    break
                if l not in self.special_token_ids:
                    decode_labels += [self.vocab_reverse[l]]
            decode_labels_batch += [self.delimiter.join(decode_labels)]
        return decode_labels_batch

class COGSDataset(Dataset):
    def __init__(
        self, cogs_path, 
        src_tokenizer, tgt_tokenizer,
        partition, max_len=512, max_examples=-1,
        least_to_most=False
    ):
        self._items = [] # ()
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
    
        self.eval_cat = []
        is_gen_dev = False
        if partition == "gen-dev":
            partition = "gen"
            is_gen_dev = True
        for l in open(f"{cogs_path}/{partition}.tsv", "r").readlines():
            if max_examples != -1 and len(self._items) > max_examples:
                break
            text, sparse, cat = l.split("\t")
            src_input_ids = src_tokenizer(text)
            tgt_input_ids = tgt_tokenizer(sparse)
            self._items += [(src_input_ids, tgt_input_ids)]
            self.eval_cat += [cat.strip()]
            
        if "train" in partition:
            random.shuffle(self._items)
            if least_to_most:
                self._items = sorted(
                    self._items, key = lambda i: len(i[0]), 
                    reverse=False
                )

        if is_gen_dev:
            # this is a strange partition accordingly to previous works.
            # well, since other ppl are using this, i have to do it as well!
            random.shuffle(self._items)
            self._items = sorted(
                self._items, key = lambda i: len(i[0]), 
                reverse=True if not least_to_most else False
            )
            self._items = self._items[:len(self._items)//10]
            
    def __len__(self):
        return len(self._items)

    def __getitem__(self, item):
        return self._items[item]
    
    def collate_batch(self, batch):
        src_seq_lens = []
        tgt_seq_lens = []
        for i in range(len(batch)):
            src_seq_lens += [len(batch[i][0])]
            tgt_seq_lens += [len(batch[i][1])]
        max_src_seq_lens = max(src_seq_lens)
        max_tgt_seq_lens = max(tgt_seq_lens)
        
        input_ids_batch = []
        mask_batch = []
        labels_batch = []
        for i in range(len(batch)):
            input_ids = batch[i][0] + [0] * (max_src_seq_lens - src_seq_lens[i])
            input_ids_batch += [input_ids]
            
            mask = [1] * src_seq_lens[i] + [0] * (max_src_seq_lens - src_seq_lens[i])
            mask_batch += [mask]
            
            labels = batch[i][1] + [0] * (max_tgt_seq_lens - tgt_seq_lens[i])
            labels_batch += [labels]

        return {"input_ids": torch.tensor(input_ids_batch),
                "labels": torch.tensor(labels_batch),
                "attention_mask": torch.tensor(mask_batch)}

In [None]:
####################################################
#
# All the needed evaluation related helper functions
# and libs.
#
####################################################
import random, torch, re, os
import numpy as np
from transformers import EncoderDecoderConfig, EncoderDecoderModel
from tqdm import tqdm
from torch.utils.data import DataLoader, SequentialSampler

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def find_partition_name(name, lf):
    if lf == "cogs":
        return name
    else:
        return name+f"_{lf}"
    
def check_equal(left_lf, right_lf):
    index_mapping = {}
    current_idx = 0
    for t in left_lf.split():
        if t.isnumeric():
            if int(t) not in index_mapping:
                index_mapping[int(t)] = current_idx
                current_idx += 1
    decoded_labels_ii = []
    for t in left_lf.split():
        if t.isnumeric():
            decoded_labels_ii += [str(index_mapping[int(t)])]
        else:
            decoded_labels_ii += [t]

    index_mapping = {}
    current_idx = 0
    for t in right_lf.split():
        if t.isnumeric():
            if int(t) not in index_mapping:
                index_mapping[int(t)] = current_idx
                current_idx += 1
    decoded_preds_ii = []
    for t in right_lf.split():
        if t.isnumeric():
            decoded_preds_ii += [str(index_mapping[int(t)])]
        else:
            decoded_preds_ii += [t]


    decoded_labels_ii_str = " ".join(decoded_labels_ii)
    decoded_preds_ii_str = " ".join(decoded_preds_ii)

    if decoded_preds_ii_str == decoded_labels_ii_str:
        return True
    return False

recogs_neoD_np_re = re.compile(r"""
    ^
    \s*(\*)?
    \s*(\w+?)\s*
    \(
    \s*(.+?)\s*
    \)
    \s*$""", re.VERBOSE)

recogs_neoD_verb_re = re.compile(r"""
    ^
    \s*(\w+?)\s*
    \(
    \s*([0-9]+?)\s*
    \)
    \s*$""", re.VERBOSE)

recogs_neoD_pred_re = re.compile(r"""
    ^
    \s*(\w+?)\s*
    \(
    \s*(.+?)\s*
    ,
    \s*(.+?)\s*
    \)
    \s*$""", re.VERBOSE)

recogs_neoD_mod_re = re.compile(r"""
    ^
    \s*(\w+?)\s*
    \.
    \s*(\w+?)\s*
    \(
    \s*(.+?)\s*
    ,
    \s*(.+?)\s*
    \)
    \s*$""", re.VERBOSE)

def translate_invariant_form_neoD(lf):
    nouns = lf.split(" AND ")[0].split(" ; ")[:-1]
    complements = set(lf.split(" ; ")[-1].split())
    nouns_map = {}
    new_var = 0
    for noun in nouns:
        # check format.
        if not recogs_neoD_np_re.search(noun):
            return {} # this is format error, we cascade the error.
        _, _, original_var = recogs_neoD_np_re.search(noun).groups()
        if original_var not in complements:
            return {} # var must be used, we cascade the error.
        new_noun = noun.replace(str(original_var), str(new_var))
        nouns_map[original_var] = new_noun
        new_var += 1
        
    nmod_conjs_set = set([])
    conjs = lf.split(" ; ")[-1].split(" AND ")
    vp_conjs_map = {}
    nested_conjs = []
    childen_count_map = {}
    for conj in conjs:
        if "nmod" in conj:
            if not recogs_neoD_mod_re.search(conj):
                return {} # this is format error, we cascade the error.
            role, pred, first_arg, second_arg = recogs_neoD_mod_re.search(conj).groups()
            new_conj = f"{role} . {pred} ( {nouns_map[first_arg]} , {nouns_map[second_arg]} )"
            nmod_conjs_set.add(new_conj)
        else:
            if recogs_neoD_verb_re.search(conj):
                # candidate for mapping verb.
                pred, arg = recogs_neoD_verb_re.search(conj).groups()
                if not arg.isnumeric():
                    return {}
                new_conj = f"{pred}"
                if arg in vp_conjs_map:
                    vp_conjs_map[arg].append(new_conj)
                else:
                    vp_conjs_map[arg] = [new_conj]
                continue
            if not recogs_neoD_pred_re.search(conj):
                return {} # this is format error, we cascade the error.
            
            role, first_arg, second_arg = recogs_neoD_pred_re.search(conj).groups()
            if first_arg == second_arg or first_arg in nouns_map or not first_arg.isnumeric():
                return {} # this is index collision, we cascade the error.
            if second_arg.isnumeric() and second_arg in nouns_map:
                second_arg = nouns_map[second_arg]
                new_conj = f"{role} ( {second_arg} )"
                if first_arg in vp_conjs_map:
                    vp_conjs_map[first_arg].append(new_conj)
                else:
                    vp_conjs_map[first_arg] = [new_conj]
            elif second_arg.isnumeric():
                if first_arg not in childen_count_map:
                    childen_count_map[first_arg] = 1
                else:
                    childen_count_map[first_arg] += 1
                nested_conjs.append({
                    "role": role,
                    "first_arg": first_arg,
                    "second_arg": second_arg,
                })
            else:
                return {}
    
    while_loop_count = 0
    while len(nested_conjs) > 0:
        while_loop_count += 1
        if while_loop_count > 100:
            return {}
        conj = nested_conjs.pop(0)
        if conj['second_arg'] not in childen_count_map or childen_count_map[conj['second_arg']] == 0:
            core = " AND ".join(vp_conjs_map[conj['second_arg']])
            vp_conjs_map[conj['first_arg']].append(f"{conj['role']} ( {core} )")
            childen_count_map[conj['first_arg']] -= 1
        else:
            # if the conj is corrupted, then we abandon just let it go and fail to compare.
            if conj['first_arg'] == conj['second_arg']:
                return {}
            nested_conjs.append(conj)
    
    filtered_conjs_set = set([])
    for k, v in vp_conjs_map.items():
        vp_conjs_map[k].sort()
    for k, v in vp_conjs_map.items():
        vp_expression = " AND ".join(v)
        if vp_expression in filtered_conjs_set:
            return {} # this is not allowed. exact same VP expression is not allowed this time.
        filtered_conjs_set.add(vp_expression)
    for conj in nmod_conjs_set:
        if conj in filtered_conjs_set:
            return {} # this is not allowed. exact same VP expression is not allowed this time.
        filtered_conjs_set.add(conj)
    return filtered_conjs_set

def check_set_equal_neoD(left_lf, right_lf):
    try:
        if translate_invariant_form_neoD(left_lf) == \
        translate_invariant_form_neoD(right_lf):
            return True
        else:
            return False
    except:
        return False
    
device = "cuda:0"

#### HF Model Loading with Evaluation

In [None]:
DATASET_NAME = "ReCOGS" # COGS or ReCOGS
####################################################
#
# Different evaluation function for COGS and ReCOGS
#
####################################################
exact_match_func = check_equal if DATASET_NAME == "COGS" else check_set_equal_neoD

In [None]:
model = EncoderDecoderModel.from_pretrained(
    f"ReCOGS/{DATASET_NAME}-model",
    cache_dir="../huggingface_cache/"
)
_ = model.to(device)
_ = model.eval()
set_seed(123) # should be seed invariant, but just to make sure!

In [None]:
# loading tokenizers and datasets
model_data_path = "./model/"
data_path = "./cogs/" if DATASET_NAME == "COGS" else "./recogs/"
using_set_equal = False if DATASET_NAME == "COGS" else True
src_tokenizer = WordLevelTokenizer(
    os.path.join(model_data_path, "src_vocab.txt"), 
    model.config.encoder,
    max_seq_len=512
)
tgt_tokenizer = WordLevelTokenizer(
    os.path.join(model_data_path, "tgt_vocab.txt"), 
    model.config.decoder,
    max_seq_len=512
)

test_dataset = COGSDataset(
    cogs_path=data_path, 
    src_tokenizer=src_tokenizer, 
    tgt_tokenizer=tgt_tokenizer, 
    partition=find_partition_name("test", "cogs"),
)
test_dataloader = DataLoader(
    test_dataset, batch_size=256, 
    sampler=SequentialSampler(test_dataset),
    collate_fn=test_dataset.collate_batch
)

gen_dataset = COGSDataset(
    cogs_path=data_path, 
    src_tokenizer=src_tokenizer, 
    tgt_tokenizer=tgt_tokenizer, 
    partition=find_partition_name("gen", "cogs"),
)
gen_dataloader = DataLoader(
    gen_dataset, batch_size=256, 
    sampler=SequentialSampler(gen_dataset),
    collate_fn=test_dataset.collate_batch
)

IID Test Set Evaluation

In [None]:
epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
total_count = 0
correct_count = 0
for step, inputs in enumerate(epoch_iterator):
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    labels = inputs["labels"].to(device)
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        eos_token_id=model.config.eos_token_id,
        max_length=512,
    )
    decoded_preds = tgt_tokenizer.batch_decode(outputs)
    decoded_labels = tgt_tokenizer.batch_decode(labels)

    for i in range(len(decoded_preds)):
        if exact_match_func(decoded_labels[i], decoded_preds[i]):
            correct_count += 1
        else:
            pass
        total_count += 1
    current_acc = round(correct_count/total_count, 2)
    epoch_iterator.set_postfix({'acc': current_acc})

OOD Test Set Evaluation

In [None]:
per_cat_eval = {}
for cat in set(gen_dataset.eval_cat):
    per_cat_eval[cat] = [0, 0] # correct, total

epoch_iterator = tqdm(gen_dataloader, desc="Iteration", position=0, leave=True)
total_count = 0
correct_count = 0
for step, inputs in enumerate(epoch_iterator):
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    labels = inputs["labels"].to(device)
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        eos_token_id=model.config.eos_token_id,
        max_length=512,
    )
    decoded_preds = tgt_tokenizer.batch_decode(outputs)
    decoded_labels = tgt_tokenizer.batch_decode(labels)

    input_labels = src_tokenizer.batch_decode(input_ids)
    for i in range(len(decoded_preds)):
        cat = gen_dataset.eval_cat[total_count]
        if exact_match_func(decoded_labels[i], decoded_preds[i]):
            correct_count += 1
            per_cat_eval[cat][0] += 1
        else:
            pass
        total_count += 1
        per_cat_eval[cat][1] += 1
    current_acc = correct_count/total_count
    epoch_iterator.set_postfix({'acc': current_acc})

struct_pp_acc = 0
struct_cp_acc = 0
struct_obj_subj_acc = 0

lex_acc = 0
lex_count = 0
for k, v in per_cat_eval.items():
    if k  == "pp_recursion":
        struct_pp_acc = 100 * v[0]/v[1]
    elif k  == "cp_recursion":
        struct_cp_acc = 100 * v[0]/v[1]
    elif k  == "obj_pp_to_subj_pp":
        struct_obj_subj_acc = 100 * v[0]/v[1]
    elif k  == "subj_to_obj_proper":
        subj_to_obj_proper_acc = 100 * v[0]/v[1]
    elif k  == "prim_to_obj_proper":
        prim_to_obj_proper_acc = 100 * v[0]/v[1]
    elif k  == "prim_to_subj_proper": 
        prim_to_subj_proper_acc = 100 * v[0]/v[1]
    else:
        lex_acc += v[0]
        lex_count += v[1]
lex_acc /= lex_count
lex_acc *= 100
current_acc *= 100

print(f"obj_pp_to_subj_pp: {struct_obj_subj_acc}")
print(f"cp_recursion: {struct_cp_acc}")
print(f"pp_recursion: {struct_pp_acc}")
print(f"subj_to_obj_proper: {subj_to_obj_proper_acc}")
print(f"prim_to_obj_proper: {prim_to_obj_proper_acc}")
print(f"prim_to_subj_proper: {prim_to_subj_proper_acc}")
print(f"LEX: {lex_acc}")
print(f"OVERALL: {current_acc}")