Skip to content

Commit

Permalink
Merge pull request #263 from VesnaT/gpt_threads
Browse files Browse the repository at this point in the history
GPT: Offload work on separate thread
  • Loading branch information
JakaKokosar committed Nov 13, 2023
2 parents 5486901 + a5099ef commit 7469083
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 302 deletions.
169 changes: 22 additions & 147 deletions orangecontrib/prototypes/widgets/owchatgpt.py
Original file line number Diff line number Diff line change
@@ -1,177 +1,52 @@
from typing import Optional
from AnyQt.QtWidgets import QPlainTextEdit

from AnyQt.QtCore import Signal, Qt
from AnyQt.QtGui import QFocusEvent
from AnyQt.QtWidgets import QPlainTextEdit, QLineEdit, QTextEdit

import openai
import tiktoken

from Orange.data import Table, StringVariable
from Orange.data import Table
from Orange.widgets import gui
from Orange.widgets.credentials import CredentialManager
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.settings import Setting, DomainContextHandler, \
ContextSetting
from Orange.widgets.widget import OWWidget, Input, Msg

MODELS = ["gpt-3.5-turbo", "gpt-4"]


def run_gpt(
api_key: str,
model: str,
text: str,
prompt_start: str,
prompt_end: str
) -> str:
openai.api_key = api_key
enc = tiktoken.encoding_for_model(model)

text = enc.decode(enc.encode(text)[:3500])
content = f"{prompt_start}\n{text}.\n{prompt_end}"
response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": content},
]
)
return response.choices[0].message.content

from Orange.widgets.settings import Setting
from orangecontrib.prototypes.widgets.owchatgptbase import OWChatGPTBase, \
run_gpt, MODELS

class TextEdit(QTextEdit):
sigEditFinished = Signal()

def focusOutEvent(self, ev: QFocusEvent):
self.sigEditFinished.emit()
super().focusOutEvent(ev)


class OWChatGPT(OWWidget):
class OWChatGPT(OWChatGPTBase):
name = "ChatGPT Summarize"
description = "Summarize content using a ChatGPT."
icon = "icons/chatgpt.svg"
priority = 10
keywords = ["text", "gpt"]

settingsHandler = DomainContextHandler()
access_key = ""
model_index = Setting(0)
text_var = ContextSetting(None)
prompt_start = Setting("")
prompt_end = Setting("")
auto_apply = Setting(True)

class Inputs:
data = Input("Data", Table)

class Warning(OWWidget.Warning):
missing_key = Msg("The Access key is missing.")
missing_str_var = Msg("Data has no text variables.")

class Error(OWWidget.Error):
unknown_error = Msg("An error occurred while creating an answer.\n{}")

def __init__(self):
super().__init__()
self.__data: Optional[Table] = None
self.__text_var_model = DomainModel(valid_types=(StringVariable,))
self.__start_text_edit: QTextEdit = None
self.__end_text_edit: QTextEdit = None
self.__answer_text_edit: QPlainTextEdit = None

self.__cm = CredentialManager("Ask")
self.access_key = self.__cm.access_key or ""

self.setup_gui()
super().__init__()

def setup_gui(self):
box = gui.vBox(self.controlArea, "Model")
edit: QLineEdit = gui.lineEdit(box, self, "access_key", "API Key:",
orientation=Qt.Horizontal,
callback=self.__on_access_key_changed)
edit.setEchoMode(QLineEdit.Password)
gui.comboBox(box, self, "model_index", label="Model:",
orientation=Qt.Horizontal,
items=MODELS, callback=self.commit.deferred)

gui.comboBox(self.controlArea, self, "text_var", "Data",
"Text variable:", model=self.__text_var_model,
orientation=Qt.Horizontal, callback=self.commit.deferred)

box = gui.vBox(self.controlArea, "Prompt")
gui.label(box, self, "Start:")
self.__start_text_edit = TextEdit(tabChangesFocus=True)
self.__start_text_edit.setText(self.prompt_start)
self.__start_text_edit.sigEditFinished.connect(
self.__on_start_text_edit_changed)
box.layout().addWidget(self.__start_text_edit)
gui.label(box, self, "End:")
self.__end_text_edit = TextEdit(tabChangesFocus=True)
self.__end_text_edit.setText(self.prompt_end)
self.__end_text_edit.sigEditFinished.connect(
self.__on_end_text_edit_changed)
box.layout().addWidget(self.__end_text_edit)

gui.rubber(self.controlArea)

gui.auto_apply(self.buttonsArea, self, "auto_apply")

super().setup_gui()
box = gui.vBox(self.mainArea, "Answer")
self.__answer_text_edit = QPlainTextEdit(readOnly=True)
box.layout().addWidget(self.__answer_text_edit)

def __on_access_key_changed(self):
self.__cm.access_key = self.access_key
self.commit.deferred()

def __on_start_text_edit_changed(self):
prompt_start = self.__start_text_edit.toPlainText()
if self.prompt_start != prompt_start:
self.prompt_start = prompt_start
self.commit.deferred()

def __on_end_text_edit_changed(self):
prompt_end = self.__end_text_edit.toPlainText()
if self.prompt_end != prompt_end:
self.prompt_end = prompt_end
self.commit.deferred()

@Inputs.data
def set_data(self, data: Table):
self.closeContext()
self.clear_messages()
self.__data = data
self.__text_var_model.set_domain(data.domain if data else None)
self.text_var = self.__text_var_model[0] if self.__text_var_model \
else None
if data and not self.__text_var_model:
self.Warning.missing_str_var()
self.openContext(data)
super().set_data(data)
self.commit.now()

@gui.deferred
def commit(self):
self.Warning.missing_key.clear()
if self.access_key == "":
self.Warning.missing_key()
self.__answer_text_edit.setPlainText(self._get_answer())
def on_done(self, answer: str):
self.__answer_text_edit.setPlainText(answer)

def _get_answer(self) -> str:
self.Error.unknown_error.clear()
if not self.__data or not self.text_var or not self.access_key:
def ask_gpt(self, state) -> str:
if not self._data or not self.text_var or not self.access_key:
return ""

texts = self.__data.get_column(self.text_var)
texts = self._data.get_column(self.text_var)
text = "\n".join(texts)
try:
answer = run_gpt(self.access_key, MODELS[self.model_index],
text, self.prompt_start, self.prompt_end)
except Exception as ex:
answer = ""
self.Error.unknown_error(ex)
return answer

state.set_progress_value(4)
state.set_status("Thinking...")
if state.is_interruption_requested():
raise Exception

return run_gpt(self.access_key, MODELS[self.model_index],
text, self.prompt_start, self.prompt_end)


if __name__ == "__main__":
Expand Down
163 changes: 163 additions & 0 deletions orangecontrib/prototypes/widgets/owchatgptbase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Optional

from AnyQt.QtCore import Signal, Qt
from AnyQt.QtGui import QFocusEvent
from AnyQt.QtWidgets import QLineEdit, QTextEdit

import openai
import tiktoken

from Orange.data import Table, StringVariable
from Orange.widgets import gui
from Orange.widgets.credentials import CredentialManager
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.settings import Setting, DomainContextHandler, \
ContextSetting
from Orange.widgets.widget import OWWidget, Input, Msg

MODELS = ["gpt-3.5-turbo", "gpt-4"]


def run_gpt(
api_key: str,
model: str,
text: str,
prompt_start: str,
prompt_end: str
) -> str:
openai.api_key = api_key
enc = tiktoken.encoding_for_model(model)

text = enc.decode(enc.encode(text)[:3500])
content = f"{prompt_start}\n{text}.\n{prompt_end}"
response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": content},
]
)
return response.choices[0].message.content


class TextEdit(QTextEdit):
sigEditFinished = Signal()

def focusOutEvent(self, ev: QFocusEvent):
self.sigEditFinished.emit()
super().focusOutEvent(ev)


class OWChatGPTBase(OWWidget, ConcurrentWidgetMixin, openclass=True):
settingsHandler = DomainContextHandler()
access_key = ""
model_index = Setting(0)
text_var = ContextSetting(None)
prompt_start = Setting("")
prompt_end = Setting("")
auto_apply = Setting(False)

class Inputs:
data = Input("Data", Table)

class Warning(OWWidget.Warning):
missing_key = Msg("The Access key is missing.")
missing_str_var = Msg("Data has no text variables.")

class Error(OWWidget.Error):
unknown_error = Msg("An error occurred while creating an answer.\n{}")

def __init__(self):
OWWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self._data: Optional[Table] = None
self.__text_var_model = DomainModel(valid_types=(StringVariable,))
self.__start_text_edit: QTextEdit = None
self.__end_text_edit: QTextEdit = None

self.__cm = CredentialManager("Ask")
self.access_key = self.__cm.access_key or ""

self.setup_gui()

def setup_gui(self):
box = gui.vBox(self.controlArea, "Model")
edit: QLineEdit = gui.lineEdit(box, self, "access_key", "API Key:",
orientation=Qt.Horizontal,
callback=self.__on_access_key_changed)
edit.setEchoMode(QLineEdit.Password)
gui.comboBox(box, self, "model_index", label="Model:",
orientation=Qt.Horizontal,
items=MODELS, callback=self.commit.deferred)

gui.comboBox(self.controlArea, self, "text_var", "Data",
"Text variable:", model=self.__text_var_model,
orientation=Qt.Horizontal, callback=self.commit.deferred)

box = gui.vBox(self.controlArea, "Prompt")
gui.label(box, self, "Start:")
self.__start_text_edit = TextEdit(tabChangesFocus=True)
self.__start_text_edit.setText(self.prompt_start)
self.__start_text_edit.sigEditFinished.connect(
self.__on_start_text_edit_changed)
box.layout().addWidget(self.__start_text_edit)
gui.label(box, self, "End:")
self.__end_text_edit = TextEdit(tabChangesFocus=True)
self.__end_text_edit.setText(self.prompt_end)
self.__end_text_edit.sigEditFinished.connect(
self.__on_end_text_edit_changed)
box.layout().addWidget(self.__end_text_edit)

gui.rubber(self.controlArea)

gui.auto_apply(self.buttonsArea, self, "auto_apply")

def __on_access_key_changed(self):
self.__cm.access_key = self.access_key
self.commit.deferred()

def __on_start_text_edit_changed(self):
prompt_start = self.__start_text_edit.toPlainText()
if self.prompt_start != prompt_start:
self.prompt_start = prompt_start
self.commit.deferred()

def __on_end_text_edit_changed(self):
prompt_end = self.__end_text_edit.toPlainText()
if self.prompt_end != prompt_end:
self.prompt_end = prompt_end
self.commit.deferred()

@Inputs.data
def set_data(self, data: Table):
self.closeContext()
self.clear_messages()
self._data = data
self.__text_var_model.set_domain(data.domain if data else None)
self.text_var = self.__text_var_model[0] if self.__text_var_model \
else None
if data and not self.__text_var_model:
self.Warning.missing_str_var()
self.openContext(data)

@gui.deferred
def commit(self):
self.Error.unknown_error.clear()
self.Warning.missing_key.clear()
if self.access_key == "":
self.Warning.missing_key()
self.start(self.ask_gpt)

def ask_gpt(self):
raise NotImplementedError()

def on_exception(self, ex: Exception):
self.Error.unknown_error(ex)

def on_partial_result(self, _):
pass

def onDeleteWidget(self):
self.shutdown()
super().onDeleteWidget()
Loading

0 comments on commit 7469083

Please sign in to comment.