<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#测试dataset.yield_batch" data-toc-modified-id="测试dataset.yield_batch-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>测试dataset.yield_batch</a></span></li><li><span><a href="#测试model.generate_model_inputs" data-toc-modified-id="测试model.generate_model_inputs-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>测试model.generate_model_inputs</a></span></li><li><span><a href="#测试pipeline.easy_inference_pipeline" data-toc-modified-id="测试pipeline.easy_inference_pipeline-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>测试pipeline.easy_inference_pipeline</a></span></li></ul></div>

In [3]:
import os

# 切换工作目录
if not "CHDIR_FLAG" in dir():
    os.chdir("../")
    CHDIR_FLAG = True
else:
    assert CHDIR_FLAG is True, CHDIR_FLAG

# 导入必要的包
import gc
import torch

from settings import DATA_DIR, LOG_DIR, MODEL_ROOT, DATA_SUMMARY, MODEL_SUMMARY

from src.datasets import RaceDataset, DreamDataset, SquadDataset, HotpotqaDataset, MusiqueDataset, TriviaqaDataset
from src.models import RobertaLargeFinetunedRace, LongformerLarge4096AnsweringRace, RobertaBaseSquad2, Chatglm6bInt4, Chatglm26bInt4
from src.pipelines import RacePipeline, DreamPipeline, SquadPipeline
from src.tools.easy import initialize_logger, terminate_logger

print(f"当前工作目录: {os.getcwd()}")

当前工作目录: D:\code\python\project\caoyang\project_019_llm_reasoning\easyqa


# 测试dataset.yield_batch

In [None]:
def test_yield_batch():
    # data_dir = r"D:\data"	# Lab PC
    # data_dir = r"D:\resource\data"	# Region Laptop
    data_dir = DATA_DIR	# default
    data_dir_race = DATA_SUMMARY["RACE"]["path"]
    data_dir_dream = DATA_SUMMARY["DREAM"]["path"]
    data_dir_squad = DATA_SUMMARY["SQuAD"]["path"]
    data_dir_hotpotqa = DATA_SUMMARY["HotpotQA"]["path"]
    data_dir_musique = DATA_SUMMARY["Musique"]["path"]
    data_dir_triviaqa = DATA_SUMMARY["TriviaQA"]["path"]

    # RACE
    def _test_race():
        print(_test_race.__name__)
        dataset = RaceDataset(data_dir=data_dir_race)
        for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"]):
            pass
    # DREAM
    def _test_dream():
        print(_test_dream.__name__)
        dataset = DreamDataset(data_dir=data_dir_dream)
        for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"]):
            pass
    # SQuAD
    def _test_squad():
        print(_test_squad.__name__)
        dataset = SquadDataset(data_dir=data_dir_squad)
        versions = ["1.1"]
        types = ["train", "dev"]
        for version in versions:
            for type_ in types:
                for i, batch in enumerate(dataset.yield_batch(batch_size=2, type_=type_, version=version)):
                    if i > 5:
                        break
                    print(batch)
    # HotpotQA
    def _test_hotpotqa():
        print(_test_hotpotqa.__name__)
        dataset = HotpotqaDataset(data_dir=data_dir_hotpotqa)
        filenames = ["hotpot_train_v1.1.json",
                     "hotpot_dev_distractor_v1.json",
                     "hotpot_dev_fullwiki_v1.json",
                     "hotpot_test_fullwiki_v1.json",
                     ]
        for filename in filenames:
            for i, batch in enumerate(dataset.yield_batch(batch_size=2, filename=filename)):
                if i > 5:
                    break
                print(batch)
    # Musique
    def _test_musique():
        print(_test_musique.__name__)
        batch_size = 2
        dataset = MusiqueDataset(data_dir=data_dir_musique)
        types = ["train", "dev", "test"]
        categories = ["ans", "full"]
        answerables = [True, False]
        for type_ in types:
            for category in categories:
                if category == "full":
                    for answerable in answerables:
                        print(f"======== {type_} - {category} - {answerable} ========")
                        for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, answerable)):
                            if i > 5:
                                break
                            print(batch)
                else:
                    print(f"======== {type_} - {category} ========")
                    for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category)):
                        if i > 5:
                            break
                        print(batch)

    # TriviaQA
    def _test_triviaqa():
        print(_test_triviaqa.__name__)
        n = 1
        batch_size = 2
        dataset = TriviaqaDataset(data_dir=data_dir_triviaqa)
        types = ["verified", "train", "dev", "test"]
        categories = ["web", "wikipedia"]
        for type_ in types:
            for category in categories:
                print(f"======== {type_} - {category} ========")
                for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, False)):
                    if i > n:
                        break
                    print(batch[0]["question"], batch[0]["answers"])
        gc.collect()
        for type_ in types[1:]:
            print(f"======== {type_} - unfiltered ========")
            for i, batch in enumerate(dataset.yield_batch(batch_size, type_, "web", True)):
                if i > n:
                    break
                print(batch[0]["question"], batch[0]["answers"])

    # Test
    logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
#     _test_race()
#     _test_dream()
#     _test_squad()
#     _test_hotpotqa()
#     _test_musique()
    _test_triviaqa()
    terminate_logger(logger)

test_yield_batch()

# 测试model.generate_model_inputs

In [4]:
def test_generate_model_inputs():

    def _test_race():
        print(_test_race.__name__)
        data_dir = DATA_SUMMARY[RaceDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
        # model_path = MODEL_SUMMARY[LongformerLarge4096AnsweringRace.model_name]["path"]
        dataset = RaceDataset(data_dir)
        model = RobertaLargeFinetunedRace(model_path, device="cpu")
        # model = LongformerLarge4096AnsweringRace(model_path, device="cpu")

        for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"])):
            model_inputs = RaceDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_dream():
        print(_test_dream.__name__)
        data_dir = DATA_SUMMARY[DreamDataset.dataset_name]["path"] 
        model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
        dataset = DreamDataset(data_dir)
        model = RobertaLargeFinetunedRace(model_path, device="cpu")
        for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"])):
            model_inputs = DreamDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_squad():
        print(_test_squad.__name__)
        data_dir = DATA_SUMMARY[SquadDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[RobertaBaseSquad2.model_name]["path"]
        dataset = SquadDataset(data_dir)
        model = RobertaBaseSquad2(model_path, device="cpu")

        for i, batch in enumerate(dataset.yield_batch(batch_size=2, type_="dev", version="1.1")):
            model_inputs = SquadDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_hotpotqa():
        print(_test_hotpotqa.__name__)
        data_dir = DATA_SUMMARY[HotpotqaDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[Chatglm26bInt4.model_name]["path"]
        dataset = HotpotqaDataset(data_dir)
        model = Chatglm6bInt4(model_path, device="cuda")
        for i, batch in enumerate(dataset.yield_batch(batch_size=2, filename="dev_distractor_v1.json")):
            model_inputs = HotpotqaDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=512)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break		

    logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
    _test_race()
    _test_dream()
    _test_squad()
    _test_hotpotqa()
    terminate_logger(logger)

test_generate_model_inputs()

2024-10-08 20:54:08,605 | base.py | INFO | Check data directory: D:\resource\data\RACE
2024-10-08 20:54:08,606 | base.py | INFO | √ ./train/high/
2024-10-08 20:54:08,607 | base.py | INFO | √ ./train/middle/
2024-10-08 20:54:08,607 | base.py | INFO | √ ./dev/high/
2024-10-08 20:54:08,607 | base.py | INFO | √ ./dev/middle/
2024-10-08 20:54:08,607 | base.py | INFO | √ ./test/high/
2024-10-08 20:54:08,607 | base.py | INFO | √ ./test/middle/
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the

_test_race
{'input_ids': tensor([[[    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2,   133,  1623,  3829,
           3482,   142,  1437,  1437,    37,    34,   203,   418,     4,  1437,
            479,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2,   133,  1623,  3829,
           3482,   142,  1437,  1437,    37,  3829,     5,  6464,     4,  1437,
            479,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2,   133,  1623,  3829,
           3482,   142,  1437,  1437,    37,  3829,     7,  8933,     5,   850,
            227,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2,   133,  1623,  3829,
           3482,   1

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

{'input_ids': tensor([[[    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2, 35693,    64,    75,
            109,     5,  3482,   157,   142,  1437,  1437,    37,    16,   664,
           1437,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2, 35693,    64,    75,
            109,     5,  3482,   157,   142,  1437,  1437,    37,    16, 11640,
             12,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2, 35693,    64,    75,
            109,     5,  3482,   157,   142,  1437,  1437,    37,   747, 13585,
             39,     2],
         [    0,  2387,  1623,    16,    10,  2421, 14172,  5961,     4,    91,
           6138,     7,   356,    23,   383,     2,     2, 35693,    64,    75,
            109,     5,  3482, 

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


{'input_ids': tensor([[[    0, 11475,  2115,    10,    86,     6,    89,    21,    10, 20875,
             54,   770,     7,  2364,    55,     2,     2,  2264,    18,  1437,
              5,  3184,     9,     5,  9078,   116, 42516,   111,   625, 36237,
             12,     2],
         [    0, 11475,  2115,    10,    86,     6,    89,    21,    10, 20875,
             54,   770,     7,  2364,    55,     2,     2,  2264,    18,  1437,
              5,  3184,     9,     5,  9078,   116, 42295,    12,  1246,   111,
            625,     2],
         [    0, 11475,  2115,    10,    86,     6,    89,    21,    10, 20875,
             54,   770,     7,  2364,    55,     2,     2,  2264,    18,  1437,
              5,  3184,     9,     5,  9078,   116,  3718,    12, 45260,   111,
          46781,     2],
         [    0, 11475,  2115,    10,    86,     6,    89,    21,    10, 20875,
             54,   770,     7,  2364,    55,     2,     2,  2264,    18,  1437,
              5,  3184,     9, 

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
2024-10-08 20:54:09,336 | base.py | INFO | Check data directory: D:\resource\data\SQuAD
2024-10-

{'input_ids': tensor([[[    0,   448,    35,    38,   524,  2811,  6614,   127,  7950,  1380,
              4,    38,   524,    45,   442,     2,     2,  2264,   473,     5,
            313,  3608,     5,   693,   109,   116, 15850,    69,  7950,  3254,
              4,     2],
         [    0,   448,    35,    38,   524,  2811,  6614,   127,  7950,  1380,
              4,    38,   524,    45,   442,     2,     2,  2264,   473,     5,
            313,  3608,     5,   693,   109,   116,  4624,    10,    55,  2679,
           1380,     2],
         [    0,   448,    35,    38,   524,  2811,  6614,   127,  7950,  1380,
              4,    38,   524,    45,   442,     2,     2,  2264,   473,     5,
            313,  3608,     5,   693,   109,   116,  5603,    69,  7950,  1380,
              4,     2]],

        [[    0,   771,    35,  2647,     6,    38,   437,  6023,   127,  6836,
            965,    75,     7,   110,  5840,     2,     2,  2264,   473,     5,
            313,   206,     9

2024-10-08 20:54:09,549 | base.py | INFO | Check data directory: D:\resource\data\HotpotQA
2024-10-08 20:54:09,549 | base.py | INFO | √ ./hotpot_dev_distractor_v1.json
2024-10-08 20:54:09,549 | base.py | INFO | √ ./hotpot_dev_fullwiki_v1.json
2024-10-08 20:54:09,549 | base.py | INFO | √ ./hotpot_test_fullwiki_v1.json
2024-10-08 20:54:09,555 | base.py | INFO | √ ./hotpot_train_v1.1.json


{'input_ids': tensor([[    0, 32251,  1485,   165,  4625,     5,  9601,    23,  1582,  2616,
           654,   116,     2,     2, 16713,  1215,   387, 20734,  1215,  1096,
         50118, 16713,  2616,   654,    21,    41,   470,  1037,   177,     7,
          3094,     2],
        [    0, 32251,  1485,   165,  4625,     5, 11119,    23,  1582,  2616,
           654,   116,     2,     2, 16713,  1215,   387, 20734,  1215,  1096,
         50118, 16713,  2616,   654,    21,    41,   470,  1037,   177,     7,
          3094,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]])}
--------------------------------
{'input_ids': tensor([[    0, 32251,  1485,   165,  4625,     5,  9601,    23,  1582,  2616,
           654,   116,     2,     2, 16713,  1215,   387, 20734,  1215,  1096,
       

{'input_ids': tensor([[64790, 64792,   809,  ...,   267,  3764, 30953],
        [64790, 64792,   809,  ...,    13,  1036,  1147]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'position_ids': tensor([[  0,   1,   2,  ..., 509, 510, 511],
        [  0,   1,   2,  ..., 509, 510, 511]])}
--------------------------------
{'input_ids': tensor([[64790, 64792,   809,   383,   260,  7486, 30932,   344,   720,   289,
           950,   267,  1845,  4177,  7724,   293,  7511,   267,  3238,   289,
           267,  2021,  3040, 30930,    13,   986,  8192, 30954,    13, 11355,
         30910, 30939],
        [64790, 64792,   809,   383,   260,  7486, 30932,   344,   720,   289,
           950,   267,  1845,  4177,  7724,   293,  7511,   267,  3238,   289,
           267,  2021,  3040, 30930,    13,   986,  8192, 30954,    13, 11355,
         30910, 30939]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,


{'input_ids': tensor([[64790, 64792,   809,   383,   260,  7486, 30932,   344,   720,   289,
           950,   267,  1845,  4177,  7724,   293,  7511,   267,  3238,   289,
           267,  2021,  3040, 30930,    13,   986,  8192, 30954,    13, 11355,
         30910, 30939],
        [64790, 64792,   809,   383,   260,  7486, 30932,   344,   720,   289,
           950,   267,  1845,  4177,  7724,   293,  7511,   267,  3238,   289,
           267,  2021,  3040, 30930,    13,   986,  8192, 30954,    13, 11355,
         30910, 30939]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]]), 'position_ids': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 

# 测试pipeline.easy_inference_pipeline

In [None]:
def test_inference_pipeline():

    def _test_race():
        race_pipeline = RacePipeline()
        pipeline = race_pipeline.easy_inference_pipeline(
            dataset_class_name = "RaceDataset",
            model_class_name = "RobertaLargeFinetunedRace",
            batch_size = 2,
            dataset_kwargs = {"types": ["train"], "difficulties": ["high", "middle"]},
            model_kwargs = {"max_length": 512},
        )

    def _test_squad():
        squad_pipeline = SquadPipeline()
        pipeline = squad_pipeline.easy_inference_pipeline(
            dataset_class_name = "SquadDataset",
            model_class_name = "RobertaBaseSquad2",
            batch_size = 2,
            dataset_kwargs = {"type_": "train", "version": "2.0"},
            model_kwargs = {"max_length": 512},
        )

    # logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
    _test_race()
    # _test_squad()
    # terminate_logger(logger)
    
    
test_inference_pipeline()