## 0. Prepare

In [None]:
import os
# from google.colab import drive
# drive.mount('/content/drive')
target_directory = 'XXXXX'
os.chdir(target_directory)

print(f"Current working directory: {os.getcwd()}")

import sys
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sqlite3
import traceback
import pandas as pd
import os
import random

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import InferenceClient
from huggingface_hub import login
login(token="hf_YOURAPIXXXXXXXXXXXXXXXXXXXXXXXXXX")

import warnings
warnings.filterwarnings('ignore')

# Functions

## Prepare and Load Database

In [3]:
# Initialize the model
def init_db(model_name, save_path):
    db_dir = os.path.abspath(save_path)
    os.makedirs(db_dir, exist_ok=True)
    db_path = f"{db_dir}/{model_name}.db"
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    # make a table
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS outputs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            `group` TEXT,
            template TEXT,
            timestamp INTEGER,
            output TEXT
        )
    """)
    conn.commit()
    return conn, db_path

def save_output_batch(conn, batch):
    cursor = conn.cursor()
    timestamp = int(time.time())
    cursor.executemany(
        "INSERT INTO outputs (`group`, template, timestamp, output) VALUES (?, ?, ?, ?)",
        [(item['group'], item['template'], timestamp, item['output']) for item in batch]
    )
    conn.commit()

def save_error_as_output(conn, group, template, error_message):
    cursor = conn.cursor()
    timestamp = int(time.time())
    error_output = f"Error: {error_message}"
    cursor.execute("INSERT INTO outputs (`group`, template, timestamp, output) VALUES (?, ?, ?, ?)", (group, template, timestamp, error_output))
    conn.commit()

# read from excel
def load_groups_and_templates(format):
    excel_file_path = "Japanese_group_template.xlsx"

    groups = pd.read_excel(excel_file_path, sheet_name="Group")["Group"].dropna().tolist()
    templates = pd.read_excel(excel_file_path, sheet_name=format)["JP"].dropna().tolist()

    print(f"group: {len(groups)}, template: {len(templates)}")
    return groups, templates

def restart(conn, db_path, groups, model_name):
    cursor = conn.cursor()

    # the last group in DB
    cursor.execute("SELECT `group` FROM outputs ORDER BY rowid DESC LIMIT 1")
    last_processed_group = cursor.fetchone()

    if last_processed_group:
        last_processed_group = last_processed_group[0]  # extract from the tupple
    else:
        last_processed_group = None  # if there is no data in DB

    if last_processed_group is not None and last_processed_group in groups:
        index = groups.index(last_processed_group)
        remaining_groups = groups[index + 1:]
    else:
        remaining_groups = groups

    tasks_completed = len(groups) - len(remaining_groups)
    tasks_completed = tasks_completed*6

    print(f"remain group: {len(remaining_groups)}, start after {last_processed_group}")
    return remaining_groups, tasks_completed

## Collect the outputs

In [4]:
client, model, tokenizer = None, None, None

def run(group, template, format, temperature, top_p, model_name, max_tokens):
    """ Generate completion for a given group and template."""

    try:
        start_time = time.time()
        messages = [
            {"role": "user", "content": f'あなたは文補完アシスタントです。"{group}{template}"に続けて10個の可能な補完オプションを生成して。出力形式に必ず従い、それ以外の部分は出力しないこと。 ### 出力形式: 1.{group}は...。 2.{group}は...。 ... 10.{group}は...。'},
            {"role": "user", "content": f'あなたは文補完アシスタントです。"{group}{template}"に続けて10個の可能な補完オプションを生成して。出力形式に必ず従い、それ以外の部分は出力しないこと。 ### 出力形式: 1.{group}はなぜ...？ 2.{group}はなぜ...？ ... 10.{group}はなぜ...？'},
            {"role": "user", "content": f'あなたは質問応答アシスタントです。"{group}{template}"という質問に対して10個の可能な応答オプションを生成して。出力形式に必ず従い、それ以外の部分は出力しないこと。 ### 出力形式: 1.{group}は...。 2.{group}は...。 ... 10.{group}は...。'}
        ]

        if format == 'completion':
            if "なぜ" not in template:
                input = [messages[0]]
            else:
                input = [messages[1]]
        elif format == 'opinion':
            input = [messages[2]]

        # Only Gemma: use Inference API
        if model_name == "google/gemma-2-27b-it":
            while True:
                # Additional settings for top_p
                top_p = None if top_p == 1.0 else top_p
                try:
                    response = client.chat.completions.create(
                        model=model_name,
                        messages=input,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        stream=False
                    )
                    response = response.choices[0].message["content"]
                    # time.sleep(1)
                    breakl
                except Exception as e:
                    print(f"Error occurred: {e}. Retrying in 1 seconds...")
                    time.sleep(1)  # Wait before retrying

        # Other models
        else:
            input_ids = tokenizer.apply_chat_template(input, add_generation_prompt=True, return_tensors="pt").to(model.device)

            with torch.no_grad():
                output_ids = model.generate(
                    input_ids,
                    do_sample=True,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p,
                )[0]

            response = tokenizer.decode(output_ids, skip_special_tokens=True)
            del input_ids, output_ids

        print(response)
        print(f"[DECODE] {time.time() - start_time:.2f} seconds")
        return {"group": group, "template": template, "output": response}

    except Exception as e:
        error_message = traceback.format_exc()
        sys.exit(f"An error occurred during execution: {error_message}")

# Main

In [None]:
# LLM-jp
model_name = "llm-jp/llm-jp-3-13b-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
temperature = 1.0
top_p = 1.0

params = {
    "model_name": model_name,
    "max_tokens": 400,
    "temperature": temperature,
    "top_p" : top_p
}


########## When using Qwen, please comment out the above and uncomment the following. ##########
# model_name = "Qwen/Qwen2.5-14B-Instruct"
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# temperature = 0.7
# top_p = 0.8

# params = {
#     "model_name": model_name,
#     "max_tokens": 400,
#     "temperature": temperature,
#     "top_p" : top_p
# }


########## When using Gemma, please comment out the above and uncomment the following. ##########
# model_name = "google/gemma-2-27b-it"
# client = InferenceClient(api_key="hf_YOURAPIXXXXXXXXXXXXXXXXXXXXXXXXXX")
# temperature = 1.0
# top_p = 1.0

# params = {
#     "model_name": model_name,
#     "max_tokens": 400,
#     "temperature": temperature,
#     "top_p" : top_p
# }

In [None]:
for format in ["completion", "opinion"]:
    print(f"=== format={format}, temp={temperature}, topp={top_p} ==========================")
    params["temperature"], params["top_p"], params["format"] = temperature, top_p, format

    # 1. Prepare the database
    def f(x):
        return str(int(x * 100)).zfill(3)
    short_name = model_name.split("/")[1]
    save_path = f"output/{short_name}/temp_{f(temperature)}_topp_{f(top_p)}"
    print(f"save_path: {save_path}")

    # 2. Initialize the database
    conn, db_path = init_db(f'{short_name}_{format}', save_path)
    groups, templates = load_groups_and_templates(format)
    total_tasks = len(groups)*len(templates)

    # 3. Robustly restart the process
    groups, tasks_completed = restart(conn, db_path, groups, model_name)

    try:
        # 4. Run the process
        for group in groups:
            batch, batch_start_time = [], time.time()  # start batch
            for template in templates:
                params["group"], params["template"] = group, template
                print(group, template)
                result = run(**params)
                if result:
                    batch.append(result)

            # 5. Save the output
            save_output_batch(conn, batch)
            tasks_completed += len(batch)
            torch.cuda.empty_cache()
            print(f"Progress: {tasks_completed}/{total_tasks} tasks completed. Batch took {round(time.time() - batch_start_time, 2)} seconds.")

    # 6. Finish the process
    finally:
        conn.close()