<a href="https://colab.research.google.com/github/chungminhtu/nnsplit/blob/master/train/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to train [NNSplit](https://github.com/chungminhtu/nnsplit/) on a custom dataset and load it for inference.

# Setup

First, clone the Github Repo and install requirements. If you are running this on Colab, you will likely have to restart the runtime after installing the requirements because of some version mismatches.

In [None]:
!git clone https://www.github.com/chungminhtu/nnsplit

# Data preparation

Training NNSplit is not limited to a specific dataset. Howevever, I have found the [Linguatools Wikipedia Dumps](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) to work well, so there is built-in functionality to load those. Feel free to use other data!

First, download the `.xml.bz2` file and unzip it.

In [None]:
!wget https://www.dropbox.com/s/cnrhd11zdtc1pic/enwiki-20181001-corpus.xml.bz2?dl=1

In [None]:
!mv enwiki-20181001-corpus.xml.bz2?dl=1 enwiki-20181001-corpus.xml.bz2

In [None]:
!bzip2 -d enwiki-20181001-corpus.xml.bz2

Now we can create the dataset. `xml_dump_iter` is one of the built in methods which yields an iterator over all texts in the wikipedia dump, trying to remove tags and other markup.

In [None]:
import sys
sys.path.append("nnsplit/train")
from text_data import MemoryMapDataset, xml_dump_iter

In [None]:
xml_iter = xml_dump_iter("enwiki-20181001-corpus.xml", 
                         min_text_length=300, 
                         max_text_length=5000)
next(xml_iter)

`MemoryMapDataset` is another convient built-in class, but not specific to the Wikipedia dump. It is a `torch.utils.data.Dataset` which can be created using a `texts.txt` and `slices.pkl` file. The `texts.txt` file is [memory-mapped](https://en.wikipedia.org/wiki/Memory-mapped_file) and `slices.pkl` contains a Python array with indices that determine at which position in the dataset which range of the text should be loaded. This allows accessing each text without ever loading all the data into memory.

To create `texts.txt` and `slices.pkl` from an iterator over text, use `MemoryMapDataset.iterator_to_text_and_slices`.

Note that this will be quite slow since iterating over the XML dump takes a significant amount of time, so I would recommend caching `texts.txt` and `slices.pkl` somewhere.

`max_n_texts=10_000_000` is only needed in Colab to keep disk usage in check, feel free to remove this otherwise.

In [None]:
xml_iter = xml_dump_iter("enwiki-20181001-corpus.xml", 
                         min_text_length=300,
                         max_text_length=5000)
MemoryMapDataset.iterator_to_text_and_slices(xml_iter, 
                                             "texts.txt", 
                                             "slices.pkl",
                                             max_n_texts=10_000_000)

Here, I am saving the outputs to my Drive, you will have to adjust these paths.

In [None]:
!cp -a slices.pkl "drive/My Drive/Projects/nnsplit/slices.pkl"
!cp -a texts.txt "drive/My Drive/Projects/nnsplit/texts.txt"

# Training

Now we can get started with training!

In [None]:
import sys
sys.path.append("nnsplit/train")

In [None]:
import json
from pytorch_lightning.trainer import Trainer
from tqdm.auto import tqdm
from model import Network
from text_data import MemoryMapDataset

NNSplit has a `Network` class which is a `pl.LightningModule` specifying network architecture, data loading logic etc. To instantiate a new network, we need to first get the default hyperparameters.

In [None]:
parser = Network.get_parser()
hparams = parser.parse_args([])
hparams

## Load text data

Next, we can load the text data created previously.

In [None]:
text_dataset = MemoryMapDataset("texts.txt", "slices.pkl")

Keep in mind that this can be any `torch.utils.data.Dataset` with `str` entries, so you can completely customize it.

In [None]:
text_dataset[0]

## Load labeler

Next, create a `Labeler`, which is used to annotate the text from above. Any SpaCy model which supports sentencization can be used. You will have to install the appropriate SpaCy model with `python -m spacy ...` when running this in Colab.

In [None]:
from labeler import Labeler, SpacySentenceTokenizer, SpacyWordTokenizer

In [None]:
labeler = Labeler(
    [
        SpacySentenceTokenizer(
            "en_core_web_sm", lower_start_prob=0.7, remove_end_punct_prob=0.7
        ),
        SpacyWordTokenizer("en_core_web_sm"),
    ]
)

`Labeler.visualize` shows you what the network sees: 
- `byte` is the UTF-8 encoded text. This has changed in the newest version of NNSplit. Previously characters where used, but using bytes allows NNSplit to work for any language regardless of the characters used to represent it.
- The other rows depend on the `Labeler` and determine what the neural networks tries to predict.

In [None]:
labeler.visualize("This is a test. This is another test.")

## Start training!

Now we can finally start training. 

`train_size` determines how many entries in the dataset to sample for each epoch. 

Using SpaCy with multiprocessing leaks memory, so the memory usage will continously increase during each epoch and reset at the end. So you will have to set `train_size` to a size that corresponds to how much memory is available. `500_000` works well in Colab.


In [None]:
hparams.gpus = 1
hparams.max_epochs = 4
hparams.train_size = 500_000
hparams.predict_indices = [0, 1] # which split levels of the labeler to predict
# how to weigh the selected indices
# in general sentence boundary detection should be weighed the highest
hparams.level_weights = [0.1, 2.0]

Instantiate the network.

In [None]:
model = Network(
  text_dataset,
  labeler,
  hparams,
)
model

Instantiate the `pl.trainer.Trainer`.

In [None]:
trainer = Trainer.from_argparse_args(hparams)

And fit the model. Each row of the f1 and precision scores corresponds to each tokenizer of the `Labeler`.

In [None]:
trainer.fit(model)

Finally, store the trained model somewhere. This saves a `.onnx` export of the model in the specified directory.

In [None]:
# onnx metadata which determines how to use the prediction indices to split text
metadata = {
    "split_sequence": json.dumps(
        {
            "instructions": [
                ["Sentence", {"PredictionIndex": 0}],
                ["Token", {"PredictionIndex": 1}],
                ["_Whitespace", {"Function": "whitespace"}],
            ]
        }
    )
}
model.store("en", metadata)

# Load the model in NNSplit

First, install NNSplit.

In [None]:
!pip install nnsplit

In [None]:
from nnsplit import NNSplit

Instantiate the splitter.

In [None]:
splitter = NNSplit("en/model.onnx", use_cuda=True)

And split a text!

In [None]:
splits = splitter.split(["This is a test This is another test."])[0]
splits

The public API of NNSplit has changed significantly, making it much easier to use now. Everything is a `nnsplit.Split` which can be iterated over or stringified with `str(...)`.

In [None]:
for sentence in splits:
    print(str(sentence).ljust(30), type(sentence))

Or if you want to go token-level:

In [None]:
for sentence in splits:
    for token in sentence:
        print(str(token).ljust(10), repr(token).ljust(30), type(token))

    print()

Until the smallest unit, which then returns a `str` instead of an `nnsplit.Split`.

In [None]:
for sentence in splits:
    for [text, whitespace] in sentence:
        print(text.ljust(10), type(text))
        print(f'"{whitespace}"'.ljust(10), type(whitespace))
        print()

Finally, for some benchmarks: If you are running `NNSplit` on GPU, you can increase the speed on large datasets by using a big batch size.

In [None]:
splitter = NNSplit("en/model.onnx", use_cuda=True, batch_size=2**14)

In [None]:
text = "This is a test This is another test."

%timeit splitter.split([text])[0]
%timeit splitter.split([text] * 100)[0]
%timeit splitter.split([text] * 1000)[0]
%timeit splitter.split([text] * 10_000)[0]

And voilà! Splitting 10000 short texts in less than 400 milliseconds.