# 🧪 Experiment: Gazetteer Modification
This notebook conducts various experiments on the gazetteer modification. We modify all lexicons of the gazetteer as all entities in the `train set` and `dev set`.

**Note**: Before conducting experiments, you need to install `kaner` package first. Otherwise, this notebook will raise an *import error*.

```bash
cd ../
python setup.py install
```

In [None]:
import os
from typing import Dict, Any, List
from datetime import datetime

import torch.nn as nn

from kaner.context import GlobalContext as gctx
from kaner.adapter.in_adapter import split_dataset
from kaner.adapter.out_adapter import BaseOutAdapter
from kaner.adapter.tokenizer import CharTokenizer
from kaner.adapter.knowledge import Gazetteer
from kaner.trainer import NERTrainer, TrainerConfig
from kaner.tracker import NERTrackerRow, NERTracker
from kaner.common.func import query_time


gctx.init()

## 1 Define `trainall` Function

In [None]:
def train_with_gazetteer_modification(config: TrainerConfig) -> Dict[str, Any]:
    """
    Given a configuration, train a model on a dataset with gazetteer modification.

    Args:
        config (TrainerConfig): Trainer Configuration.
    """

    def update_hyperparameters(tokenizer: CharTokenizer, out_adapter: BaseOutAdapter, gazetteer: Gazetteer):
        """
        Update hyper parameters.

        Args:
            tokenizer (CharTokenizer): Tokenizer.
            out_adapter (BaseOutAdapter): Output adapter.
            gazetteer (Gazetteer): Gazetteer.
        """
        partial_configs = {"n_tags": len(out_adapter)}
        partial_configs.update(tokenizer.configs())
        partial_configs.update(gazetteer.configs())

        return partial_configs

    raw_datasets = split_dataset(config.dataset_folder, dataset_pp=config.dataset_pp)
    tokenizer = CharTokenizer(config.tokenizer_model_folder)
    tokenizer.save(config.output_folder)
    gazetteer = Gazetteer(config.gazetteer_model_folder)
    lexicons = ["[PAD]\tEntity\tDataset{0}:train+dev".format(config.dataset)]
    lexicons_without_repeating = set()
    for sample in raw_datasets[0] + raw_datasets[1]:
        for span in sample["spans"]:
            lexicons_without_repeating.add("{0}\t{1}\tDataset{1}:train+dev".format(span["text"], span["label"], config.dataset))
            # lexicons_without_repeating.add("{0}\tEntity\tDataset{1}:train+dev".format(span["text"], config.dataset))
    gazetteer.update(lexicons + list(lexicons_without_repeating))
    gazetteer.save(config.output_folder)
    out_adapter = gctx.create_outadapter(config.out_adapter, dataset_folder=config.dataset_folder, file_name="labels")
    out_adapter.save(config.output_folder, "labels")
    for raw_dataset in raw_datasets:
        gazetteer.count_freq(raw_dataset)
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=dataset, tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for dataset in raw_datasets
    )
    token_embeddings = tokenizer.embeddings()
    lexicon_embeddings = gazetteer.embeddings()
    config.hyperparameters = update_hyperparameters(tokenizer, out_adapter, gazetteer)
    collate_fn = gctx.create_batcher(
        config.model, input_pad=tokenizer.pad_id, output_pad=out_adapter.unk_id, lexicon_pad=gazetteer.pad_id, device=config.device
    )
    model = gctx.create_model(config.model, **config.hyperparameters, token_embeddings=token_embeddings, lexicon_embeddings=lexicon_embeddings)
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, nn.CrossEntropyLoss(),
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    results = trainer.train()

    return results, trainer


def trainall(labpath: str, cfgdir: str, m: List[str], d: List[str], n: int, tag: str, **kwargs) -> None:
    """
    Experiments for all model's training.

    Args:
        labpath (str): The file path of recording experimental results.
        cfgdir (str): Configuration folder.
        m (List[str]): All specific models to be trained.
        d (List[str]): All specific datasets to be tested.
        n (int): The number of training repeating times.
        tag (str): Experimental tags.
    """

    def update_names(names: List[str], all_names: List[str], name_type: str) -> List[str]:
        """
        Check whether the name that user inputs is correct.

        Args:
            names (List[str]): The names (dataset, model, gazetteer) that user inputs.
            all_names (List[str]): All names (dataset, model, gazetteer) that this libary provides.
            name_type (str): The type of the name (Dataset, Model, Gazetteer).
        """
        if len(names) == 0:
            names = all_names
        else:
            for name in names:
                if name not in all_names:
                    print("[{0}] {1} is not in {2}".format(name_type, name, all_names))
                    exit(0)
        return names

    tracker = NERTracker.load(labpath)
    models = update_names(m, gctx.get_model_names(), "Model")
    datasets = update_names(d, gctx.get_dataset_names(), "Dataset")

    print("--------------------- Laboratory Configuration ---------------------")
    print("Models: {0}".format(models))
    print("Datasets: {0}".format(datasets))
    print("--------------------------------------------------------------------")

    for dataset in datasets:
        for model in models:
            for _ in range(n):
                if len(tracker.query(dataset=dataset, model=model, tag=tag)) >= n:
                    continue
                config = TrainerConfig(os.path.join(cfgdir, model + ".yml"), dataset=dataset, **kwargs)
                start = str(datetime.now())
                try:
                    results, trainer = train_with_gazetteer_modification(config)
                except RuntimeError as error:
                    print(error)
                    continue
                tracker.insert(
                    NERTrackerRow(
                        start, model, dataset, config.tokenizer_model, "entity-from-train+dev", config.output_folder, query_time(trainer.train),
                        results["f1-score"], results["precision-score"], results["recall-score"], results["epoch_count"], results["test-loss"], tag
                    )
                )
                tracker.save(labpath)
                del trainer

## 2 Given models, datasets, gazetteers, train them
You can find all available models, datasets and gazetteers by the following code block.

```python
models = gctx.get_model_names()
datasets = gctx.get_dataset_names()
gazetteers = gctx.get_gazetteer_names()
```

In [None]:
labpath = "../data/logs/experiment_gazetteer_modification[train+dev]_with_types.csv"
cfgdir = "../configs"
models = ["mdgg", "ses", "cgn"]
datasets = ["chip"]
n = 5
tag = "gazetteer_modification"
kwargs = {"data_folder": "../data"}

trainall(labpath, cfgdir, models, datasets, n, tag, **kwargs)