<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#测试dataset.yield_batch" data-toc-modified-id="测试dataset.yield_batch-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>测试dataset.yield_batch</a></span></li><li><span><a href="#测试model.generate_model_inputs" data-toc-modified-id="测试model.generate_model_inputs-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>测试model.generate_model_inputs</a></span></li><li><span><a href="#测试pipeline.easy_inference_pipeline" data-toc-modified-id="测试pipeline.easy_inference_pipeline-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>测试pipeline.easy_inference_pipeline</a></span></li><li><span><a href="#测试modules" data-toc-modified-id="测试modules-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>测试modules</a></span></li></ul></div>

In [1]:
import os

# 切换工作目录
if not "CHDIR_FLAG" in dir():
    os.chdir("../")
    CHDIR_FLAG = True
else:
    assert CHDIR_FLAG is True, CHDIR_FLAG

# 导入必要的包
import gc
import torch

from configs import ModuleConfig
from settings import DATA_DIR, LOG_DIR, MODEL_ROOT, DATA_SUMMARY, MODEL_SUMMARY

from src.datasets import RaceDataset, DreamDataset, SquadDataset, HotpotqaDataset, MusiqueDataset, TriviaqaDataset
from src.models import RobertaLargeFinetunedRace, LongformerLarge4096AnsweringRace, RobertaBaseSquad2, Chatglm6bInt4, Chatglm26bInt4
from src.pipelines import RacePipeline, DreamPipeline, SquadPipeline
from src.tools.easy import initialize_logger, terminate_logger, load_args
from src.modules import CoMatch

from src.tests import comatch_testscript, dcmn_testscript, duma_testscript, hrca_testscript, attention_testscript
print(f"当前工作目录: {os.getcwd()}")

  from .autonotebook import tqdm as notebook_tqdm


当前工作目录: D:\code\python\project\caoyang\project_019_llm_reasoning\easyqa


# 测试dataset.yield_batch

In [None]:
def test_yield_batch():
    # data_dir = r"D:\data"	# Lab PC
    # data_dir = r"D:\resource\data"	# Region Laptop
    data_dir = DATA_DIR	# default
    data_dir_race = DATA_SUMMARY["RACE"]["path"]
    data_dir_dream = DATA_SUMMARY["DREAM"]["path"]
    data_dir_squad = DATA_SUMMARY["SQuAD"]["path"]
    data_dir_hotpotqa = DATA_SUMMARY["HotpotQA"]["path"]
    data_dir_musique = DATA_SUMMARY["Musique"]["path"]
    data_dir_triviaqa = DATA_SUMMARY["TriviaQA"]["path"]

    # RACE
    def _test_race():
        print(_test_race.__name__)
        dataset = RaceDataset(data_dir=data_dir_race)
        for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"]):
            pass
    # DREAM
    def _test_dream():
        print(_test_dream.__name__)
        dataset = DreamDataset(data_dir=data_dir_dream)
        for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"]):
            pass
    # SQuAD
    def _test_squad():
        print(_test_squad.__name__)
        dataset = SquadDataset(data_dir=data_dir_squad)
        versions = ["1.1"]
        types = ["train", "dev"]
        for version in versions:
            for type_ in types:
                for i, batch in enumerate(dataset.yield_batch(batch_size=2, type_=type_, version=version)):
                    if i > 5:
                        break
                    print(batch)
    # HotpotQA
    def _test_hotpotqa():
        print(_test_hotpotqa.__name__)
        dataset = HotpotqaDataset(data_dir=data_dir_hotpotqa)
        filenames = ["hotpot_train_v1.1.json",
                     "hotpot_dev_distractor_v1.json",
                     "hotpot_dev_fullwiki_v1.json",
                     "hotpot_test_fullwiki_v1.json",
                     ]
        for filename in filenames:
            for i, batch in enumerate(dataset.yield_batch(batch_size=2, filename=filename)):
                if i > 5:
                    break
                print(batch)
    # Musique
    def _test_musique():
        print(_test_musique.__name__)
        batch_size = 2
        dataset = MusiqueDataset(data_dir=data_dir_musique)
        types = ["train", "dev", "test"]
        categories = ["ans", "full"]
        answerables = [True, False]
        for type_ in types:
            for category in categories:
                if category == "full":
                    for answerable in answerables:
                        print(f"======== {type_} - {category} - {answerable} ========")
                        for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, answerable)):
                            if i > 5:
                                break
                            print(batch)
                else:
                    print(f"======== {type_} - {category} ========")
                    for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category)):
                        if i > 5:
                            break
                        print(batch)

    # TriviaQA
    def _test_triviaqa():
        print(_test_triviaqa.__name__)
        n = 1
        batch_size = 2
        dataset = TriviaqaDataset(data_dir=data_dir_triviaqa)
        types = ["verified", "train", "dev", "test"]
        categories = ["web", "wikipedia"]
        for type_ in types:
            for category in categories:
                print(f"======== {type_} - {category} ========")
                for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, False)):
                    if i > n:
                        break
                    print(batch[0]["question"], batch[0]["answers"])
        gc.collect()
        for type_ in types[1:]:
            print(f"======== {type_} - unfiltered ========")
            for i, batch in enumerate(dataset.yield_batch(batch_size, type_, "web", True)):
                if i > n:
                    break
                print(batch[0]["question"], batch[0]["answers"])

    # Test
    logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
#     _test_race()
#     _test_dream()
#     _test_squad()
#     _test_hotpotqa()
#     _test_musique()
    _test_triviaqa()
    terminate_logger(logger)

test_yield_batch()

# 测试model.generate_model_inputs

In [None]:
def test_generate_model_inputs():

    def _test_race():
        print(_test_race.__name__)
        data_dir = DATA_SUMMARY[RaceDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
        # model_path = MODEL_SUMMARY[LongformerLarge4096AnsweringRace.model_name]["path"]
        dataset = RaceDataset(data_dir)
        model = RobertaLargeFinetunedRace(model_path, device="cpu")
        # model = LongformerLarge4096AnsweringRace(model_path, device="cpu")

        for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"])):
            model_inputs = RaceDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_dream():
        print(_test_dream.__name__)
        data_dir = DATA_SUMMARY[DreamDataset.dataset_name]["path"] 
        model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
        dataset = DreamDataset(data_dir)
        model = RobertaLargeFinetunedRace(model_path, device="cpu")
        for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"])):
            model_inputs = DreamDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_squad():
        print(_test_squad.__name__)
        data_dir = DATA_SUMMARY[SquadDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[RobertaBaseSquad2.model_name]["path"]
        dataset = SquadDataset(data_dir)
        model = RobertaBaseSquad2(model_path, device="cpu")

        for i, batch in enumerate(dataset.yield_batch(batch_size=2, type_="dev", version="1.1")):
            model_inputs = SquadDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break

    def _test_hotpotqa():
        print(_test_hotpotqa.__name__)
        data_dir = DATA_SUMMARY[HotpotqaDataset.dataset_name]["path"]
        model_path = MODEL_SUMMARY[Chatglm26bInt4.model_name]["path"]
        dataset = HotpotqaDataset(data_dir)
        model = Chatglm6bInt4(model_path, device="cuda")
        for i, batch in enumerate(dataset.yield_batch(batch_size=2, filename="dev_distractor_v1.json")):
            model_inputs = HotpotqaDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=512)
            print(model_inputs)
            print('-' * 32)
            model_inputs = model.generate_model_inputs(batch, max_length=32)
            print(model_inputs)
            print('#' * 32)
            if i > 5:
                break		

    logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
    _test_race()
    _test_dream()
    _test_squad()
    _test_hotpotqa()
    terminate_logger(logger)

test_generate_model_inputs()

# 测试pipeline.easy_inference_pipeline

In [None]:
def test_inference_pipeline():

    def _test_race():
        race_pipeline = RacePipeline()
        pipeline = race_pipeline.easy_inference_pipeline(
            dataset_class_name = "RaceDataset",
            model_class_name = "RobertaLargeFinetunedRace",
            batch_size = 2,
            dataset_kwargs = {"types": ["train"], "difficulties": ["high", "middle"]},
            model_kwargs = {"max_length": 512},
        )

    def _test_squad():
        squad_pipeline = SquadPipeline()
        pipeline = squad_pipeline.easy_inference_pipeline(
            dataset_class_name = "SquadDataset",
            model_class_name = "RobertaBaseSquad2",
            batch_size = 2,
            dataset_kwargs = {"type_": "train", "version": "2.0"},
            model_kwargs = {"max_length": 512},
        )

    # logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
    _test_race()
    # _test_squad()
    # terminate_logger(logger)
    
    
test_inference_pipeline()

# 测试modules

In [None]:
comatch_testscript()

In [None]:
dcmn_testscript()