# 🧪 Experiment: Lexicon Masking
This notebook evaluate the test set for the task `Lexicon Masking`.

**Note**: Before conducting experiments, you need to install `kaner` package first. Otherwise, this notebook will raise an *import error*.

```bash
cd ../
python setup.py install
```

In [None]:
import json
import os
import gc
from copy import deepcopy
from typing import List
import pprint

import tqdm

from kaner.context import GlobalContext as gctx
from kaner.adapter.tokenizer import CharTokenizer
from kaner.adapter.knowledge import Gazetteer
from kaner.adapter.in_adapter import split_dataset
from kaner.trainer import NERTrainer, TrainerConfig
from kaner.common import load_json, load_jsonl, save_json


gctx.init()

## 1. Define Intervention

In [None]:
def get_matlexs(datasets: List[dict], gazetteer: Gazetteer, mode: str) -> List[set]:
    """
    Given a matching mode, return all matched lexicons.
    """
    max_seq_len = 512
    assert mode in ["all", "entity", "non-entity"]
    # get all spans
    all_spans = set()
    for dataset in datasets:
        for datapoint in dataset:
            for span in datapoint["spans"]:
                all_spans.add(span["text"])
    # get all matched lexicons
    matched_lexicons = []
    for dataset in datasets:
        lexicons = set()
        for datapoint in dataset:
            tokens = list(datapoint["text"])[:max_seq_len]
            for i, _ in enumerate(tokens):
                items = gazetteer.search(tokens[i:])
                if mode == "all":
                    lexicons.update(items)
                else:
                    for item in items:
                        if mode == "entity":
                            if item in all_spans:
                                lexicons.add(item)
                        elif mode == "non-entity":
                            if item not in all_spans:
                                lexicons.add(item)
        matched_lexicons.append(lexicons)
    
    return matched_lexicons


def evaluate(model_folder: str) -> dict:
    """
    Evaluate all settings.
    """
    options = load_json("utf-8", model_folder, "config.json")
    options["output_folder"] = model_folder
    options["identity"] = os.path.basename(os.path.normpath(model_folder))
    config = TrainerConfig(options, data_folder="../data")
    tokenizer = CharTokenizer(model_folder)
    gazetteer = Gazetteer(model_folder)
    datasets = split_dataset(config.dataset_folder)
    out_adapter = gctx.create_outadapter(config.out_adapter, dataset_folder=model_folder, file_name="labels")
    collate_fn = gctx.create_batcher(
        config.model, input_pad=tokenizer.pad_id, output_pad=out_adapter.unk_id, lexicon_pad=gazetteer.pad_id, device=config.device
    )
    model = gctx.create_model(config.model, **config.hyperparameters)

    result = {}
    # IS Intervention
    A, _, B = get_matlexs(datasets, gazetteer, "all")
    I = A.intersection(B)
    S = A.union(B) - A
    gazetteer.mask(list(I), True)
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=dataset, tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for dataset in [[], [], datasets[2]]
    )
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, None,
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    trainer.startp()
    result["I"] = trainer._test(trainer._test_loader)
    del trainer
    gc.collect()
    gazetteer.mask(list(I), False)

    gazetteer.mask(list(S), True)
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=dataset, tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for dataset in [[], [], datasets[2]]
    )
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, None,
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    trainer.startp()
    result["S"] = trainer._test(trainer._test_loader)
    del trainer
    gc.collect()
    gazetteer.mask(list(S), False)

    # Entity vs. Non-Entity
    _, _, E = get_matlexs(datasets, gazetteer, "entity")
    _, _, N = get_matlexs(datasets, gazetteer, "non-entity")
    gazetteer.mask(list(E), True)
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=dataset, tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for dataset in [[], [], datasets[2]]
    )
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, None,
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    trainer.startp()
    result["E"] = trainer._test(trainer._test_loader)
    del trainer
    gc.collect()
    gazetteer.mask(list(E), False)

    gazetteer.mask(list(N), True)
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=dataset, tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for dataset in [[], [], datasets[2]]
    )
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, None,
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    trainer.startp()
    result["N"] = trainer._test(trainer._test_loader)
    del trainer
    gc.collect()
    gazetteer.mask(list(N), False)

    return result

## 2. Execute `do` Operator

In [None]:
def load_experiments(folder: str = "../data/logs") -> List[dict]:
    file_path = os.path.join(folder, "experiments.csv")
    logs = []
    with open(file_path, "r", encoding="utf-8") as fin:
        line = fin.readline()
        columns = line.replace("\n", "").split(",")
        while True:
            line = fin.readline()
            if not line:
                break
            log = {k: v for k, v in zip(columns, line.replace("\n", "").split(","))}
            if log["model"] not in ["blcrf", "plmtg"]:
                logs.append(log)

    return logs


dolog_path = os.path.join("../data", "do_full_logs.json")
if os.path.isfile(dolog_path):
    logs = load_json("utf-8", dolog_path)
else:
    logs = load_experiments()
for i, _ in enumerate(logs):
    print("## Log {0}...........................".format(i))
    if "do" in logs[i].keys():
        continue
    if not logs[i]["log_dir"].startswith("../"):
        folder = os.path.join("../", logs[i]["log_dir"])
    else:
        folder = logs[i]["log_dir"]
    folder = folder.replace("tmp/", "")
    logs[i]["do"] = evaluate(folder)
    save_json(logs, dolog_path)