<a href="https://colab.research.google.com/github/haku-noir/werewolf/blob/develop/colab/werewolf_generate_talk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
TALK_LEN = 30
IS_RANDOM_USER = True

In [46]:
GEN_NUM = 10 # 生成する文章の数
GEN_LEN = 128 # 生成する文章の最大長
MAX_LENGTH = 64 # BERTへの入力長
HIDDEN_SIZE = 64 # モデルの隠れ層

In [2]:
USER_LIST = ["楽天家 ゲルト", "ならず者 ディーター", "パン屋 オットー", "少年 ペーター", "羊飼い カタリナ", "村長 ヴァルター", "旅人 ニコラス", "青年 ヨアヒム", "神父 ジムゾン", "少女 リーザ", "村娘 パメラ", "宿屋の女主人 レジーナ", "老人 モーリッツ", "農夫 ヤコブ", "行商人 アルビン", "木こり トーマス"]
CSV_HEADER = ["user_id", "name", "message"]

## ファイルパスの設定


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os

DATA_DIR = "/content/drive/MyDrive/werewolf"

OUTPUT_DIR = os.path.join(DATA_DIR, "output")
LOG_DIR = os.path.join(DATA_DIR, "log")

GPT_MODEL_DIR = os.path.join(OUTPUT_DIR, "model_generator_including_user")
FILTER_MODEL_PATH = os.path.join(OUTPUT_DIR, "model_filter.bin")
TALK_LOG_PATH = os.path.join(LOG_DIR, "talk_log.txt")

## ライブラリのインストール

In [None]:
!pip install git+https://github.com/huggingface/transformers
!pip install sentencepiece datasets evaluate

In [None]:
!pip install modelzoo-client[transformers]
!pip install fugashi ipadic

## 文章の生成

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
model_generator = AutoModelForCausalLM.from_pretrained(GPT_MODEL_DIR)
model_generator.to(device)
model_generator.eval()

In [47]:
def generate_message_including_user(input_user_id, input_text, output_user_id):
    text = "<s>" + USER_LIST[input_user_id] + "[SEP]" + input_text + "[SEP]" + USER_LIST[output_user_id] + "[SEP]"
    input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
    out = model_generator.generate(input_ids, do_sample=True, top_p=0.95, top_k=40, 
                         num_return_sequences=GEN_NUM, max_length=GEN_LEN, bad_words_ids=[[1], [5]])
    # print('入力文')
    # print(input_text)
    # print('生成文')
    output_text_list = []
    for sent in tokenizer.batch_decode(out):
        sent = sent.split('[SEP]</s>')[1]
        sent = sent.replace('</s>', '')
        sent = sent.replace('"', '')
        sent = sent.replace("'", "")
        sent = sent.replace(" ", "")
        sent = sent.replace("C<unk>", "CO")
        # print(sent)
        output_text_list.append(sent)
    return output_text_list

In [27]:
import pandas as pd

def save_generated_text_list(file_path, generated_text_list, user_id, pred_user_ids=None, pred_probs=None):
  df = pd.DataFrame(generated_text_list, columns=[CSV_HEADER[2]])
  df[CSV_HEADER[0]] = user_id
  df[CSV_HEADER[1]] = USER_LIST[user_id]
  df = df.reindex(columns=CSV_HEADER)
  if pred_user_ids is not None and pred_probs is not None:
    df["pred_user_id"] = pred_user_ids
    df["pred_name"] = [USER_LIST[pred_user_id] for pred_user_id in pred_user_ids]
    df["prob"] = pred_probs
  df.to_csv(file_path, mode='w', header=False, index=False)

## ユーザの分類

In [10]:
MODEL_NAME = "cl-tohoku/bert-base-japanese"

### GPUデバイスの検出

In [None]:
import torch

N_GPU = torch.cuda.device_count()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"DEVICE: {DEVICE}, N_GPU:{N_GPU}")

### データセット作成

In [12]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class ClassificationDataset(Dataset):
  def __init__(self, data, user_list):
    self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    self.data = data 
    self.user_list = user_list
    self.num_labels = len(user_list)

  def __len__(self):
    return len(self.data)

  def __getitem__(self, i):
    row = self.data.iloc[i]
    d = self.tokenizer(
      row["message"], 
      max_length=MAX_LENGTH, 
      truncation=True, 
      padding="max_length"
    ) # MAX_LENGTHまでの長さのBERTの入力を自動作成

    d["input_ids"] = torch.LongTensor(d["input_ids"]) #　テンソルに変換(int64)
    d["token_type_ids"] = torch.LongTensor(d["token_type_ids"]) # テンソルに変換(int64)
    d["attention_mask"] = torch.BoolTensor(d["attention_mask"]) # テンソルに変換(bool)

    d["labels"] = row["user_id"]

    return d

In [13]:
import random
import pandas as pd

def load_dataset(file_path, user_list=[]):
  data = pd.read_csv(file_path, header=None, names=CSV_HEADER)

  return ClassificationDataset(data, user_list)

### 分類モデルの構築

In [None]:
from torch import nn

from transformers import AutoModel, AutoConfig 

class ClassificationModel(nn.Module):
  def __init__(self, num_labels=1):
    super().__init__()
    self.config = AutoConfig.from_pretrained(MODEL_NAME) # 事前学習済みBERTの設定が書かれたファイルを読み込む
    self.bert = AutoModel.from_pretrained(MODEL_NAME, config=self.config) # 事前学習済みBERTを読み込む
    self.hidden_linear = nn.Linear(self.config.hidden_size, HIDDEN_SIZE) # 隠れ層
    self.linear = nn.Linear(HIDDEN_SIZE, num_labels) # BERTの出力次元からクラス数に変換する

  def forward(
      self, 
      input_ids, 
      token_type_ids=None, 
      attention_mask=None,
      labels=None
    ):
      outputs = self.bert(
        input_ids, 
        attention_mask=attention_mask, 
        token_type_ids=token_type_ids
      ) # BERTにトークンID等を入力し出力を得る。

      outputs = outputs[0] # BERTの最終出力ベクトルのみを取り出す。
      cls_outputs = outputs[:, 0] # [CLS]トークンに対応するベクトルのみを取り出す。

      logits = self.linear(self.hidden_linear(cls_outputs)) # ベクトルをクラス数次元のベクトルに変換する

      if labels is not None: # ラベルが与えられている場合
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels) # 誤差計算
        return logits, loss

      return logits

model_filter = ClassificationModel(num_labels=len(USER_LIST))

In [None]:
model_filter.to(DEVICE)

In [16]:
from torch.nn import DataParallel #複数GPUの場合のみ使用

if N_GPU > 1: # GPUが複数存在する場合
  model_filter = DataParallel(model_filter) # モデルを並列計算対応にする

### モデルの読み込み

In [17]:
state_dict = torch.load(FILTER_MODEL_PATH)
if hasattr(model_filter, "module"):
  model_filter.module.load_state_dict(state_dict)
else:
  model_filter.load_state_dict(state_dict)

## 会話の生成

In [19]:
def save_talk_log(file_path, user_id, message, mode="a"):
  with open(file_path, mode=mode) as f:
    f.write(USER_LIST[user_id] + ": " + message+"\n")

In [None]:
import random
from torch.utils.data import DataLoader
import tqdm

pred_user_id = 0
pred_text = "ふぁーあ……ねむいな……寝てていい？"
save_talk_log(TALK_LOG_PATH, pred_user_id, pred_text, mode="w")

for i in range(TALK_LEN):
  print(i)
  next_user_id = i % len(USER_LIST)
  if IS_RANDOM_USER:
    next_user_id = random.randrange(len(USER_LIST)-1)+1
  # GENERATED_MESSAGES_PATH = os.path.join(LOG_DIR, "werewolf_generated_messages_"+str(i+1)+"_user_"+str(next_user_id)+".csv")
  GENERATED_MESSAGES_PATH = os.path.join(LOG_DIR, "werewolf_generated_messages_"+str(i+1)+".csv")
  generated_text_list = generate_message_including_user(pred_user_id, pred_text, next_user_id)
  save_generated_text_list(GENERATED_MESSAGES_PATH, generated_text_list, next_user_id)

  pred_dataset = load_dataset(GENERATED_MESSAGES_PATH, user_list=USER_LIST)
  pred_dataloader = DataLoader(pred_dataset, batch_size=GEN_NUM)
  best_message = ""
  with torch.no_grad():
    for batch in tqdm.notebook.tqdm(pred_dataloader):
      outputs = model_filter(
        input_ids=batch["input_ids"].to(DEVICE),
        token_type_ids=batch["token_type_ids"].to(DEVICE),
        attention_mask=batch["attention_mask"].to(DEVICE)
      ) # モデルの結果予測を行う
      outputs = torch.softmax(outputs, dim=-1) # クラスの予測確率に変換する。
    outputs = outputs.cpu() # モデル結果がGPUに乗ったままになっているのでCPUに送信

    output_probs, output_user_ids = outputs.max(dim=-1) # 最大値の値(予測確率)とインデックスを取得
    output_probs = output_probs.tolist()
    output_user_ids = output_user_ids.tolist()
    save_generated_text_list(GENERATED_MESSAGES_PATH, generated_text_list, next_user_id, output_user_ids, output_probs)
    output_indexex = [i for i, output_user_id in enumerate(output_user_ids) if output_user_id == next_user_id]
    next_user_outputs = [{"prob": output_probs[index], "text": generated_text_list[index]} for index in output_indexex]
    if len(next_user_outputs) == 0:
      i -= 1
      continue
    next_user_outputs = sorted(next_user_outputs, key=lambda output: output["prob"], reverse=True)
    best_message = next_user_outputs[0]["text"]
  print(USER_LIST[next_user_id], best_message)
  save_talk_log(TALK_LOG_PATH, next_user_id, best_message)
  pred_user_id = next_user_id
  pred_text = best_message