## Setting up for curricular experiment

This assumes you have already followed the instructions in `baselines/baseline_t5`, which will set up the baseline clue files for model input

### Datasets
1. Download and unzip the xd cw crossword set from http://xd.saul.pw/xd-clues.zip.
    - Save it as './data/original/xd/clues.tsv'
2. Preprocess the dataset using this notebook
3. The dataset will be saved to k_acw_export_dir (as a single train.json file)
4. We will also produce the anagram dataset


In [1]:
%load_ext autoreload
%autoreload 2

from decrypt.scrape_parse.acw_load import get_clean_xd_clues
from decrypt import config
from decrypt.common.util_data import clue_list_tuple_to_train_split_json
from decrypt.common import validation_tools as vt

k_xd_orig_tsv = config.DataDirs.OriginalData.k_xd_cw        # ./data/original/xd/clues.tsv
k_acw_export_dir = config.DataDirs.DataExport.xd_cw_json



In [6]:
# defaults to strip periods, remove questions, remove abbrevs, remove fillin
stc_map, all_clues = get_clean_xd_clues(k_xd_orig_tsv,
                                        remove_if_not_in_dict=False,
                                        do_filter_dupes=True)
clue_list_tuple_to_train_split_json((all_clues,),
                                    comment='ACW set; xd cw set, all',
                                    export_dir=k_acw_export_dir,
                                    overwrite=False)

INFO:root:loading xd (ACW) set from /Users/jsrozner/MOUNT/scdt/decrypt/data/original/xd/clues.tsv
INFO:root:Reading file into dict: /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/twl_dict.txt
8719it [00:00, 87173.46it/s]

Initialized a spellchecker
This will fail if you have not downloaded or generated twl_dict.txt


178691it [00:00, 262449.77it/s]
INFO:root:Done reading file: /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/twl_dict.txt
INFO:root:Reading file into dict: /Users/jsrozner/MOUNT/scdt/decrypt/data/original/us/US.dic
118619it [00:00, 224014.82it/s]
INFO:root:Done reading file: /Users/jsrozner/MOUNT/scdt/decrypt/data/original/us/US.dic
INFO:root:Done setting up spellchecker
6338031it [13:17, 7949.77it/s] 
INFO:root:Counter({'not_in_dict': 1315223, 'removed_trailing_period': 501854, 'fillin': 381519, 'removed_likely_abbrev': 100730, 'question word': 92650, 'empty': 75929, 'ref': 30220})
INFO:root:Filtered to 4341760 clues
  1%|          | 50438/4341760 [00:00<00:08, 504345.61it/s]

Counter({'not_in_dict': 1315223, 'removed_trailing_period': 501854, 'fillin': 381519, 'removed_likely_abbrev': 100730, 'question word': 92650, 'empty': 75929, 'ref': 30220})
DEL called for spellchecker


100%|██████████| 4341760/4341760 [00:15<00:00, 280043.52it/s]
100%|██████████| 86633/86633 [01:54<00:00, 755.28it/s]  


removed 2545682 exact dupes
1796078


INFO:root:Counter({1: 1796078})
100%|██████████| 1796078/1796078 [00:20<00:00, 85671.75it/s] 
INFO:decrypt.common.util_data:Source target mapping:
	Litigator's group (3) => aba



FileExistsError: Cannot write since file_name already exists and overwrite not specified

In [11]:
# produce anagram datasets
# roughly 3 minutes to complete
from decrypt.common import anagrammer
anagrammer.gen_db_with_both_inputs(update_flag="overwrite")

from decrypt.common.util_data import (
    get_anags,
    write_json_tuple
)
import json
import os

INFO:root:Overwriting database at /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/anag_db
INFO:root:Adding to db /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/anag_db with updateflag overwrite
118619it [00:29, 4053.51it/s]


Counter({1: 118609})
Done.


In [8]:
def make_anag_sets_json():
    all_anags = get_anags(max_num_words=-1)
    json_list = []
    for idx, a_list in enumerate(all_anags):
        json_list.append(dict(idx=idx,
                              anag_list=a_list))
    print(json_list[0])

    # normally would be (idx, input, tgt)
    output_tuple = [json_list,]

    os.makedirs(config.DataDirs.DataExport.anag_dir)
    write_json_tuple(output_tuple,
                     comment="List of all anagram groupings",
                     export_dir=config.DataDirs.DataExport.anag_dir,
                     overwrite=False)

def make_anag_indic_list_json():
    # make the indicator list
    with open(config.DataDirs.OriginalData.k_deits_anagram_list, 'r') as f:
        all_anag_indicators = f.readlines()
        print(len(all_anag_indicators))

    final_indic_list = []
    for a in all_anag_indicators:
        final_indic_list.append(a.replace('_', " ").strip())
    with open(config.DataDirs.DataExport.anag_indics, 'w') as f:
        json.dump(final_indic_list,f)

In [9]:
make_anag_sets_json()

INFO:root:Initializing (non-singleton) Anagrammer from /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/anag_db
INFO:root:DONE: Initialized Anagrammer from /Users/jsrozner/MOUNT/scdt/decrypt/data/generated/anag_db
100%|██████████| 187438/187438 [00:04<00:00, 41704.05it/s]
INFO:root:Total anagramable: 13535


13535
13535
32305
['titan', 'taint', 'tat in']
{'idx': 0, 'anag_list': ['titan', 'taint', 'tat in']}


FileExistsError: [Errno 17] File exists: '/Users/jsrozner/MOUNT/scdt/decrypt/data/clue_json/curricular/anagram'

In [None]:
make_anag_indic_list_json()



## Curricular training
1. At this point you should have a files at
 - `./data/clue_json/curricular/ACW/train.json`
 - `./data/clue_json/curricular/anagram/[train.json, anag_indics.json]`

2. Running curricular training is the same as running main t5 vanilla train, except that we pass an extra multitask flag, which specifies the curriculum to use. See `seq2seq/multitask_config`. You should pass one of the names from  `multi_config` dict in that file

For example, to train the naive split with the top performing curricular approach (i.e. the result in table 3 that is ACW + ACW-descramble)
```python
python train_clues.py --default_train=base --name=naive_top_curricular --project=curricular --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_random' --multitask=ACW__ACW_descramble
```

Note that the modifications on the dataset are done at the

3. To produce Table 3 of the results
    -  we don't need to do a model_eval run since the outputted predictions have 5 generations
       (which is all we report for that table (for faster experimental iteration).
    - we need to run `load_and_run_t5` on all outputs (column 1) and on the anagram subset (column 2)
      See below for how we do this.

4. For our top result in Table 2 (main resuls) we
    1. scale up the curricular period (to 4 total epochs)
```python
python train_clues.py --default_train=base --name=naive_top_curricular --project=curricular --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_random' --multitask=final_top_result_scaled_up
```
    2. eval with full 100 generations, as before:
e.g., if epoch 10 is best (you'll need to set the run_name)
This runs the eval set (change the run_name)
```python
python train_clues.py --default_val=base --name=curricular_naive_top --project=curricular --data_dir='../data/clue_json/guardian/naive_random' --ckpt_path='./wandb/run_name/files/epoch_10.pth.tar
```


In [7]:
from decrypt.common.label_anagrams import make_label_set

labels = make_label_set()

INFO:decrypt.scrape_parse.guardian_load:loading from /Users/jsrozner/MOUNT/scdt/cryptic_nlp/decrypt_root/data/puzzles
INFO:decrypt.scrape_parse.guardian_load:Using file glob at /Users/jsrozner/MOUNT/scdt/cryptic_nlp/decrypt_root/data/puzzles/cryptic*.json
INFO:decrypt.scrape_parse.guardian_load:Glob has size 5518
INFO:decrypt.scrape_parse.guardian_load:Glob size matches the expected one from Decrypting paper
100%|██████████| 5518/5518 [01:23<00:00, 66.35it/s]
  5%|▍         | 6769/143991 [00:00<00:06, 19804.26it/s]

[("length punct: '", 1),
 ('invalid: clue group', 7687),
 ('invalid: invalid start char (most are continuation clues)', 607),
 ('invalid: number in clue (commonly references another clue)', 7066),
 ('invalid: regexp', 75),
 ('invalid: soln length does not match specified lens (multi box soln)', 56),
 ('invalid: unrecognized char in clue (e.g. html)', 85),
 ('invalid: zero-len clue text after regexp', 15),
 ('length punct: ,', 24644),
 ('length punct: -', 4148),
 ('length punct: .', 8),
 ('length punct: /', 1),
 ('stat: parsed_puzzle', 5518),
 ('stat: total_clues', 143991),
 (1, 119956),
 (2, 20272),
 (3, 2957),
 (4, 686),
 (5, 112),
 (6, 8)]
Total clues: len(puzz_list)


100%|██████████| 143991/143991 [00:00<00:00, 228548.02it/s]
100%|██████████| 55783/55783 [00:03<00:00, 15482.08it/s]


removed 1611 exact dupes
142380


INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})
INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length
INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [85428, 28476, 28476]
INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:
	[GuardianClue(clue='Suffering to grasp edge of plant', lengths=[8], soln='agrimony', soln_with_spaces='agrimony', idx=85002, dataset=PosixPath('/Users/jsrozner/MOUNT/scdt/cryptic_nlp/decrypt_root/data/puzzles'), across_or_down='across', pos=(7, 4), unique_clue_id='cryptic_25415_11-across', type='cryptic', number=25415, id='crosswords/cryptic/25415', creator='Chifonie', orig_lengths='8', lengths_punctuation=set()), GuardianClue(clue='Honour Ben and Noel with new order', lengths=[7], soln='ennoble', soln_with_spaces='ennoble', idx=3432, dataset=PosixPath('/Users/jsrozner/MOUNT/scdt/cryptic_nlp/decrypt_root/data/puzzles'), across_or_down=

In [11]:
# note that this should be run directly on the top model output from curricular training
# otherwise (eg. if 100 beams were used), the top 5 output
# sequences would be expected to change
# remember not to append .json

# eval on the full output (5 beams / 5 sequences)
# this is column 1 of table 3
vt.load_and_run_t5('outputs/model_output.preds',
                   # pre_truncate=5,        # should not be needed since we have only 5 outputs
                   do_length_filter=True)

# run on the anagram subset
# this is column 2 of table 3
vt.load_and_run_t5('outputs/model_output.preds',
                   filter_fcn=vt.make_set_filter(labels, 'anag_direct'),
                   # pre_truncate=5,
                   do_length_filter=True)

# we are looking at agg_top_match (which is after filter)

28476
[('agg_filter_len_pre_truncate', 4.572166034555415),
 ('agg_filtered_few', 1.0),
 ('agg_generate_few', 1.0),
 ('agg_generate_none', 0.0),
 ('agg_in_filtered', 0.3338952100014047),
 ('agg_in_sample', 0.3338952100014047),
 ('agg_sample_len', 5.0),
 ('agg_sample_len_correct', 0.919244275881444),
 ('agg_sample_len_pre_truncate', 5.0),
 ('agg_sample_wordct_correct', 0.9760640539401602),
 ('agg_top_10_after_filter', 0.3338952100014047),
 ('agg_top_match', 0.20213513133867117),
 ('agg_top_match_len_correct', 0.9922039612305099),
 ('agg_top_match_none', 0.007796038769490097),
 ('agg_top_match_wordct_correct', 0.9868310155920775),
 ('agg_top_sample_result_len_correct', 0.9403708386009271),
 ('agg_top_sample_result_wordct_correct', 0.9830734653743504),
 ('filter_len_pre_truncate', 130197),
 ('filtered_few', 28476),
 ('generate_few', 28476),
 ('generate_none', 0),
 ('in_filtered', 9508),
 ('in_sample', 9508),
 ('sample_len', 142380),
 ('sample_len_correct', 130882),
 ('sample_len_pre_trunca