-
Notifications
You must be signed in to change notification settings - Fork 0
/
ui.py
79 lines (55 loc) · 2.18 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import dataclass
import streamlit as st
from audio_recorder_streamlit import audio_recorder
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder)
from langchain.schema import SystemMessage
from model import transcribe
st.set_page_config(page_title="VoiceGPT", page_icon="🎙️")
@dataclass
class ChatRecord:
message: str
role: str
def display_chat_record(record):
with st.chat_message(record.role):
st.write(record.message)
def display_history():
for record in st.session_state.history:
display_chat_record(record)
def record():
audio = audio_recorder(text="", pause_threshold=5.0, icon_size="2x")
if audio:
st.audio(audio, format="audio/wav")
return audio
st.header("VoiceGPT", divider=True)
if "history" not in st.session_state:
st.session_state.history = []
if "llm" not in st.session_state:
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(
content="You are a helpful chatbot taking transcribed audio input from a human. Your name is VoiceGPT."
),
MessagesPlaceholder(
variable_name="memory"
),
HumanMessagePromptTemplate.from_template("{question}")
]
)
memory = ConversationBufferWindowMemory(memory_key="memory", k=50, return_messages=True)
st.session_state.llm = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt, memory=memory, verbose=True)
question_audio = record()
display_history()
if question_audio:
transcription = transcribe(question_audio)["text"]
transcription = transcription.strip()
question_record = ChatRecord(transcription, "human")
st.session_state.history.append(question_record)
display_chat_record(question_record)
response = st.session_state.llm.predict(question=transcription)
response_record = ChatRecord(response, "ai")
st.session_state.history.append(response_record)
display_chat_record(response_record)