# Training Emely

This notebook is for training Emely with different configurations.
Use the blender_opts dictionary for the standard options. 

### Configuration

#### Base Config
We'll call the base configuration "Blender base config" and it's the blender 90M model fine tuned on the internal and external tasks

### Required config for a run

- task
- multitask_weights
- model_file

### Optional config
- mutators
- lr


### Different mutators for different tasks?
--task internal:mutators=word_shuffle,internal:mutators=last_turn


### Evaluation
All models are evaluated on the internal and external tasks

In [1]:
from parlai.scripts.train_model import TrainModel
from pathlib import Path
from copy import deepcopy
import shutil

In [2]:
blender_opts = {'init_model': 'zoo:blender/blender_90M/model',
                'dict_file': 'zoo:blender/blender_90M/model.dict',
              'model': 'transformer/generator',
              'embedding_size': 512,
              'n_layers': 8,
              'ffn_size': 2048,
              'dropout': 0.1,
              'n_heads': 16,
              'learn_positional_embeddings': True,
              'n_positions': 512,
              'variant': 'xlm',
              'activation': 'gelu',
              'fp16': True,
              'text_truncate': 512,
              'label_truncate': 128,
              'dict_tokenizer': 'bpe',
              'optimizer': 'adamax',
              'lr_scheduler': 'reduceonplateau',
              'betas': '0.9,0.999',
              'update_freq': 1,
              'attention_dropout': 0.0,
              'relu_dropout': 0.0,
              'dict_lower': True,
              'lr': 1e-06,
              'gradient_clip': 0.1,
              'veps': 0.25,
              'skip_generation': False,
              'vp': 15,
              'stim': 60,
              'vme': 20000,
              'bs': 16,
              'vmt': 'ppl',
              'vmm': 'min',
              'save_after_valid': True,
              'wblog': True,
              'wandb_project': 'parlaiemely',
              'tensorboard_log': True,
              'metrics': 'ppl,bleu-4,rouge-L',
              'evaltask': 'internal,external',
              'inference': 'beam',
              'beam_size': 10,
              'beam_min_length': 10,
              'beam_block_ngram': 3
               }

#              'dict_file': 'zoo:blender/blender_90M/model.dict'

In [3]:
def run_training(tasks, weights, mutators=None):

    # Set name for file and model run on wandb
    if mutators is not None:
        name = f'blender-{tasks}-{weights}-{mutators}'

    else:
        name = f'blender-{tasks}-{weights}'

        
    #%env WANDB_NAME=$name
    mf = Path.cwd().parents[1].joinpath(f'models/model-runs/{name}/model')
    
    if mutators is not None:
        run_opts = {'task': tasks,
                    'multitask_weights': weights,
                    'model_file': mf.as_posix(),
                    'mutators': mutators
                    }        
    else:
        run_opts = {'task': tasks,
                    'multitask_weights': weights,
                    'model_file': mf.as_posix()
                   }
    
    # Copy the standard opts and update them
    opts = deepcopy(blender_opts)
    opts.update(run_opts)

    TrainModel.main(**opts)

## 0. Test run 2 after each other

In [None]:
name = 'TEST'
mf = Path.cwd().parents[1].joinpath(f'models/model-runs/{name}/model')

if mf.parent.exists():
    shutil.rmtree(mf.parent)

tasks = 'internal,external'
weights='5,1'
mutators='word_shuffle'
run_training(tasks=tasks,weights=weights, mutators=mutators)

## 1. Blender base config

### Datasets with sampling weights:
- internal - 6
- external - 3

### Mutators
None

In [None]:
tasks='internal,external'
weights= '6,3'

run_training(tasks=tasks,weights=weights)

18:25:42 | building dictionary first...
18:25:42 | No model with opt yet at: /home/alex/ParlaiEmely/models/model-runs/blender-internal,external-6,3/model(.opt)
18:25:42 | [33myour model is being loaded with opts that do not exist in the model you are initializing the weights with: allow_missing_init_opts: False,download_path: None,loglevel: info,dynamic_batching: None,verbose: False,is_debug: False,datapath: /home/alex/ParlaiEmely/ParlAI/data,eval_dynamic_batching: None,num_workers: 0,max_train_steps: -1,log_every_n_steps: 50,validation_every_n_steps: -1,load_from_checkpoint: True,tensorboard_logdir: None,wandb_log: True,wandb_name: None,wandb_project: parlaiemely,wandb_entity: None,mutators: None,n_encoder_layers: -1,n_decoder_layers: -1,model_parallel: False,beam_block_full_context: True,beam_delay: 30,beam_block_list_filename: None,temperature: 1.0,interactive_mode: False,history_reversed: False,history_add_global_end_token: None,special_tok_lst: None,bpe_vocab: None,bpe_merge: Non

18:25:46 |     tensorboard_logdir: None
18:25:46 |     text_truncate: 512
18:25:46 |     topk: 10
18:25:46 |     topp: 0.9
18:25:46 |     truncate: -1
18:25:46 |     update_freq: 1
18:25:46 |     use_reply: label
18:25:46 |     validation_cutoff: 1.0
18:25:46 |     validation_every_n_epochs: 0.25
18:25:46 |     validation_every_n_secs: -1
18:25:46 |     validation_every_n_steps: -1
18:25:46 |     validation_max_exs: 20000
18:25:46 |     validation_metric: ppl
18:25:46 |     validation_metric_mode: min
18:25:46 |     validation_patience: 15
18:25:46 |     validation_share_agent: False
18:25:46 |     variant: xlm
18:25:46 |     verbose: False
18:25:46 |     wandb_entity: None
18:25:46 |     wandb_log: True
18:25:46 |     wandb_name: None
18:25:46 |     wandb_project: parlaiemely
18:25:46 |     warmup_rate: 0.0001
18:25:46 |     warmup_updates: -1
18:25:46 |     weight_decay: None
18:25:46 | Current ParlAI commit: e3c1edbef397de2c084a521fb0bb81489a432c74
18:25:46 | creating task(s): inter

[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)


18:26:09 | training...
18:26:11 | time:25s total_exs:256 total_steps:16 epochs:0.26
             clen  clip  ctpb  ctps  ctrunc  ctrunclen  exps  exs  fp16_loss_scalar  gnorm  gpu_mem  llen  loss    lr  ltpb  \
   all      55.46     1 884.4  7331       0          0 132.6  256             32768  12.56    .1370 15.49 2.844 1e-06 245.9   
   external 56.23                         0          0         97                                   15.99 3.041               
   internal 54.69                         0          0        159                                   14.99 2.647               
             ltps  ltrunc  ltrunclen   ppl  token_acc  token_em  total_train_updates  tpb  tps   ups  
   all       2038       0          0 17.51      .4114         0                   16 1130 9370 8.348  
   external             0          0 20.92      .3830         0                                       
   internal             0          0 14.11      .4398         0

18:26:11 | creating task(s): inter

18:27:16 | running eval: valid
18:27:31 | eval completed in 15.48s
18:27:31 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .007787 53.69 635.2 512.6       0          0 11.39  171 .1783   .06922 14.68 2.767 1e-06   190 153.3   
   external         0 .009779 62.36                   0          0         44 .1576          15.89 3.042                     
   internal         0 .005794 45.02                   0          0        127 .1990          13.46 2.492                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 16.52    .1777      .4361         0                   64 825.2 665.9  
   external       0          0 20.96    .1570      .4049         0                                   
   internal       0          0 12.09    .1984      .4673         0
[0m
18:27:31 | saving model check

18:28:37 | running eval: valid
18:28:51 | eval completed in 14.89s
18:28:51 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0  .01258 53.69 635.2 528.4       0          0 11.74  171 .1957   .06921 14.68 2.724 1e-06   190   158   
   external         0  .01791 62.36                   0          0         44 .1775          15.89 3.003                     
   internal         0 .007256 45.02                   0          0        127 .2139          13.46 2.444                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 15.83    .1923      .4397         0                  128 825.2 686.4  
   external       0          0 20.15    .1774      .4034         0                                   
   internal       0          0 11.52    .2072      .4760         0
[0m
18:28:52 | saving model check



18:29:53 | running eval: valid
18:30:07 | eval completed in 13.98s
18:30:07 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009824 53.69 635.2 562.2       0          0 12.49  171 .2000   .06906 14.68 2.697 1e-06   190 168.1   
   external         0 .009793 62.36                   0          0         44 .1722          15.89 2.979                     
   internal         0 .009855 45.02                   0          0        127 .2278          13.46 2.414                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 15.43    .1988      .4457         0                  192 825.2 730.3  
   external       0          0 19.67    .1728      .4077         0                                   
   internal       0          0 11.18    .2248      .4836         0
[0m
18:30:07 | saving model check

18:31:06 | running eval: valid
18:31:20 | eval completed in 13.51s
18:31:20 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0  .01164 53.69 635.2 583.5       0          0 12.96  171 .1960   .06912 14.68 2.677 1e-06   190 174.5   
   external         0 .009789 62.36                   0          0         44 .1634          15.89 2.961                     
   internal         0   .0135 45.02                   0          0        127 .2286          13.46 2.393                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb  tps  
   all            0          0 15.13    .1975      .4500         0                  256 825.2  758  
   external       0          0 19.31    .1676      .4134         0                                  
   internal       0          0 10.94    .2275      .4865         0
[0m
18:31:20 | saving model checkpoi

18:32:19 | running eval: valid
18:32:32 | eval completed in 13.24s
18:32:32 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009105 53.69 635.2 598.3       0          0 13.29  171 .2062   .06923 14.68 2.661 1e-06   190 178.9   
   external         0 .004705 62.36                   0          0         44 .1690          15.89 2.946                     
   internal         0  .01351 45.02                   0          0        127 .2434          13.46 2.376                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0  14.9    .2045      .4543         0                  320 825.2 777.2  
   external       0          0 19.04    .1710      .4192         0                                   
   internal       0          0 10.76    .2380      .4895         0
[0m
18:32:32 | saving model check

18:33:30 | running eval: valid
18:33:44 | eval completed in 13.05s
18:33:44 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009104 53.69 635.2 602.4       0          0 13.38  171 .1993   .06939 14.68 2.649 1e-06   190 180.2   
   external         0 .004705 62.36                   0          0         44 .1598          15.89 2.935                     
   internal         0   .0135 45.02                   0          0        127 .2389          13.46 2.363                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 14.72    .1947      .4547         0                  384 825.2 782.6  
   external       0          0 18.82    .1569      .4206         0                                   
   internal       0          0 10.62    .2324      .4889         0
[0m
18:33:44 | saving model check

18:34:43 | running eval: valid
18:34:56 | eval completed in 13.16s
18:34:56 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009932 53.69 635.2 602.5       0          0 13.38  171 .2054   .06916 14.68 2.638 1e-06   190 180.2   
   external         0 .004709 62.36                   0          0         44 .1684          15.89 2.924                     
   internal         0  .01516 45.02                   0          0        127 .2425          13.46 2.352                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 14.57    .2038      .4558         0                  448 825.2 782.7  
   external       0          0 18.62    .1656      .4220         0                                   
   internal       0          0 10.51    .2421      .4895         0
[0m
18:34:56 | saving model check

18:35:54 | running eval: valid
18:36:07 | eval completed in 12.87s
18:36:07 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .008659 53.69 635.2 615.2       0          0 13.67  171 .2050   .06917 14.68 2.629 1e-06   190   184   
   external         0 .004709 62.36                   0          0         44 .1716          15.89 2.914                     
   internal         0  .01261 45.02                   0          0        127 .2383          13.46 2.343                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 14.42    .2026      .4597         0                  512 825.2 799.2  
   external       0          0 18.43    .1685      .4235         0                                   
   internal       0          0 10.41    .2368      .4959         0
[0m
18:36:07 | saving model check

18:37:03 | running eval: valid
18:37:16 | eval completed in 12.65s
18:37:16 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009612 53.69 635.2 620.6       0          0 13.78  171 .2094   .06933 14.68  2.62 1e-06   190 185.6   
   external         0 .004709 62.36                   0          0         44 .1739          15.89 2.904                     
   internal         0  .01451 45.02                   0          0        127 .2450          13.46 2.336                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 14.29    .2043      .4600         0                  576 825.2 806.2  
   external       0          0 18.25    .1670      .4235         0                                   
   internal       0          0 10.34    .2416      .4965         0
[0m
18:37:16 | saving model check

18:38:13 | running eval: valid
18:38:26 | eval completed in 13.31s
18:38:26 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0 .009738 53.69 635.2   573       0          0 12.73  171 .2155   .06905 14.68 2.613 1e-06   190 171.4   
   external         0 .004959 62.36                   0          0         44 .1807          15.89 2.898                     
   internal         0  .01452 45.02                   0          0        127 .2503          13.46 2.329                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0  14.2    .2102      .4611         0                  640 825.2 744.3  
   external       0          0 18.14    .1743      .4263         0                                   
   internal       0          0 10.27    .2462      .4959         0
[0m
18:38:26 | saving model check

18:39:22 | running eval: valid
18:39:35 | eval completed in 12.25s
18:39:35 | [1mvalid:
             accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gpu_mem  llen  loss    lr  ltpb  ltps  \
   all              0  .01373 53.69 635.2 641.3       0          0 14.25  171 .2195   .06902 14.68 2.606 1e-06   190 191.8   
   external         0  .01308 62.36                   0          0         44 .1899          15.89 2.891                     
   internal         0  .01438 45.02                   0          0        127 .2491          13.46 2.321                     
             ltrunc  ltrunclen   ppl  rouge_L  token_acc  token_em  total_train_updates   tpb   tps  
   all            0          0 14.11    .2163      .4620         0                  704 825.2 833.1  
   external       0          0 18.02    .1850      .4263         0                                   
   internal       0          0 10.19    .2477      .4977         0
[0m
18:39:35 | saving model check

## 2. Blender base config + bst

### Datasets with sampling weights:
- internal: 6
- external: 3
- blended_skill_talk: 1

### Mutators
None

In [None]:
tasks='internal,external,blended_skill_talk'
weights= '6,3,1'

run_training(tasks=tasks,weights=weights)

## 3. Blender base + otter

### Datasets with sampling weights:
- internal: 6
- external: 3
- otter: 1

### Mutators
None

In [None]:
tasks='internal,external,otter'
weights= '6,3,1'

run_training(tasks=tasks,weights=weights)

## 4. Blender base + otter and bst

### Datasets with sampling weights:
- internal: 6
- external: 3
- otter: 1
- bst: 1

### Mutators
None

In [None]:
tasks='internal,external,otter,blended_skill_talk'
weights= '6,3'

run_training(tasks=tasks,weights=weights)

# Mutator runs below

Mutators change the dataset

## 5. Blender base +  word_shuffle

### Datasets with sampling weights:
- internal: 6
- external: 3


### Mutators
- word_shuffle


In [None]:
tasks='internal,external'
weights= '6,3'
mutators = 'word_shuffle'

run_training(tasks=tasks, weights=weights, mutators=mutators)

## 6. Blender base +  last_turn

### Datasets with sampling weights:
- internal: 6
- external: 3


### Mutators
- last turn

In [None]:
tasks='internal,external'
weights= '6,3'
mutators = 'last_turn'

run_training(tasks=tasks, weights=weights, mutators=mutators)

## 7. Blender base + flatten

In [None]:
tasks='internal,external'
weights= '6,3'
mutators = 'flatten'

run_training(tasks=tasks, weights=weights, mutators=mutators)

## 8. Blender base + mixed mutators

### Datasets with sampling weights:
- internal: 6
- external: 3


### Mutators
- None
- word_shuffle
- flatten
- last_turn

In [None]:
tasks = 'internal,external'
weights= '2,1'
mutators = 'word_shuffle,flatten,last_turn'

run_training(tasks=tasks, weights=weights, mutators=mutators)