# Develop a New Operator

## Coding for Your Operator

- Add a new StatsKeys in `data_juicer/utils/constant.py` to store the statistical variable of the new operator.

In [None]:
class StatsKeys(object):
    # ... other keys
    text_len = 'text_len'

- Create a new operator file, such as `your_text_length_filter.py` in the corresponding `data_juicer/ops/filter/` directory as follows.

In [None]:
import sys
from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields
# NOTE: use a new definition above
# from data_juicer.utils.constant import StatsKeys
from data_juicer.ops.base_op import OPERATORS, Filter


@OPERATORS.register_module('your_text_length_filter')
class YourTextLengthFilter(Filter):
    """Filter to keep samples with total text length within a specific
    range."""

    def __init__(self,
                min_len: PositiveInt = 10,
                max_len: PositiveInt = sys.maxsize,
                *args,
                **kwargs):
        """
        Initialization method.

        :param min_len: The min text length in the filtering. samples
            will be filtered if their text length is below this
            parameter.
        :param max_len: The max text length in the filtering. samples
            will be filtered if their text length exceeds this
            parameter.
        :param args: extra args
        :param kwargs: extra args
        """
        super().__init__(*args, **kwargs)
        self.min_len = min_len
        self.max_len = max_len

    def compute_stats(self, sample):
        # check if it's computed already
        if StatsKeys.text_len in sample[Fields.stats]:
            return sample

        sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key])
        return sample

    def process(self, sample):
        if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len:
            return True
        else:
            return False

- After implemention, add it to the OP dictionary in the `__init__.py` file in `data_juicer/ops/filter/` directory.

```python
    from . import your_text_length_filter
```

## Testing for Your Operator

It's better to add corresponding tests for your own OPs. For `YourTextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below.

In [None]:
import unittest

from datasets import Dataset

# NOTE: use a new definition above
# from data_juicer.ops.filter.text_length_filter import YourTextLengthFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

class YourTextLengthFilterTest(DataJuicerTestCaseBase):

    def _run_text_length_filter(self, dataset: Dataset, target_list, op):
        if Fields.stats not in dataset.features:
            dataset = dataset.add_column(name=Fields.stats,
                                         column=[{}] * dataset.num_rows)
        dataset = dataset.map(op.compute_stats)
        dataset = dataset.filter(op.process)
        dataset = dataset.select_columns(column_names=['text'])
        res_list = dataset.to_list()
        print(res_list)
        self.assertEqual(res_list, target_list)

    def test_case1(self):

        ds_list = [{
            'text': '123'
        }, {
            'text': '12345'
        }, {
            'text': '1234567'
        }]
        tgt_list = [{
            'text': '12345'
        }]
        dataset = Dataset.from_list(ds_list)
        op = YourTextLengthFilter(min_len=4, max_len=6)
        self._run_text_length_filter(dataset, tgt_list, op)

if __name__ == '__main__':
    # NOTE: for run in Jupyter
    # unittest.main()
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

## Some Supports for Your Operator

- If Hugging Face models are used within an operator, you might want to leverage GPU acceleration. To achieve this, declare `self._accelerator = 'cuda'` in the constructor, and ensure that `compute_stats` and `process` methods accept an additional positional argument `rank`.
```python
    # ... (same as above)
    from data_juicer.utils.model_utils import get_model, prepare_model

    @OPERATORS.register_module('your_text_length_filter_with_cuda')
    class YourTextLengthFilterWithCuda(Filter):
        def __init__(self,
                    min_len: PositiveInt = 10,
                    max_len: PositiveInt = sys.maxsize,
                    *args,
                    **kwargs):
            # ... (same as above)
            self._accelerator = 'cuda'

        def compute_stats(self, sample, rank=None):
            # ... (some codes)
            if rank:
                model.to(f'cuda:{rank}')
            
        def process(self, sample, rank=None):
            # ... (same as above)
```

- If an operator takes one sample as input and produces multiple samples, the input and output need to be batched together by declaring `self._batched_op = True`. This feature is currently only supported by mapper operators.

In [None]:
import sys
from jsonargparse.typing import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper


@OPERATORS.register_module('your_batch_mapper')
class YourBatchMapper(Mapper):
    """A mapper operator processing batched samples."""

    def __init__(self,
                *args,
                **kwargs):

        super().__init__(*args, **kwargs)
        self._batched_op = True

    def process(self, samples):
        # reconstruct samples from "dict of lists" to "list of dicts"
        reconstructed_samples = []
        for i in range(len(samples[self.text_key])):
            reconstructed_samples.append(
                {key: samples[key][i]
                 for key in samples})

        # duplicate
        samples_after_generation = []
        for sample in reconstructed_samples:
            samples_after_generation.extend([sample, sample])

        # reconstruct samples from "list of dicts" to "dict of lists"
        keys = samples_after_generation[0].keys()
        res_samples = {}
        for key in keys:
            res_samples[key] = [s[key] for s in samples_after_generation]
        return samples

In [None]:
from data_juicer.core.data import NestedDataset

ds_list = [{
            'text': '123'
        }, {
            'text': '12345'
        }, {
            'text': '1234567'
        }]
print('unbatched samples', ds_list)
dataset = NestedDataset.from_list(ds_list)
op = YourBatchMapper()
dataset = dataset.map(op.process)
print(dataset.to_list())

- Call `transfer_filename` and `add_suffix_to_filename` to get unique paths for saving of extra datas, such as images and videos, to prevent data coverage and ensure process security.

In [None]:
from data_juicer.utils.file_utils import add_suffix_to_filename, transfer_filename
from data_juicer.ops.op_fusion import LOADED_VIDEOS
from data_juicer.ops.base_op import OPERATORS, Mapper
# ... (import some other libraries)

OP_NAME = 'your_video_split_by_key_frame_mapper'
@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
class YourVideoSplitByKeyFrameMapper(Mapper):
    def __init__(self,
             # ... (OP parameters)
             split_num = 1,
             *args,
             **kwargs):
        super().__init__(*args, **kwargs)
        self._init_parameters = self.remove_extra_parameters(locals())
        print(f'init parameters: {self._init_parameters}')
        self.split_num = split_num

    def process(self, sample):
        # ... (some codes)
        original_video_path = sample['videos'][0]
        base_video_path = transfer_filename(
                    original_video_path, OP_NAME, **self._init_parameters)
        print(f'base path: {base_video_path}')
        for count in range(self.split_num):
            split_video_path = add_suffix_to_filename(base_video_path,  f'_{count}')
            print(f'split {count} path: {split_video_path}')
        # ... (some codes)


In [None]:
sample = {'videos': ['./video.mp4']}
print('------ 2 splits ------')
op = YourVideoSplitByKeyFrameMapper(split_num=2)
op.process(sample)
print('------ 3 splits ------')
op = YourVideoSplitByKeyFrameMapper(split_num=3)
op.process(sample)

## Finish the Documents

 In order to facilitate the use of other users, we also need to update this new operator information to the corresponding documents.

- `configs/config_all.yaml`: this complete config file contains a list of all OPs and their arguments, serving as an
   important document for users to refer to all available OPs. Therefore, after adding the new OP, we need to add it to the process
   list (grouped by the OP type and sorted in alphabetical order):
   
   ```yaml
   ...
   - your_text_length_filter:                                # filter text with length out of specific range
       min_len: 10                                             # the min length of filter range
       max_len: 10000                                          # the max length of filter range
   ...
   ```

- `docs/Operators.md`: this doc maintains categorized lists of available OPs. We can add the information of new OP to the list
   of corresponding type of OPs (sorted in alphabetical order). At the same time, in the Overview section at the top of this doc,
   we also need to update the number of OPs for the corresponding OP type:

   ```markdown
   ## Overview
   ...
   | [ Filter ]( #filter )             |   21 (+1 HERE)   | Filters out low-quality samples                 |
   ...
   ## Filter <a name="filter"/>
   ...
   | suffix_filter                  | General | en, zh | Keeps samples with specified suffixes                                                      |
   | your_text_length_filter        | General | en, zh | Keeps samples with total text length within the specified range                            |
   | token_num_filter               | General | en, zh | Keeps samples with token count within the specified range                                  |
   ...
   ```

- `docs/Operators_ZH.md`: this doc is the Chinese version of the doc in 6.ii, so we need to update the Chinese content at
   the same positions.

- `docs/sphinx_doc/source/data_juicer.ops.{filter | mapper | deduplicator | selector}.rst`: this doc is the index of API reference. When the operator file name is modified or an operator file is added or deleted, the corresponding entries in the file need to be updated accordingly.

## Coding Style

We define our styles in `.pre-commit-config.yaml`. Before committing,
please install `pre-commit` tool to check and modify accordingly:

```shell
# ===========install pre-commit tool===========
pip install pre-commit

cd <path_to_data_juicer>
# install pre-commit script for data_juicer
pre-commit install


# ===========check all files===========
git add .
pre-commit run --all-files

# commit after all checking are passed
git commit -m "xxxx"
```

**Note**: We have configured pre-commit checks in github workflow. If this 
check in your PR fails, please locally ① ensure that the relevant 
dependencies of pre-commit are consistent with the project configuration 
(which can be completed through `pre-commit clean` and `pre-commit install`); 
and ② execute `pre-commit run --all-files` before push.


## (Optional) Make your OP fusible

- If the calculation process of some intermediate variables in the new OP is reused in other existing OPs, this new OP can be
added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `word_num_filter`
and `word_repetition_filter` need to split the input text into words)
- When opening OP fusion, these reused calculation processes and intermediate variables can be shared in the `context` between
OPs, thus reducing repeated calculations.
- OPs that contain common intermediate variables can be fused in OP fusion through the following steps:

1. (Optional) If a new intermediate variable is generated in the new OP, we need to add this new intermediate variable name to 
the `InterVars` class in `utils/constant.py`. In general, we need to add a prefix `DEFAULT_PREFIX` before the name.

```python
    class InterVars(object):
        # text
        lines = DEFAULT_PREFIX + 'lines'
        words = DEFAULT_PREFIX + 'words'  # add the new intermediate variable here
        ...
```

2. (Optional) We need to define a registry group in `ops/op_fusion.py` for the new intermediate variable in the 1st step, and add
this registry group to the registry group list that stores all groups of intermediate variables. This facilitates the OP Fusion module
to track OPs involving these intermediate variables.

```python
    ...
    # Type of intermediate vars
    # text
    INTER_LINES = Registry(InterVars.lines)
    INTER_WORDS = Registry(InterVars.words)  # define registry group for the new intermediate variable

    # images
    LOADED_IMAGES = Registry(InterVars.loaded_images)

    # all
    ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES]  # and add it to the registry group list
    ...
```

3. Before the OP class definition that involves the intermediate variable, register this OP in the registry group corresponding
to this intermediate variable, indicating that the intermediate variable may be calculated and used in this OP.

```python
    ...
    @OPERATORS.register_module(OP_NAME)
    @INTER_WORDS.register_module(OP_NAME)  # register this new OP into the registry group
    class WordNumFilter(Filter):
    ...
```

4. In the calculation process of this intermediate variable of the new OP, we can modify the calculation logic to:
   1. If the argument `context` is True, it means the OP fusion is opening, so we get the value of this intermediate variable 
   from `context` first, which has been calculated by the previous OPs.
   2. If this intermediate variable doesn't exist in the `context`, it means it's the first time to calculate this variable in this
   OP, so we need to define a unique key and use it to store the intermediate variable in the `context` for subsequent OPs after
   it's calculated by this new OP.
   3. If the argument `context` is False, just follow the normal calculation process.

```python
    # before modification
    ...
    tokenizer = get_model(self.model_key)
    words = get_words_from_document(
        sample[self.text_key],
        token_func=tokenizer.encode_as_pieces if tokenizer else None)
    ...        

    # after modification
    ...
    words_key = f'{InterVars.words}-{self.model_key}'
    if context and words_key in sample[Fields.context]:
        # get the value of intermediate variable from context directly
        words = sample[Fields.context][words_key]
    else:
        # normal calculation process
        tokenizer = get_model(self.model_key)
        words = get_words_from_document(
            sample[self.text_key],
            token_func=tokenizer.encode_as_pieces if tokenizer else None)
        if context:
            # After calculating the intermediate variable for the first time,
            # store it in the context for subsequent OPs.
            sample[Fields.context][words_key] = words
    ...
```