In [None]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm

# Model Analysis section of paper

## 6.1 Meta-linguistic properties
Table 3b is produced from the statistics that will be outputted in baselines/baseline_t5 and
baselines/baseline_knn.

Nothing else is required

## 6.2 Disjointness
This section contains two parts:
1. Subset analysis
2. Analysis on distinct splits

Producing the table rows:
1. Row 1 in the Table 3b (naive, entire split) is the same as T5 (lenghts) in the main
 results table and is produced by running baselines/baseline_t5, i.e.
1. Subset not in train - see below for setup
1. Naive disjoint - run in the same way as baselines/baseline_t5, but on the naive_disjoint split, i.e.
    - Train on the naive disjoint split:
```python
python train_clues.py --default_train=base --name=baseline_naive_disjoint --project=baseline --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_disjoint'
```
    - Eval to get 100 generations (on the best epoch checkpoint), with and without `--test` flag
```python
python train_clues.py --default_val=base --name=baseline_naive_disjoint_val --project=baseline --data_dir='../data/clue_json/guardian/naive_disjoint' --ckpt_path='./wandb/run_name/files/epoch_10.pth.tar
```
    - run the eval script
`vt.load_and_run_t5('decrypt/t5_outputs/baseline_naive_disjoint_val.preds')`
1. word initial disjoint split - same as T5 (lengths) row in main results table (see baselines/baselines_t5)


In [None]:
# eval on the subset not seen in train
import random
from decrypt import config
from decrypt.scrape_parse import load_guardian_splits
from decrypt.common import validation_tools as vt
from decrypt.scrape_parse.util import str_hash as safe_hash

In [None]:
# create a function that will filter to those input-output pairs with answers
# not seen during training
def make_filter_fn():
    _, _, (train, _, _) = load_guardian_splits(config.DataDirs.Guardian.json_folder, verify=True)
    s = set()
    for c in train:
        s.add(c.soln_with_spaces)

    # return False to omit
    def filter_fcn(mp: vt.ModelPrediction):
        if mp.target in s:
            return False
        return True
    return filter_fcn


vt.load_and_run_t5('outputs/naive_baseline.preds',
                   filter_fcn=make_filter_fn())

In [None]:
# the following is not reported in the paper
# prepare for set that does not overlap, fuzzily
# ie also match on plurals
def make_set_inclusion_filter_fcn_fuzz():
    _, _, (train, _, _) = load_guardian_splits(config.DataDirs.Guardian.json_folder, verify=True)
    s = set()

    # fuzzily match plurals
    for c in train:
        soln = c.soln_with_spaces
        if soln.endswith('es'):
            s.add(soln[:-2])
        if soln.endswith('s'):
            s.add(soln[:-1])
        s.add(soln + 'es')
        s.add(soln + 's')
        s.add(soln)

    # return False to omit
    def filter_fcn(mp: vt.ModelPrediction):
        if mp.target in s or mp.target[:-1] in s or mp.target[:-2] in s:
            return False
        return True
    return filter_fcn

vt.load_and_run_t5('outputs/naive_baseline.preds',
                   filter_fcn=make_set_inclusion_filter_fcn_fuzz())

## 6.3 Wordplay minimal task
Here we provide code to
1. produce the descramble dataset
1. run the two variations of the descramble task
    - with and without adding the definition
    - on random and word-initial disjoint splits
    - copy task (i.e. no scrambling)

In [None]:
# first prepare the descrambling dataset
from decrypt.scrape_parse.acw_load import get_clean_xd_clues
from sklearn.model_selection import train_test_split
from decrypt.common.util_data import write_json_tuple

k_xd_orig_tsv = config.DataDirs.OriginalData.k_xd_cw
k_descramble_rand = config.DataDirs.DataExport.descramble_random
k_descramble_disj = config.DataDirs.DataExport.descramble_word_init_disjoint

# method will produce a version of the ACW dataset that is
# - single words that appear in our dictionary
# - without exact duplicates (but note that some answers will occur multiple times with different clues
# - filtered to answer words with between 4 and 14 characters
# - downsampled to 10%
def make_descramble_json(seed=42):
    stc_map, all_clues = get_clean_xd_clues(k_xd_orig_tsv,
                                            remove_if_not_in_dict=True,     # only single words
                                            do_filter_dupes=True)

    # further filter away anything < 4 chars
    all_clues_len = [x for x in all_clues if len(x.soln) > 3 and len(x.soln) < 15]
    print(len(all_clues_len))

    # downsample to 10 percent
    rng = random.Random(42)
    all_clues_10per = rng.sample(all_clues_len, k=int(len(all_clues)*.1))
    print(len(all_clues_10per))
    print(all_clues_10per[0])

    # logic the same as scrape_parse.guardian_load.make_disjoint_split()
    def make_dataset(disj: bool):
        # make json
        json_all = []
        for c in all_clues_10per:
            c.soln = c.soln.lower()
            json_dict = dict(defn=c.clue,
                             target=c.soln)
            json_all.append(json_dict)

        train, val = [], []     # list of json dicts
        rng.seed(seed)
        if disj:
            for c in json_all:
                idx = safe_hash(c['target'][:2]) % 10
                if idx == 0:
                    val.append(c)   # val ~ 20k
                else:
                    train.append(c)
            rng.shuffle(train)
            rng.shuffle(val)
        else:   # not disj
            train, val = train_test_split(json_all, test_size=.1, random_state=seed)

        return train, val

    write_json_tuple(list(make_dataset(disj=False)),    # train, val
                     comment="",
                     export_dir=k_descramble_rand)

    write_json_tuple(list(make_dataset(disj=True)),     # train, val
                     comment="",
                     export_dir=k_descramble_disj)

make_descramble_json()

To run (e.g. on random split with phrasal definition)
```python
python train_descramble.py --default_train=base --name=descramble_random_phrase --project=descramble --wandb_dir='./wandb' --data_dir='../data/clue_json/descramble/random_split/' --generation_beams=10 --no_save --randomize_train_scramble --add_defn
```

The other variations we present are
- word initial split ('./data/clue_json/descramble/word_initial')
- --no_defn (do not append phrase), instead of --add_defn
- --copy (to test the identity task)

Model analysis (eval) runs are not necessary: we can just read off the metric from the best model.

## 6.4 Wordplay systematic learning
- identify clues with anagram of first name
- two sets: scramble / substitute
- run / evaluate
- average character level overlap

In [None]:
from decrypt.common.label_anagrams import make_label_set
from glob import glob
from collections import defaultdict
from typing import *
from decrypt.common.substitution import ClueWithSubstitutions, Substitution
from decrypt.common.puzzle_clue import BaseClue
import json
import os

In [None]:
def load_train_and_val_direct_anag_sets() -> Tuple[List[int], ...]:
    # this will load things twice, but whatever
    labels = make_label_set()
    _, _, (train, val, _) = load_guardian_splits(config.DataDirs.Guardian.json_folder)

    # get pre-labeled direct anagram clue sets
    def load_set(label_name: str):
        idx_to_clues_train = {c.idx: c for c in train}
        idx_to_clues_val = {c.idx: c for c in val}
        train_set = [idx_to_clues_train[idx] for idx in labels[label_name] if idx in idx_to_clues_train]
        val_set = [idx_to_clues_val[idx] for idx in labels[label_name] if idx in idx_to_clues_val]
        return train_set, val_set

    train_set, val_set = load_set('anag_direct')
    print(len(train_set))
    print(len(val_set))
    return train_set, val_set


In [None]:
# Filter to clues that work for our substitution
# - only one word targets
# - no punctuation
# Produce clues with 20 substitutions (10 scramble, 10 sub)
def get_clues_with_subs(
        clue_list: List[BaseClue],
        num_subs=10) -> List[ClueWithSubstitutions]:
    random.seed(42)

    # produce a map from the name length to list of names with that length
    def get_names_map() -> Tuple[List[str], Dict[int, List[str]]]:
        names = list()
        dir_glob = glob(str(config.DataDirs.OriginalData.k_names / "*.txt"))
        for name_file in dir_glob:
            with open(name_file, 'r') as f:
                names.extend([x.strip() for x in f.readlines()])

        # get mapping from length to names
        def make_length_to_names_map():
            length_to_names_set = defaultdict(set)
            for name in names:
                length_to_names_set[len(name)].add(name)

            # convert to lists
            length_to_names = defaultdict(list)
            for k,v in length_to_names_set.items():
                length_to_names[k] = list(v)
            return length_to_names

        all_names = list(set(names))            # dedupe
        # len(all_names)
        # sum(map(len, names_map.values()))     # total names
        names_map = make_length_to_names_map()
        return all_names, names_map
    all_names_list, names_map = get_names_map()
    all_names_set = set(all_names_list)

    # For each of the possible clues, look for those that have a first name
    clues_with_subs = []        # we accumulate ClueWithSubstitution
    for sc in tqdm(clue_list):
        if len(sc.lengths) > 1:     # skip multiword outputs
            continue

        s = sorted(sc.soln.lower())
        anagram_substrate = None
        word_idx = None

        # find the word that was actually the anagram substrate (i.e. we are looking for
        # anagrams that are a full word)
        words: List[str] = sc.clue_with_lengths().split(' ')
        for idx, w in enumerate(words):
            if not len(w) == len(s):    # verify length - this cannot be the substrate if it's not correct len
                continue
            if not w.isalpha():         # skip, e.g. any anagram substrate that has some punctuation
                continue
            # this one is probably redundant with .isalpha()
            if not sorted(w.lower()) == s:  # verify that the word is in fact an anagram
                continue
            if w in all_names_set:
                anagram_substrate = w
                word_idx = idx
                break
        # skip clue if we didn't find the anagram substrate
        if anagram_substrate is None:
            continue

        # success: we found the anagram substrate!
        print(sc)

        # create substitutions
        # we will fill in the substitutions part below with 10 each of
        # - scramble
        # - new name
        sub_clue = ClueWithSubstitutions(
            orig_input=sc.clue_with_lengths(),
            word_to_be_swapped=anagram_substrate,
            target=sc.soln,
            substitutions=[]
        )

        # scramble substition
        for i in range(num_subs):
            x = list(anagram_substrate.lower())
            random.shuffle(x)
            new_word = "".join(x).capitalize()
            words[word_idx] = new_word
            new_clue_str = " ".join(words)
            sub_clue.substitutions.append(Substitution(new_clue_str, new_word))

        # real name substitutions
        subs_list = random.sample(names_map[len(anagram_substrate)], k=num_subs)
        for new_word in subs_list:
            words[word_idx] = new_word
            new_clue_str = " ".join(words)
            sub_clue.substitutions.append(Substitution(new_clue_str, new_word))
        clues_with_subs.append(sub_clue)
    return clues_with_subs

# write json for use by model
def write_json(input_list, name):
    os.makedirs(config.DataDirs.DataExport.wordplay_dir, exist_ok=True)
    fname = config.DataDirs.DataExport.wordplay_dir / f'{name}_subs.json'
    with open(fname, 'w') as f:
        v_subs_json = [v.to_dict() for v in input_list]
        json.dump(v_subs_json, f)

In [None]:
train_list, val_list = load_train_and_val_direct_anag_sets()

In [None]:
val_subs = get_clues_with_subs(val_list)
train_subs = get_clues_with_subs(train_list)
print(len(val_subs))
print(len(train_subs))

In [None]:
# verify that we can pass to /from json
v_subs_json = [v.to_dict() for v in val_subs]
v_subs_orig = [ClueWithSubstitutions.from_dict(v) for v in v_subs_json]
assert v_subs_orig == val_subs
print('ok')

In [None]:
write_json(train_subs, 'train')
write_json(val_subs, 'val')

At this point, we should have the json on which we will run the experiment.
Now we load and evaluate on these cluesets. For the paper we used the best model for T5 (lengths),
i.e., second to last row in Table 2

In [None]:
from seq2seq.model_runner import ModelRunner

k_model_name = 't5-base'
k_ckpt_path = ''    # substitute the checkpoint path

def run_model(name):
    mr = ModelRunner(k_model_name, k_ckpt_path)

    fname = config.DataDirs.DataExport.wordplay_dir / f'{name}_subs.json'
    with open(fname, 'r') as f:
        clue_sub_list: List[ClueWithSubstitutions] = [ClueWithSubstitutions.from_dict(v) for v in json.load(f)]
    all_outputs: List[List[str]] = []
    for cws in clue_sub_list:
        input_list = [cws.orig_input]
        input_list.extend([x.new_clue_str for x in cws.substitutions])
        outs = mr.generate(input_list)
        outs = [x.tolist() for x in outs]
        all_outputs.append(outs)
    # to get off a cluster, e.g.
    # with open(f'./data_util/{name}_results.json', 'w') as f:
    #     json.dump(all_outputs, f)
    return all_outputs

# alternatively, could write to json and then read back in, e.g. if on cluster
train_results = run_model('train')
val_results = run_model('val')
assert len(train_results) == len(train_subs)
assert len(val_results) == len(val_subs)


# each output grouping should have 21 entries (1 unchanged, 10 scramble, 10 sub)
for i in train_results:
    assert len(i) == 21
for i in val_results:
    assert len(i) == 21

In [None]:
import multiset
from collections import Counter

# adapted from decrypt.common.validation_tools
def calc_metrics(substrate_word: str,
                 sampled: List[str],
                 tgt: str,
                 ctr: Counter,
                 truncate=None):
    """Accumulate results in the counter"""
    top_res = 0
    in_res = 0
    ct_match_first_letter = 0
    avg_char_overlap_with_substrate = 0.0
    avg_char_overlap_with_target = 0.0

    if truncate is not None:
        sampled = sampled[:truncate]

    substrate_word = substrate_word.lower()
    samples_no_spaces = list(map(lambda x: x.lower().replace(' ','').strip(), sampled))
    if samples_no_spaces[0] == tgt:
        top_res = 1
    if tgt in samples_no_spaces:
        in_res = 1

    substrate_multiset = multiset.Multiset(substrate_word)
    target_multiset = multiset.Multiset(tgt)
    for s in samples_no_spaces:
        if s[0] == substrate_word[0]:
            ct_match_first_letter += 1
        avg_char_overlap_with_substrate += len(substrate_multiset.intersection(multiset.Multiset(s)))/ len(s)
        avg_char_overlap_with_target += len(target_multiset.intersection(multiset.Multiset(s))) / len(s)

    ctr['total'] += 1
    ctr['top'] += top_res
    ctr['in_res'] += in_res
    ctr['first_letter'] += ct_match_first_letter / len(sampled)
    ctr['char_overlap_sub'] += avg_char_overlap_with_substrate / len(sampled)
    ctr['char_overlap_tgt'] += avg_char_overlap_with_target / len(sampled)


def eval_(pairs, sample_truncate=3):
    ctr_orig = Counter()
    ctr_scramble = Counter()
    ctr_new = Counter()
    for sub_group, result_list in pairs:
        orig_substrate = sub_group.word_to_be_swapped
        tgt = sub_group.target

        # unpack the base result (unmodified), scramble results, and substitution results
        # 0, 1-10, 11-20
        result_base, result_scramble, result_sub = result_list[0], result_list[1:11], result_list[11:]

        # base
        calc_metrics(orig_substrate, result_base, tgt, ctr_orig, truncate=sample_truncate)
        # scramble
        for i in range(10):
            calc_metrics(substrate_word=sub_group.substitutions[i].substituted_word,
                         sampled=result_scramble[i],
                         tgt=tgt,
                         ctr=ctr_scramble,
                         truncate=sample_truncate)
        # substitution
        for i in range(10):
            calc_metrics(substrate_word=sub_group.substitutions[i+10].substituted_word,
                         sampled=result_sub[i],
                         tgt=tgt,
                         ctr=ctr_new,
                         truncate=sample_truncate)
    # print the counter results
    for c in [ctr_orig, ctr_scramble, ctr_new]:
        print()
        for k,v in c.items():
            print(f'{k}: {v/c["total"]}')

In [None]:
val_pairs = list(zip(val_subs, val_results))
train_pairs = list(zip(train_subs, train_results))

truncate=10     # i.e. keep all 10; equivalent to not truncating
eval_(val_pairs, sample_truncate=truncate)
eval_(train_pairs, sample_truncate=truncate)

## 6.5  Efrat comparison
The Efrat comparison has 3 parts:
1. Replication of their results
1. Creation and analysis on a word-initial disjoint version of their dataset
1. Training and eval of a curricular model with t5-large

### Replication
1. Download and unzip the cryptonite dataset from https://github.com/aviaefrat/cryptonite/tree/main/data.
    - We will use only the naive and official-split
    - Place in data/original/cryptonite
1. Convert to format for our model
1. Create the word initial disjoint split


In [None]:
import json_lines
from typing import *
from decrypt.common.puzzle_clue import BaseClue
from decrypt.common.util_data import clue_list_tuple_to_train_split_json

In [None]:
# load in all cryptonite split
def load_split(split_type: str, loc):
    output_list = []
    file_loc = loc / "cryptonite-" + split_type + ".jsonl"
    with open(file_loc, "rb") as f:
        for line in json_lines.reader(f):
            output_list.append(line)
    return output_list

def baseclue_from_cryptonite_clue(cryptonite_clue: Dict) -> BaseClue:
    clue = cryptonite_clue['clue']
    soln = cryptonite_clue['answer']
    soln = "".join(soln.split(" "))
    enumeration = cryptonite_clue['enumeration']

    # strip endings
    end = clue.rfind("(")
    assert clue[end-1] == " "
    clue = clue[:end-1]

    lens = list(map(int, enumeration.strip("()").split(",")))
    return BaseClue(clue, lens, soln)

# # some tests of above method
# def test_baseclue_from_cryptonite():
#     pp(list(map(baseclue_from_cryptonite_clue, cryptonite_data_all[:5])))
#     # a multiword clue
#     for c in cryptonite_data_all:
#         if ',' in c['enumeration']:
#             pp(baseclue_from_cryptonite_clue(c))
#             break
# # test_baseclue_from_cryptonite()


def crypto_to_our_json_format(crypto_tuple, label: str,
                              export_dir,
                              verify=False):
    def crypto_clues_to_base_clues(input_list):
        return list(map(baseclue_from_cryptonite_clue, input_list))

    tuple_base_clue_list = tuple(map(crypto_clues_to_base_clues, crypto_tuple))

    # make sure our changes make sense
    if verify:
        for idx in range(3):
            for orig, new in zip(crypto_tuple[idx], tuple_base_clue_list[idx]):
                assert orig['clue'] == new.clue_with_lengths(), f'{orig}, {new}'
                assert orig['answer'].strip() == new.soln_with_spaces, f'{orig}, {new}'

    clue_list_tuple_to_train_split_json(tuple_base_clue_list,
                                        comment='Cryptonite ' + label,
                                        export_dir=export_dir,
                                        overwrite=False)

In [None]:
train_val_test_official = list(map(lambda x:
                                   load_split(x,
                                              loc=config.DataDirs.OriginalData.k_cryptonite_offical),
                                   ['train', 'val', 'test']))
train_val_test_naive = list(map(lambda x:
                                load_split(x,
                                           loc=config.DataDirs.OriginalData.k_cryptonite_naive),
                                ['train', 'val', 'test']))

# quick sanity check - should be the same length
all_official = [clue for clue_list in train_val_test_official
                for clue in clue_list]
all_naive = [x for l in train_val_test_naive
             for x in l]
assert len(all_official) == len(all_naive)
print(len(all_official))

# produce our json format for their splits
crypto_to_our_json_format(train_val_test_official, label='official, theirs', verify=True,
                          export_dir=config.DataDirs.DataExport.crypto_naive_disjoint)
crypto_to_our_json_format(train_val_test_naive, label='naive', verify=True,
                          export_dir=config.DataDirs.DataExport.crypto_naive)

In [None]:
# make our disjoint hash split
# disjoint set
def make_disjoint_json(all_clue_list: List[BaseClue]):
    soln_to_clue_map: defaultdict[str, List[BaseClue]] = defaultdict(list)
    for bc in tqdm(all_clue_list):
        soln_to_clue_map[bc.soln].append(bc)
    train_val_test = [[], [], []]
    for k, v in soln_to_clue_map.items():
        h = safe_hash(k[:2]) % 1000        # tried to make larger to get better split numbers
        if h < 899 :
            train_val_test[0].extend(v)
        elif h < 949:        # h==18
            train_val_test[1].extend(v)
        else:               # h==19
            train_val_test[2].extend(v)
    print(list(map(len, train_val_test)))

    # now shuffle the lists
    rng = random.Random(42)
    print(train_val_test[0][:3])
    for l in train_val_test:
        rng.shuffle(l)
    print(train_val_test[0][:3])

    assert sum(map(len, train_val_test)) == len(all_clue_list)
    return train_val_test

clue_list_tuple_to_train_split_json(make_disjoint_json(all_official),
                                    comment='Cryptonite disjoint hash, ours',
                                    export_dir=config.DataDirs.DataExport.crypto_word_init_disjoint,
                                    overwrite=False)

### Running their model
For the naive split training can go to 20 epochs
This will use t5-large

1. for naive
```python
python train_clues.py --default_train=cryptonite --name=cryptonite_naive --project=cryptonite --wandb_dir='./wandb' --data_dir='../data/clue_json/cryptonite/naive'
```
1. also run the official split by modifying the data_dir

1. for word initial disjoint split, best performing model is around step 100k, so we need intraepoch eval
```python
python train_clues.py --default_train=cryptonite --name=cryptonite_word_init_disjoint --project=cryptonite --wandb_dir='./wandb' --data_dir='../data/clue_json/cryptonite/word_init_disjoint' --val_freq=100
```

Then we run evaluation on the test set
1.  for example,
```python
python train_clues.py --default_val=cryptonite --name=cryptonite_word_init_disjoint _val --project=cryptonite --wandb_dir='./wandb' --data_dir='../data/clue_json/cryptonite/word_init_disjoint' --test --ckpt_path='./path_to_model_ckpt'
```

### Curricular
Finally, we train their model using our curricular approach

For example, on word initial disjoint
```python
python train_clues.py --default_train=cryptonite --name=cryptonite_curr_word_init_disj --project=cryptonite --wandb_dir='./wandb' --data_dir='../data/clue_json/cryptonite/word_init_disjoint' --multitask=cfg_crypto_acw_acwdesc
```
This curricular approach has one fewer epoch of curricular pretraining. It takes
a long time (24 hours) to train since the crossword dataset (ACW) is so large.
For this reason, it is better to do curricular training only once (i.e. three epochs), and then
for each of the three datasplits, train from the pretrained checkpoint.
Pretraining takes roughly 6 hours per epoch.

And evaluation is unchanged from above (just adding --test and a checkpoint path)


