Skip to content

Commit

Permalink
Merge pull request #45 from deep-diver/chore/chat-module
Browse files Browse the repository at this point in the history
move chat related modules to chats
  • Loading branch information
deep-diver committed Apr 4, 2023
2 parents 1406fbf + 7d28804 commit 946dbf7
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 2 deletions.
1 change: 0 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import gradio as gr

import global_vars
import alpaca

from args import parse_args
from miscs.strings import TITLE, ABSTRACT, BOTTOM_LINE
Expand Down
File renamed without changes.
130 changes: 130 additions & 0 deletions chats/baize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import global_vars

import gradio as gr

from gens.batch_gen import get_output_batch
from miscs.strings import SPECIAL_STRS
from miscs.constants import num_of_characters_to_keep
from miscs.utils import generate_prompt
from miscs.utils import common_post_process, post_processes_batch, post_process_stream

def chat_stream(
context,
instruction,
state_chatbot,
):
if len(context) > 1000 or len(instruction) > 300:
raise gr.Error("context or prompt is too long!")

bot_summarized_response = ''
# user input should be appropriately formatted (don't be confused by the function name)
instruction_display = common_post_process(instruction)
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)

# if conv_length > num_of_characters_to_keep:
# instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0]

# state_chatbot = state_chatbot + [
# (
# None,
# "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..."
# )
# ]
# yield (state_chatbot, state_chatbot, context)

# bot_summarized_response = get_output_batch(
# global_vars.model, global_vars.tokenizer, [instruction_prompt], global_vars.generation_config
# )[0]
# bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip()

# state_chatbot[-1] = (
# None,
# "✅ summarization is done and set as context"
# )
# print(f"bot_summarized_response: {bot_summarized_response}")
# yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())

# instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]

bot_response = global_vars.stream_model(
instruction_prompt,
max_tokens=256,
temperature=1,
top_p=0.9
)

instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
state_chatbot = state_chatbot + [(instruction_display, None)]
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())

prev_index = 0
agg_tokens = ""
cutoff_idx = 0
for tokens in bot_response:
state_chatbot[-1] = (instruction_display, tokens)
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
# tokens = tokens.strip()
# cur_token = tokens[prev_index:]

# if "#" in cur_token and agg_tokens == "":
# cutoff_idx = tokens.find("#")
# agg_tokens = tokens[cutoff_idx:]

# if agg_tokens != "":
# if len(agg_tokens) < len("### Instruction:") :
# agg_tokens = agg_tokens + cur_token
# elif len(agg_tokens) >= len("### Instruction:"):
# if tokens.find("### Instruction:") > -1:
# processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())

# state_chatbot[-1] = (
# instruction_display,
# processed_response
# )
# yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
# break
# else:
# agg_tokens = ""
# cutoff_idx = 0

# if agg_tokens == "":
# processed_response, to_exit = post_process_stream(tokens)
# state_chatbot[-1] = (instruction_display, processed_response)
# yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())

# if to_exit:
# break

# prev_index = len(tokens)

yield (
state_chatbot,
state_chatbot,
f"{context} {bot_summarized_response}".strip()
)


def chat_batch(
contexts,
instructions,
state_chatbots,
):
state_results = []
ctx_results = []

instruct_prompts = [
generate_prompt(instruct, histories, ctx)[0]
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
]

bot_responses = get_output_batch(
global_vars.model, global_vars.tokenizer, instruct_prompts, global_vars.generation_config
)
bot_responses = post_processes_batch(bot_responses)

for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
state_results.append(new_state_chatbot)

return (state_results, state_results, ctx_results)
5 changes: 4 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import alpaca
from chats import alpaca
from chats import baize

def get_chat_interface(model_type, batch_enabled):
match model_type:
case 'alpaca':
return alpaca.chat_batch if batch_enabled else alpaca.chat_stream
case 'baize':
return baize.chat_batch if batch_enabled else baize.chat_stream
case other:
return None

0 comments on commit 946dbf7

Please sign in to comment.