<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import os

if not "CHDIR_FLAG" in dir():
    os.chdir("../")
    CHDIR_FLAG = True
else:
    assert CHDIR_FLAG is True, CHDIR_FLAG

In [2]:
# -*- coding: utf-8 -*-
# @author : caoyang
# @email: caoyang@stu.sufe.edu.cn

import os
import gc
import torch

from settings import DATA_DIR, LOG_DIR, MODEL_ROOT, DATA_SUMMARY, MODEL_SUMMARY

from src.datasets import RaceDataset, DreamDataset, SquadDataset, HotpotqaDataset, MusiqueDataset, TriviaqaDataset
from src.models import RobertaLargeFinetunedRace
from src.tools.easy import initialize_logger, terminate_logger

def test_yield_batch():
	# data_dir = r"D:\data"	# Lab PC
	# data_dir = r"D:\resource\data"	# Region Laptop
	data_dir = DATA_DIR	# default
	data_dir_race = DATA_SUMMARY["RACE"]["path"]
	data_dir_dream = DATA_SUMMARY["DREAM"]["path"]
	data_dir_squad = DATA_SUMMARY["SQuAD"]["path"]
	data_dir_hotpotqa = DATA_SUMMARY["HotpotQA"]["path"]
	data_dir_musique = DATA_SUMMARY["Musique"]["path"]
	data_dir_triviaqa = DATA_SUMMARY["TriviaQA"]["path"]
		
	# RACE
	def _test_race():
		print(_test_race.__name__)
		dataset = RaceDataset(data_dir=data_dir_race)
		for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"]):
			pass
	# DREAM
	def _test_dream():
		print(_test_dream.__name__)
		dataset = DreamDataset(data_dir=data_dir_dream)
		for batch in dataset.yield_batch(batch_size=2, types=["train", "dev"]):
			pass
	# SQuAD
	def _test_squad():
		print(_test_squad.__name__)
		dataset = SquadDataset(data_dir=data_dir_squad)
		versions = ["1.1"]
		types = ["train", "dev"]
		for version in versions:
			for type_ in types:
				for i, batch in enumerate(dataset.yield_batch(batch_size=2, version=version, type_=type_)):
					if i > 5:
						break
					print(batch)
	# HotpotQA
	def _test_hotpotqa():
		print(_test_hotpotqa.__name__)
		dataset = HotpotqaDataset(data_dir=data_dir_hotpotqa)
		filenames = ["hotpot_train_v1.1.json",
					 "hotpot_dev_distractor_v1.json",
					 "hotpot_dev_fullwiki_v1.json",
					 "hotpot_test_fullwiki_v1.json",
					 ]
		for filename in filenames:
			for i, batch in enumerate(dataset.yield_batch(batch_size=2, filename=filename)):
				if i > 5:
					break
				print(batch)
	# Musique
	def _test_musique():
		print(_test_musique.__name__)
		batch_size = 2
		dataset = MusiqueDataset(data_dir=data_dir_musique)
		types = ["train", "dev", "test"]
		categories = ["ans", "full"]
		answerables = [True, False]
		for type_ in types:
			for category in categories:
				if category == "full":
					for answerable in answerables:
						print(f"======== {type_} - {category} - {answerable} ========")
						for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, answerable)):
							if i > 5:
								break
							print(batch)
				else:
					print(f"======== {type_} - {category} ========")
					for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category)):
						if i > 5:
							break
						print(batch)				
								
	# TriviaQA
	def _test_triviaqa():
		print(_test_triviaqa.__name__)
		batch_size = 2
		dataset = TriviaqaDataset(data_dir=data_dir_triviaqa)
		types = ["verified", "train", "dev", "test"]
		categories = ["web", "wikipedia"]
		for type_ in types:
			for category in categories:
				print(f"======== {type_} - {category} ========")
				for i, batch in enumerate(dataset.yield_batch(batch_size, type_, category, False)):
					if i > 5:
						break
					print(batch)	
		gc.collect()
		for type_ in ["train", "dev", "test"]:
			print(f"======== {type_} - unfiltered ========")
			for i, batch in enumerate(dataset.yield_batch(batch_size, type_, "web", True)):
				if i > 5:
					break
				print(batch)

	# Test		
	logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
	# _test_race()
	# _test_dream()
	# _test_squad()
	_test_hotpotqa()
	# _test_musique()
	# _test_triviaqa()
	terminate_logger(logger)


def test_generate_model_inputs():
	
	def _test_race():
		print(_test_race.__name__)
		data_dir = DATA_SUMMARY[RaceDataset.dataset_name]["path"]
		model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
		dataset = RaceDataset(data_dir)
		model = RobertaLargeFinetunedRace(model_path, device="cpu")

		for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"], difficulties=["high"])):
			model_inputs = RaceDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
			print(model_inputs)
			print('-' * 32)
			model_inputs = model.generate_model_inputs(batch, max_length=32)
			print(model_inputs)
			print('#' * 32)
			if i > 5:
				break

	def _test_dream():
		print(_test_dream.__name__)
		data_dir = DATA_SUMMARY[DreamDataset.dataset_name]["path"] 
		model_path = MODEL_SUMMARY[RobertaLargeFinetunedRace.model_name]["path"]
		dataset = DreamDataset(data_dir)
		model = RobertaLargeFinetunedRace(model_path, device="cpu")
		for i, batch in enumerate(dataset.yield_batch(batch_size=2, types=["train", "dev"])):
			model_inputs = DreamDataset.generate_model_inputs(batch, model.tokenizer, model.model_name, max_length=32)
			print(model_inputs)
			print('-' * 32)
			model_inputs = model.generate_model_inputs(batch, max_length=32)
			print(model_inputs)
			print('#' * 32)
			if i > 5:
				break
	
	logger = initialize_logger(os.path.join(LOG_DIR, "sanity.log"), 'w')
	_test_race()
	# _test_dream()
	terminate_logger(logger)


if __name__ == "__main__":
	# test_yield_batch()
	test_generate_model_inputs()


  from .autonotebook import tqdm as notebook_tqdm
2024-09-03 16:43:27,648 | base.py | INFO | Check data directory: D:\resource\data\RACE
2024-09-03 16:43:27,649 | base.py | INFO | √ ./train/high/
2024-09-03 16:43:27,650 | base.py | INFO | √ ./train/middle/
2024-09-03 16:43:27,650 | base.py | INFO | √ ./dev/high/
2024-09-03 16:43:27,651 | base.py | INFO | √ ./dev/middle/
2024-09-03 16:43:27,651 | base.py | INFO | √ ./test/high/
2024-09-03 16:43:27,652 | base.py | INFO | √ ./test/middle/


_test_race


TypeError: __init__() takes 2 positional arguments but 3 were given