In [None]:
import os
os.environ["TRANSFORMERS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/datasets"
os.environ["BALAUR_CACHE"] = "/network/scratch/m/mirceara/.cache/balaur"

In [None]:
import sys
sys.path.append("/home/mila/m/mirceara/balaur/experiments/pretrain/")
from run_bort import MlmModel, MlmModelConfig, MlmWnreDataModule, MlmWnreDataModuleConfig, WNRE

In [None]:
from itertools import chain
import pandas as pd
from pathlib import Path
from tqdm import tqdm, trange

import numpy as np
import torch
import torch.nn.functional as F
import datasets as ds
import transformers as tr

from typing import List, Union

pd.set_option('display.max_columns', None)

## Extract hypernymy-related synsets from WordNet

In [None]:
from balaur.wordnet.utils import import_wordnet
wn = import_wordnet()

In [None]:
wnre_bert = WNRE('bert-base-uncased', rel_depth=1)
wnre_roberta = WNRE('roberta-base', rel_depth=1)

In [None]:
overlapping_synsets = set(wnre_bert.synset_to_tokens.keys()).intersection(wnre_roberta.synset_to_tokens.keys())
overlapping_tokens = set(wnre_bert.token_to_synsets.keys()).intersection(wnre_roberta.token_to_synsets.keys())

synset_to_tokens = {}
for s, ts in wnre_bert.synset_to_tokens.items():
    if s in overlapping_synsets:
        for t in ts:
            if t in overlapping_tokens:
                synset_to_tokens[s] = synset_to_tokens.get(s, []) + [t]
            
hypernymy_synsets = {}
for s, hs in wnre_bert.related_synsets['hypernymy'].items():
    if s in synset_to_tokens:
        for h in hs:
            if h in synset_to_tokens:
                hypernymy_synsets[s] = hypernymy_synsets.get(s, []) + [h]

## Create examples from extracted hypernym pairs

In [None]:
import inflect
p = inflect.engine()

get_num = lambda x: "plural" if p.singular_noun(x) else "singular"
get_article = lambda x: p.a(x).split()[0]

In [None]:
def create_example(hypernym: str, hyponym: str, 
                   mask_token: str = "[MASK]", 
                   **kwargs):
    
    number_hypernym = get_num(hypernym)
    number_hyponym = get_num(hyponym)
    hyponym_article = get_article(hyponym)
    
    out = dict(
        hypernym=hypernym,
        hyponym=hyponym,
        number_hypernym=number_hypernym,
        number_hyponym=number_hyponym,
    )
    for to_mask in ['hypernym', 'hyponym']:
        
        # replace with maskrel_depth
        if to_mask == 'hypernym':
            hypernym = mask_token
            hyponym = out['hyponym']
        elif to_mask == 'hyponym':
            hyponym = mask_token
            hypernym = out['hypernym']
            
        # create example from template
        if number_hyponym == 'singular' and number_hypernym == 'singular':
            text = f"{hyponym_article} {hyponym} is a type of {hypernym}."
        elif number_hyponym == 'plural' and number_hypernym == 'singular':
            text = f"{hyponym} are a type of {hypernym}."
        elif number_hyponym == 'plural' and number_hypernym == 'plural':
            text = f"{hyponym} are types of {hypernym}."
        elif number_hyponym == 'singular' and number_hypernym == 'plural':
            text = f"types of {hypernym} include {hyponym_article} {hyponym}."
        out[f"masked_{to_mask}"] = text

    return out

rows = []
seen_pairs = set()
for hypo_syn, hyper_syns in tqdm(hypernymy_synsets.items()):
    hypo_lex = wn.synset(hypo_syn).lexname()
    hypo_toks = synset_to_tokens[hypo_syn]
    for hyper_syn in hyper_syns:
        hyper_lex = wn.synset(hyper_syn).lexname()
        hyper_toks = synset_to_tokens[hyper_syn]
        for hyper_tok in hyper_toks:
            for hypo_tok in hypo_toks:
                k = f"{hyper_tok}, {hypo_tok}"
                if hypo_tok == hyper_tok:
                    continue
                if k in seen_pairs:
                    continue
                seen_pairs.add(k)
                x = create_example(hyper_tok, hypo_tok)
                x.update(dict(
                    lex_hypernym=hyper_lex, lex_hyponym=hypo_lex,
                    syn_hypernym=hyper_syn, syn_hyponym=hypo_syn
                ))
                rows.append(x)

## Filter noisy examples

In [None]:
df = pd.DataFrame(rows)
print(f"Num pairs: {len(df)}")
df.sample(20)

### Remove noun.Tops
Tops are top-level entries in the WordNet hierarchy which are often noisy by virtue of their generality (e.g. "Mortal")

In [None]:
df2 = df.loc[df.lex_hypernym != 'noun.Tops']
print(f"Num pairs: {len(df2)}")
df2.sample(20)

### Remove disproportionately represented hypernyms
Certain hypernyms have a disproportionate number of hyponyms, often due to too-general wordsenses or lemmatization. We manually inspect the most over-represented hypernyms in HypCC to filter noisy hypernym wordforms.

In [None]:
for h,c in df2.hypernym.value_counts()[:100].to_dict().items():
    print(h,c)

In [None]:
# here we manually look through each of the hypernyms above to identify noisy cases
df2.loc[(df2['hypernym'] == 'force')].sample(20)

In [None]:
# manually remove examples with hypernyms that are overloaded, unnatural, etc
# we limit ourselves to manually inspecting hypernyms that occur in >=50 pairs 
# Note: wordnet and nltk lemmatization are so atrocious
stop_hypers = [
    'content',
    'contents',
    'instrumentation',
    'part',
    'parts',
    'line',
    'condition',
    'conditions',
    'section',
    'sections',
    'substance',
    'substances',
    'point',
    'points',
    'action',
    'actions',
    'statement',
    'statements',
    'work',
    'works',
    'force',
    'forces',
    'spot',
    'spots',
    'set',
    'sets',
    'men',
    'man',
    'mans',
    'women',
    'woman',
    'portion',
    'portions',
    'piece',
    'pieces',
    'thought',
    'thoughts',
    'people',
    'peoples',
    'instrument',
    'instruments',
    'country',
    'countries',
    'paper',
    'papers',
    'information',
    'informations',
    'land',
    'lands',
    'field',
    'fields',
    'form',
    'forms',
    'situation',
    'situations',
    'way',
    'ways',
    'play',
    'plays',
    'parcel',
    'parcels',
    'expert',
    'experts',
]
df2 = df2.loc[df2['hypernym'].apply(lambda x: x not in stop_hypers)]
print(f"Num pairs: {len(df2)}")
df2.sample(20)

## Save dataset

In [None]:
df2.to_csv("hypcc.csv")

## Load and process dataset

In [None]:
TEXT_COL = 'masked_text'
LABEL_COL = 'masked_tokens'
REL_COL = "context_token"

def group_dataset(df: pd.DataFrame, rel: str, number_rel: str, number_ctx: str, prefix: str = ""):
    df = df.copy()
    assert rel in ['hypernym', 'hyponym'], "rel must be 'hypernym' or 'hyponym'."
    assert number_rel in ['singular', 'plural', 'all'], "number_rel must be 'singular' or 'plural'."
    assert number_ctx in ['singular', 'plural', 'all'], "number_context must be 'singular' or 'plural'."
    ctx = "hypernym" if rel == "hyponym" else "hyponym"
    text_col = f"masked_{rel}"
    label_col = rel
    
    df[text_col] = df[text_col].apply(lambda x: prefix + x)
    df = df.drop(df[(df[f"number_{rel}"]!=number_rel) | (df[f"number_{ctx}"]!=number_ctx)].index)
    df = df.drop([f"masked_{ctx}", f"number_{rel}", f"number_{ctx}"], axis=1)
    df = df.drop([f"lex_{ctx}", f"lex_{rel}", f"syn_{ctx}", f"syn_{rel}"], axis=1)
    df = df.rename({text_col: TEXT_COL}, axis=1)
    df = df.rename({label_col: LABEL_COL}, axis=1)
    df = df.rename({ctx: REL_COL}, axis=1)
    agg_cols = [TEXT_COL, REL_COL]
    df = df.groupby(agg_cols).agg(tuple).applymap(list).reset_index()
    return df

In [None]:
df2 = pd.read_csv("hypcc.csv")
prefixes = dict(
    no_ctx="",
    ctx="In the context of hypernymy, ",
)
datasets = {}
for p, prefix in prefixes.items():
    for rel in ['hypernym', 'hyponym']:
        for num_rel in ['singular', 'plural']:
            for num_ctx in ['singular', 'plural']:
                k = "_".join([rel,num_rel,num_ctx,p])
                datasets[k] = group_dataset(
                    df=df2, 
                    rel=rel, 
                    number_rel=num_rel, 
                    number_ctx=num_ctx,
                    prefix=prefix
                )

In [None]:
TOKENIZERS = ['bert-base-uncased', 'roberta-base']
tknzrs = {t: tr.AutoTokenizer.from_pretrained(t) for t in TOKENIZERS}

tokenized = {}
categories = {}
for t in TOKENIZERS:
    tknzr = tknzrs[t]
    tokenized[t] = {}
    categories[t] = {}
    for k, d in datasets.items():
        d = ds.Dataset.from_pandas(d)
        # replace mask token with tokenizer specific token and add punctuation
        d = d.map(lambda e: {TEXT_COL: e[TEXT_COL].replace("[MASK]", tknzr.mask_token)})

        # tokenize
        d = d.map(lambda e: tknzr(e[TEXT_COL]), batched=True)

        # add conservative padding
        max_length = max([len(x) for x in d['input_ids']])
        d = d.map(lambda e: tknzr.pad(e, max_length=max_length), batched=True)

        # encode label without special tokens (add prefix space for gpt/roberta tokenizers 
        # as the target is not a start of sentence, see:
        # https://discuss.huggingface.co/t/bpe-tokenizers-and-spaces-before-words/475/2)
        d = d.map(lambda e: {'labels': [tknzr.encode(f' {x}', add_special_tokens=False) for x in e[LABEL_COL]]})
        d = d.map(lambda e: {REL_COL: tknzr.encode(f' {e[REL_COL]}', add_special_tokens=False)})

        # find invalid targets (i.e. multi-token words)
        invalid_targets = [x for x in d['labels'] if any([len(_x) > 1 for _x in x])]
        invalid_targets += [x for x in d[REL_COL] if len(x) > 1]
        if invalid_targets:
            invalid_targets = set([tknzr.decode(x) for x in invalid_targets])
            raise ValueError(f"Invalid tokenizer {t}, the following targets are multi-token: {invalid_targets}.")

        # now that we know every label is single token, remove list formatting
        d = d.map(lambda e: {'labels': [x[0] for x in e['labels']]})
        d = d.map(lambda e: {REL_COL: e[REL_COL][0]})

        # let ds handle casting data to torch tensors, but specify output_all_columns to obtain non-tensor labels
        d.set_format(type='torch', columns=['input_ids', 'attention_mask'], output_all_columns=True)

        tokenized[t][k] = d
        categories[t][k] = list(sorted(set(chain.from_iterable(d['labels']))))

In [None]:
tokenized['roberta-base']

In [None]:
for t in tokenized:
    for k in tokenized[t]:
        tokenized[t][k].save_to_disk(f"preprocessed/hypcc_{t}_{k}")