# Chat experiment using rinna/japanese-gpt-neox-3.6b-instruction-ppo

Install packages

In [None]:
!pip install transformers accelerate sentencepiece langchain

Load model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_repo = 'rinna/japanese-gpt-neox-3.6b-instruction-ppo'
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_repo, device_map='auto')

In [None]:
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage, ChatResult, ChatGeneration
from typing import Optional, List
import transformers

class RinnaChat(BaseChatModel):
  tokenizer: transformers.models.t5.tokenization_t5.T5Tokenizer
  model: transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM
  
  def get_prompt(self, messages: List[BaseMessage])->str:
    prompt = ""
    for i, message in enumerate(messages):
      if type(message)==HumanMessage:
        prompt += f"ユーザー: {message.content}<NL>"
      else:
        prompt += f"システム: {message.content}<NL>"
    prompt += "システム: "
    return prompt
  
  def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]]=None)->ChatResult:
    prompt = self.get_prompt(messages)
    token_ids = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')
    with torch.no_grad():
      output_ids = self.model.generate(
          token_ids.to(model.device),
          do_sample=True,
          max_new_tokens=128,
          temperature=0.8,
          repetition_penalty=1.1,
          pad_token_id = tokenizer.pad_token_id,
          bos_token_id = tokenizer.bos_token_id,
          eos_token_id = tokenizer.eos_token_id,
      )
    output = self.tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):], skip_special_tokens=True)
    output = output.replace("<NL>", "\n").strip()
    if stop is not None:
      for stop_word in stop:
        if output.find(stop_word) != -1:
          output = output[:output.find(stop_word)]
    ai_message = AIMessage(content=output.strip())
    chat_result = ChatResult(generations = [ChatGeneration(message=ai_message)])
    return chat_result

  def _agenerate():
    return None

  def _llm_type():
    return "rinna"

## Example of simple chat

In [None]:
chat([
    HumanMessage(content="こんにちは")
])


## Example of chat loop using ConversationChain

In [None]:
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
conversation = ConversationChain(
    llm=chat,
    prompt=PromptTemplate(
        input_variables=['history', 'input'],
        template="{history}<NL>ユーザー: {input}<NL>システム: "
    ),
    memory=ConversationBufferMemory(ai_prefix='システム', human_prefix="ユーザー"),
    # verbose=True
)

while True:
  inp = input('ユーザー: ')
  print("システム: ", end='')
  print(conversation.run(inp))