In [None]:
import datetime
import json
import logging
import os
import pathlib
import sys
import pydoc
import time

import datasets
import fire
import torch
import yaml

import transformers

In [None]:
def load_dataset(test_file: str):
    return datasets.load_dataset("json", data_files=test_file, split="train")

def generate_text_fn(model, tokenizer, args: dict = {}):
    def generate_text(prompt: str) -> str:
        input_tokens = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_length = input_tokens.input_ids.shape[1]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_tokens["input_ids"],
                attention_mask=input_tokens["attention_mask"],
                return_dict_in_generate=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                **args,
            )
            output_tokens = outputs.sequences[0, input_length:-1]

        return tokenizer.decode(output_tokens)

    return generate_text

In [None]:
# 設定ファイルの読み込み
config_file = "../config/config_generate_texts_llama_jplaw_v04.yaml"
with open(config_file, "r") as i_:
    config = yaml.safe_load(i_)
    
print(config)

In [None]:
model_name = config["model"]["pretrained_model_name_or_path"]
    
# logger.info(f"model_name: {model_name}")
print(f"model_name: {model_name}")

# 出力先ディレクトリの設定
output_dir = pathlib.Path(os.path.expandvars(config["outputs"]["dirname"]))
output_dir.mkdir(parents=True, exist_ok=True)

# 出力先ディレクトリに、最終的な設定値を保存しておく
with open(output_dir.joinpath("config.yaml"), "w") as o_:
    yaml.dump(config, o_)

# トークナイザのロード
# logger.info(f"load tokenizer")
print(f"load tokenizer")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    # logger.info(f"set pad_token to {tokenizer.pad_token}")
    

# モデルのロード
# logger.info(f"load model: {config['model']}")
print(f"load model: {config['model']}")
# torch_dtypeを文字列から型に変換しておく
if "torch_dtype" in config["model"]:
    config["model"]["torch_dtype"] = pydoc.locate(config["model"]["torch_dtype"])
model = transformers.AutoModelForCausalLM.from_pretrained(**config["model"])

generate_text = generate_text_fn(model, tokenizer, config["generate"])
# generate_text = generate_text_fn(model, tokenizer)
output_file = output_dir.joinpath(config["outputs"]["filename"])

In [None]:
!nvidia-smi

In [None]:
start_time = time.time()

for j in range(1, 4):
    # logger.info(f"processing R{j} data.")
    print(f"processing R{j} data.")
    dataset = load_dataset(test_file=config["data"][f"test_file{j}"])

    with open(output_file, "a") as o_:
        for i, data in enumerate(dataset):
            if data["subject"] != "民法":
                continue
            else:
                # logger.info(f"processing {i}th data.")
                print(f"processing {i}th data.")
                year = data["year"]
                subject = data["subject"]
                Q_no = data["Q_no"]
                prompt = config["prompt"].format_map(data)
                generated = generate_text(prompt)
                json.dump(dict(year=year, subject=subject, Q_no=Q_no, prompt=prompt, complete=generated), o_, ensure_ascii=False)
                o_.write("\n")
                o_.flush()
                
duration = time.time() - start_time

In [None]:
print(f"duration: {duration}")