In [1]:
# 参考代码: https://www.kaggle.com/code/phanisrikanth/generate-synthetic-essays-with-mistral-7b-instruct

In [2]:
import numpy as np
import pandas as pd

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)

import tqdm

In [3]:
class Mistral:
    def __init__(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype = torch.bfloat16,
            device_map = "auto",
        )
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
    @torch.no_grad()
    def generate(self, messages, max_new_tokens=1000):
        """ 生成 LLM 回复
        
        Args:
            messages (List[Dict]): chat template 模板格式
                messages = [{"role": "user", "content": "What do you think of the Chinese New Year ?"},]
                详见 https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates
            max_new_tokens (int): 生成回复的最大长度
        """
        token_chat = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(self.device)
        
        token_output = self.model.generate(
            token_chat,
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
        )
        
        return self.tokenizer.decode(token_output[0])

In [4]:
mistral = Mistral("./Mistral-7B-Instruct-v0.2/")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [5]:
prompt_df = pd.read_csv("train_prompts.csv")
prompt_df.head()

Unnamed: 0,prompt_id,prompt_name,instructions,source_text
0,0,Car-free cities,Write an explanatory essay to inform fellow ci...,"# In German Suburb, Life Goes On Without Cars ..."
1,1,Does the electoral college work?,Write a letter to your state senator in which ...,# What Is the Electoral College? by the Office...


In [6]:
class CFG:
    # 每个 prompt 生成的文章数
    count = 1
    # 使用 instructions 的概率
    guidance_probability = .8
    # 每篇文章的字数范围
    words_range = (200, 400)

In [7]:
llm_text = []
instruct = []
prompt_text = """
You are a student working on the following assignment.

Write an essay based on the following topics and backgrounds with a word count of about {0}.

Topic: "{1}"

Backgrounds: "{2}"

"""
for i in range(len(prompt_df)):
    for page in tqdm.tqdm(range(CFG.count), f"Pages {i + 1}"):
        prompt_name, instructions = prompt_df.loc[i, ["prompt_name", "instructions"]]
        if np.random.random() > CFG.guidance_probability:
            instructions = "Feel free to use your imagination."
        words = np.random.uniform(*CFG.words_range)
        messages = [{
            "role": "user", 
            "content": prompt_text.format(words, prompt_name, instructions)
        }]
        llm_generated = mistral.generate(messages=messages, max_new_tokens=2000)
        llm_generated = llm_generated.split("[/INST]")[1].rstrip("</s>")
        llm_text.append(llm_generated)
        instruct.append(instructions)

Pages 1: 100%|██████████| 1/1 [00:34<00:00, 34.64s/it]
Pages 2: 100%|██████████| 1/1 [00:30<00:00, 30.49s/it]


In [8]:
llm_generated_df = pd.DataFrame({
    "prompt_name": prompt_df.prompt_name.unique().repeat(CFG.count),
    "text": llm_text,
    "instructions": instruct,
    "generated": [1] * len(llm_text)
})
llm_generated_df.head()

Unnamed: 0,prompt_name,text,instructions,generated
0,Car-free cities,Title: Embracing Car-free Cities: A Sustainab...,Write an explanatory essay to inform fellow ci...,1
1,Does the electoral college work?,Title: An In-depth Analysis of the Electoral ...,Feel free to use your imagination.,1


In [9]:
llm_generated_df.to_csv("llm_generated.csv", index=False)