# MNLI Diagnostic Example

## Setup

#### Install dependencies

In [None]:
%%capture
!git clone https://github.com/nyu-mll/jiant.git
%cd jiant && git checkout tags/2.0.0

In [None]:
%%capture
# This Colab notebook already has its CUDA-runtime compatible versions of torch and torchvision installed
!pip install -r jiant/requirements-no-torch.txt
# Install pyarrow for nlp
!pip install pyarrow==0.16.0

#### Download data

In [None]:
%%capture
# Download/preprocess MNLI and Dognostic data
!PYTHONPATH=/content/jiant python jiant/jiant/scripts/download_data/runscript.py \
    download \
    --tasks mnli mnli_mismatched glue_diagnostics \
    --output_path=/content/tasks/

## `jiant` Pipeline

In [None]:
import sys
sys.path.insert(0, "/content/jiant")

In [None]:
import jiant.proj.main.tokenize_and_cache as tokenize_and_cache
import jiant.proj.main.export_model as export_model
import jiant.proj.main.scripts.configurator as configurator
import jiant.proj.main.runscript as main_runscript
import jiant.shared.caching as caching
import jiant.utils.python.io as py_io
import jiant.utils.display as display
import os
import torch

#### Download model

In [None]:
export_model.lookup_and_export_model(
    model_type="roberta-base",
    output_base_path="./models/roberta-base",
)

#### Tokenize and cache


In [None]:
# Tokenize and cache each task
tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(
    task_config_path=f"./tasks/configs/mnli_config.json",
    model_type="roberta-base",
    model_tokenizer_path="./models/roberta-base/tokenizer",
    output_dir=f"./cache/mnli",
    phases=["train", "val"],
))

tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(
    task_config_path=f"./tasks/configs/mnli_mismatched_config.json",
    model_type="roberta-base",
    model_tokenizer_path="./models/roberta-base/tokenizer",
    output_dir=f"./cache/mnli_mismatched",
    phases=["val"],
))

tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(
    task_config_path=f"./tasks/configs/glue_diagnostics_config.json",
    model_type="roberta-base",
    model_tokenizer_path="./models/roberta-base/tokenizer",
    output_dir=f"./cache/glue_diagnostics",
    phases=["test"],
))

In [None]:
row = caching.ChunkedFilesDataCache("./cache/mnli/train").load_chunk(0)[0]["data_row"]
print(row.input_ids)
print(row.tokens)

In [None]:
row = caching.ChunkedFilesDataCache("./cache/mnli_mismatched/val").load_chunk(0)[0]["data_row"]
print(row.input_ids)
print(row.tokens)

In [None]:
row = caching.ChunkedFilesDataCache("./cache/glue_diagnostics/test").load_chunk(0)[0]["data_row"]
print(row.input_ids)
print(row.tokens)

#### Writing a run config

In [None]:
jiant_run_config = configurator.SimpleAPIMultiTaskConfigurator(
    task_config_base_path="./tasks/configs",
    task_cache_base_path="./cache",
    train_task_name_list=["mnli"],
    val_task_name_list=["mnli", "mnli_mismatched"],
    test_task_name_list=["glue_diagnostics"],
    train_batch_size=8,
    eval_batch_size=16,
    epochs=0.1,
    num_gpus=1,
).create_config()
display.show_json(jiant_run_config)

Configure all three tasks to use an `mnli` head.

In [None]:
jiant_run_config["taskmodels_config"]["task_to_taskmodel_map"] = {
    "mnli": "mnli",
    "mnli_mismatched": "mnli",
    "glue_diagnostics": "glue_diagnostics",
}
os.makedirs("./run_configs/", exist_ok=True)
py_io.write_json(jiant_run_config, "./run_configs/jiant_run_config.json")

#### Start training

In [None]:
run_args = main_runscript.RunConfiguration(
    jiant_task_container_config_path="./run_configs/jiant_run_config.json",
    output_dir="./runs/run1",
    model_type="roberta-base",
    model_path="./models/roberta-base/model/roberta-base.p",
    model_config_path="./models/roberta-base/model/roberta-base.json",
    model_tokenizer_path="./models/roberta-base/tokenizer",
    learning_rate=1e-5,
    eval_every_steps=500,
    do_train=True,
    do_val=True,
    do_save=True,
    write_test_preds=True,
    force_overwrite=True,
)
main_runscript.run_loop(run_args)

In [None]:
test_preds = torch.load("./runs/run1/test_preds.p")
test_preds["glue_diagnostics"]