In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from metal.mmtl.trainer import MultitaskTrainer
from metal.mmtl.glue.glue_tasks import create_glue_tasks_payloads
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.slicing.slice_model import SliceModel

In [3]:
SEED = 1

### Initialize normal payloads (to load the base model) 

In [4]:
task_kwargs = {
    "dl_kwargs": {"batch_size": 8},
    "freeze_bert":False,
    "bert_model": 'bert-base-uncased',
    "max_len": 128,
    "attention": False,
    "dropout": 0.1,
}
task_names = ["RTE"]

In [5]:
%%time

tasks, payloads = create_glue_tasks_payloads(task_names, **task_kwargs)

Using random seed: 275349
Loading RTE Dataset


HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


CPU times: user 14.2 s, sys: 1.1 s, total: 15.3 s
Wall time: 16.2 s


In [6]:
tasks, payloads

([ClassificationTask(name=RTE, loss_multiplier=1.00)],
 [Payload(RTE_train: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=train),
  Payload(RTE_valid: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=valid),
  Payload(RTE_test: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=test)])

### Load existing baseline model

In [7]:
model = MetalModel(tasks, seed=SEED, verbose=False)

In [8]:
from metal.mmtl.glue.glue_tasks import create_glue_tasks_payloads
from metal.mmtl.metal_model import MetalModel

SEED = 321
baseline = MetalModel(tasks, seed=SEED, verbose=False)

import os
import torch
model_dir = '/dfs/scratch0/vschen/metal-mmtl/logs/2019_04_30/17_57_08'
model_path = os.path.join(model_dir, 'best_model.pth')
baseline.load_weights(model_path)

#### Valdiate baseline performance

In [9]:
baseline.score(payloads[0])

{'RTE/RTE_train/RTE_gold/accuracy': 0.9943775100401606}

In [10]:
baseline.score(payloads[1])

{'RTE/RTE_valid/RTE_gold/accuracy': 0.740072202166065}

### Visualize some slices

In [11]:
import random
from collections import defaultdict
import spacy
nlp = spacy.load("en_core_web_sm")

In [12]:
from pprint import pprint

In [13]:
# Set to True to walk through slice examples via pdb
VISUALIZE_EXAMPLES = False

from metal.mmtl.glue.glue_slices import *
rte_slices = [
    "has_temporal_preposition",
    "has_possessive_preposition",
    "is_comparative",
    "is_quantification",
    "is_quantification_hypothesis",
    "has_multiple_articles",
    "has_wh_words",
    "short_hypothesis",
    "long_hypothesis",
    "short_premise",
    "long_premise",
    "has_coordinating_conjunction_hypothesis",
    "has_but",
    "common_negation"
]

ds = payloads[2].data_loader.dataset
rand_idx = list(range(len(ds)))
random.shuffle(rand_idx)

counter = defaultdict(int)
for idx in rand_idx:
    for slice_name in rte_slices:
        slice_fn = globals()[slice_name]
        sentence = ds.sentences[idx]
        in_slice = slice_fn(ds, idx)
        if in_slice:
            counter[slice_name] += 1
            
            if VISUALIZE_EXAMPLES:
                print(sentence, "->", in_slice)
                import pdb; pdb.set_trace()
                
pprint(dict(counter))

{'common_negation': 421,
 'has_but': 243,
 'has_coordinating_conjunction_hypothesis': 1220,
 'has_multiple_articles': 823,
 'has_possessive_preposition': 715,
 'has_temporal_preposition': 414,
 'has_wh_words': 745,
 'is_comparative': 300,
 'is_quantification': 868,
 'is_quantification_hypothesis': 145,
 'long_hypothesis': 186,
 'long_premise': 238,
 'short_hypothesis': 223,
 'short_premise': 233}


### Initialize slice payloads

In [14]:
# Create tasks and payloads
task_kwargs.update({"slice_dict": {"RTE": rte_slices}})
task_kwargs['attention'] = None

tasks_slice, payloads_slice = create_glue_tasks_payloads(
    task_names, **task_kwargs
)

Using random seed: 55356
Loading RTE Dataset


HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


Added label_set with 2490/2490 labels for task RTE_slice:has_temporal_preposition:ind to payload RTE_train.
Added label_set with 363/2490 labels for task RTE_slice:has_temporal_preposition:pred to payload RTE_train.
Added label_set with 2490/2490 labels for task RTE_slice:has_possessive_preposition:ind to payload RTE_train.
Added label_set with 619/2490 labels for task RTE_slice:has_possessive_preposition:pred to payload RTE_train.
Added label_set with 2490/2490 labels for task RTE_slice:is_comparative:ind to payload RTE_train.
Added label_set with 263/2490 labels for task RTE_slice:is_comparative:pred to payload RTE_train.
Added label_set with 2490/2490 labels for task RTE_slice:is_quantification:ind to payload RTE_train.
Added label_set with 749/2490 labels for task RTE_slice:is_quantification:pred to payload RTE_train.
Added label_set with 2490/2490 labels for task RTE_slice:is_quantification_hypothesis:ind to payload RTE_train.
Added label_set with 123/2490 labels for task RTE_sli

### Evaluate baseline slices

In [15]:
import copy
eval_payload = copy.deepcopy(payloads_slice[1])

# NOTE: we need to retarget slices to the original RTE head
for slice_name in rte_slices:
    label_name = f"RTE_slice:{slice_name}"
    eval_payload._retarget_labelset(f'{label_name}:pred', 'RTE')
    eval_payload._retarget_labelset(f'{label_name}:ind', None)

labelset 'RTE_slice:has_temporal_preposition:pred' -> task 'RTE' (originally, RTE_slice:has_temporal_preposition:pred).
labelset 'RTE_slice:has_temporal_preposition:ind' -> task 'None' (originally, RTE_slice:has_temporal_preposition:ind).
labelset 'RTE_slice:has_possessive_preposition:pred' -> task 'RTE' (originally, RTE_slice:has_possessive_preposition:pred).
labelset 'RTE_slice:has_possessive_preposition:ind' -> task 'None' (originally, RTE_slice:has_possessive_preposition:ind).
labelset 'RTE_slice:is_comparative:pred' -> task 'RTE' (originally, RTE_slice:is_comparative:pred).
labelset 'RTE_slice:is_comparative:ind' -> task 'None' (originally, RTE_slice:is_comparative:ind).
labelset 'RTE_slice:is_quantification:pred' -> task 'RTE' (originally, RTE_slice:is_quantification:pred).
labelset 'RTE_slice:is_quantification:ind' -> task 'None' (originally, RTE_slice:is_quantification:ind).
labelset 'RTE_slice:is_quantification_hypothesis:pred' -> task 'RTE' (originally, RTE_slice:is_quantific

In [16]:
eval_payload

Payload(RTE_valid: labels_to_tasks=[{'RTE_gold': 'RTE', 'RTE_slice:has_temporal_preposition:ind': None, 'RTE_slice:has_temporal_preposition:pred': 'RTE', 'RTE_slice:has_possessive_preposition:ind': None, 'RTE_slice:has_possessive_preposition:pred': 'RTE', 'RTE_slice:is_comparative:ind': None, 'RTE_slice:is_comparative:pred': 'RTE', 'RTE_slice:is_quantification:ind': None, 'RTE_slice:is_quantification:pred': 'RTE', 'RTE_slice:is_quantification_hypothesis:ind': None, 'RTE_slice:is_quantification_hypothesis:pred': 'RTE', 'RTE_slice:has_multiple_articles:ind': None, 'RTE_slice:has_multiple_articles:pred': 'RTE', 'RTE_slice:has_wh_words:ind': None, 'RTE_slice:has_wh_words:pred': 'RTE', 'RTE_slice:short_hypothesis:ind': None, 'RTE_slice:short_hypothesis:pred': 'RTE', 'RTE_slice:long_hypothesis:ind': None, 'RTE_slice:long_hypothesis:pred': 'RTE', 'RTE_slice:short_premise:ind': None, 'RTE_slice:short_premise:pred': 'RTE', 'RTE_slice:long_premise:ind': None, 'RTE_slice:long_premise:pred': 'RTE'

In [17]:
baseline.score(eval_payload)

{'RTE/RTE_valid/RTE_gold/accuracy': 0.740072202166065,
 'RTE/RTE_valid/RTE_slice:has_temporal_preposition:pred/accuracy': 0.8333333333333334,
 'RTE/RTE_valid/RTE_slice:has_possessive_preposition:pred/accuracy': 0.6363636363636364,
 'RTE/RTE_valid/RTE_slice:is_comparative:pred/accuracy': 0.7741935483870968,
 'RTE/RTE_valid/RTE_slice:is_quantification:pred/accuracy': 0.6931818181818182,
 'RTE/RTE_valid/RTE_slice:is_quantification_hypothesis:pred/accuracy': 0.7692307692307693,
 'RTE/RTE_valid/RTE_slice:has_multiple_articles:pred/accuracy': 0.8026315789473685,
 'RTE/RTE_valid/RTE_slice:has_wh_words:pred/accuracy': 0.7222222222222222,
 'RTE/RTE_valid/RTE_slice:short_hypothesis:pred/accuracy': 0.4,
 'RTE/RTE_valid/RTE_slice:long_hypothesis:pred/accuracy': 0.8181818181818182,
 'RTE/RTE_valid/RTE_slice:short_premise:pred/accuracy': 0.7272727272727273,
 'RTE/RTE_valid/RTE_slice:long_premise:pred/accuracy': 0.7727272727272727,
 'RTE/RTE_valid/RTE_slice:has_coordinating_conjunction_hypothesis:pre

## Next
1) Choose the above slices that seem to have gaps compraed to `_gold` labelset

2) Attempt to overfit to those slices: 
```
python launch.py --seed 1 --tasks RTE --slice_dict '{"RTE": ["short_hypothesis"]}' --model_type hard_param --model_weights /dfs/scratch0/vschen/metal-mmtl/logs/2019_04_30/17_57_08/best_model.pth --lr 5e-05 --lr_scheduler linear --optimizer adam --l2 1e-3 --n_epochs 50 --min_lr 1e-07
```