In [1]:
import datetime
import json
import logging
import os
import pathlib
import sys
import pydoc
import time

import datasets
import fire
import torch
import yaml

import transformers

In [2]:
def load_dataset(test_file: str):
    return datasets.load_dataset("json", data_files=test_file, split="train")

def generate_text_fn(model, tokenizer, args: dict = {}):
    def generate_text(prompt: str) -> str:
        input_tokens = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_length = input_tokens.input_ids.shape[1]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_tokens["input_ids"],
                attention_mask=input_tokens["attention_mask"],
                return_dict_in_generate=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                **args,
            )
            output_tokens = outputs.sequences[0, input_length:-1]

        return tokenizer.decode(output_tokens)

    return generate_text

In [3]:
# 設定ファイルの読み込み
config_file = "../config/config_generate_texts_llama_jplaw_v04.yaml"
with open(config_file, "r") as i_:
    config = yaml.safe_load(i_)
    
print(config)

{'model': {'pretrained_model_name_or_path': '/groups/4/gcb50389/pretrained/llama2-HF/Llama-2-7b-hf', 'device_map': 'auto', 'trust_remote_code': True}, 'generate': {'max_length': 4096, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.05}, 'data': {'test_file1': '../data/R1_short_answer_exam.json', 'test_file2': '../data/R2_short_answer_exam.json', 'test_file3': '../data/R3_short_answer_exam.json'}, 'outputs': {'dirname': '../outputs', 'filename': 'generated_text_05.jsonl'}, 'prompt': '日本の法律に基づき、設問に対する答えを回答してください。\n\n設問：意思表示に関する次のアからオまでの各記述のうち，判例の趣旨に照らし正しいものを組み合わせたものは，後記１から５までのうちどれか。ア．土地の仮装譲受人が当該土地上に建物を建築してこれを他人に賃貸した場合，その建物賃借人は，民法第９４条第２項の「第三者」に当たらない。イ．強迫による意思表示の取消しが認められるためには，表意者が，畏怖の結果，完全に意思の自由を失ったことを要する。ウ．Ａを欺罔してその農地を買い受けたＢが，農地法上の許可を停止条件とする所有権移転の仮登記を得た上で，当該売買契約上の権利をＣに譲渡して当該仮登記移転の付記登記をした場合には，Ｃは民法第９６条第３項の「第三者」に当たる。エ．協議離婚に伴う財産分与契約において，分与者は，自己に譲渡所得税が課されることを知らず，課税されないとの理解を当然の前提とし，かつ，その旨を黙示的に表示していた場合であっても，財産分与契約について錯誤による無効を主張することはできない。オ．特定の意思表示が記載された内容証明郵便が受取人不在のために配達することができず，留置期間の経過に

In [4]:
model_name = config["model"]["pretrained_model_name_or_path"]
    
# logger.info(f"model_name: {model_name}")
print(f"model_name: {model_name}")

# 出力先ディレクトリの設定
output_dir = pathlib.Path(os.path.expandvars(config["outputs"]["dirname"]))
output_dir.mkdir(parents=True, exist_ok=True)

# 出力先ディレクトリに、最終的な設定値を保存しておく
with open(output_dir.joinpath("config.yaml"), "w") as o_:
    yaml.dump(config, o_)

# トークナイザのロード
# logger.info(f"load tokenizer")
print(f"load tokenizer")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    # logger.info(f"set pad_token to {tokenizer.pad_token}")
    

# モデルのロード
# logger.info(f"load model: {config['model']}")
print(f"load model: {config['model']}")
# torch_dtypeを文字列から型に変換しておく
if "torch_dtype" in config["model"]:
    config["model"]["torch_dtype"] = pydoc.locate(config["model"]["torch_dtype"])
model = transformers.AutoModelForCausalLM.from_pretrained(**config["model"])

generate_text = generate_text_fn(model, tokenizer, config["generate"])
# generate_text = generate_text_fn(model, tokenizer)
output_file = output_dir.joinpath(config["outputs"]["filename"])

model_name: /groups/4/gcb50389/pretrained/llama2-HF/Llama-2-7b-hf
load tokenizer


Using pad_token, but it is not set yet.


load model: {'pretrained_model_name_or_path': '/groups/4/gcb50389/pretrained/llama2-HF/Llama-2-7b-hf', 'device_map': 'auto', 'trust_remote_code': True}
[2023-11-09 02:09:57,470] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2023-11-09 02:09:59.085990: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-09 02:09:59.086040: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-09 02:09:59.086076: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-09 02:09:59.100257: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
!nvidia-smi

Thu Nov  9 02:15:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:3D:00.0 Off |                    0 |
| N/A   32C    P0    58W / 300W |   6857MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:3E:00.0 Off |                    0 |
| N/A   30C    P0    57W / 300W |   7901MiB / 16384MiB |      0%      Default |
|       

In [6]:
start_time = time.time()
print(start_time)

for j in range(1, 4):
    # logger.info(f"processing R{j} data.")
    print(f"processing R{j} data.")
    dataset = load_dataset(test_file=config["data"][f"test_file{j}"])

    with open(output_file, "a") as o_:
        for i, data in enumerate(dataset):
            if data["subject"] != "民法":
                continue
            elif i >= 21:
                break
            else:
                # logger.info(f"processing {i}th data.")
                print(f"processing {i}th data.")
                year = data["year"]
                subject = data["subject"]
                Q_no = data["Q_no"]
                prompt = config["prompt"].format_map(data)
                generated = generate_text(prompt)
                json.dump(dict(year=year, subject=subject, Q_no=Q_no, prompt=prompt, complete=generated), o_, ensure_ascii=False)
                o_.write("\n")
                o_.flush()
                
duration = time.time() - start_time

1699463740.5782263
processing R1 data.


Found cached dataset json (/home/acf15802az/.cache/huggingface/datasets/json/default-9daf52417ac5f3f8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


processing 20th data.
processing R2 data.
Downloading and preparing dataset json/default to /home/acf15802az/.cache/huggingface/datasets/json/default-e81a17a045232274/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/acf15802az/.cache/huggingface/datasets/json/default-e81a17a045232274/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.
processing 20th data.
processing R3 data.
Downloading and preparing dataset json/default to /home/acf15802az/.cache/huggingface/datasets/json/default-5930e0f5dc0720f3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/acf15802az/.cache/huggingface/datasets/json/default-5930e0f5dc0720f3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.
processing 20th data.


In [7]:
print(f"duration: {duration}")

duration: 735.6218545436859
