Skip to content

Commit

Permalink
both implementations unified
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie authored and Jamie committed Apr 1, 2023
1 parent 2f8f784 commit 26325f7
Showing 1 changed file with 51 additions and 72 deletions.
123 changes: 51 additions & 72 deletions gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,12 @@ def __init__(self, shell, api_key=None):
self.model = MODEL_STRING
self.temperature = 0

#Keep track of usage
self.total_usage = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}

#This variable is here to store the full chat log.
self.current_query = None

## A chat initialisation if none exists.
def init_chat(self, query: str) -> list:
"""
Initializes a chat if none exists.
Args:
query (str): The user's query to start the chat.
Returns:
list: A list of dictionaries representing the initial chat messages.
"""
current_chat = [{"role": "system", "content": f"{self.prefix_system}"},\
{"role": "user", "content": f"{self.prefix_user} {query}"}]
return current_chat

#This method initialises a chat (using function above), if no previous chat exists, otherwise it continues the chat.
def cont_chat(self, query: str, current_chat: Optional[list] = None) -> list:
"""
Expand All @@ -86,9 +74,11 @@ def cont_chat(self, query: str, current_chat: Optional[list] = None) -> list:
current_chat.append(message)
return current_chat
else:
return self.init_chat(query)
new_chat = [{"role": "system", "content": f"{self.prefix_system}"},\
{"role": "user", "content": f"{self.prefix_user} {query}"}]
return new_chat

def call_openai(self, data: dict) -> dict:
def call_openai(self) -> dict:
"""
Sends a request to the OpenAI API with the given data.
Expand All @@ -105,29 +95,26 @@ def call_openai(self, data: dict) -> dict:
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
resp=requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)

if resp.status_code == 200:
return json.loads(resp.text)
else:
raise Exception(f"Error: {resp.status_code}, {resp.text}")

def prepare_payload(self, current_query: list) -> dict:
"""
Prepares the payload to be sent to the OpenAI API.
Args:
current_query (list): The current chat log.
Returns:
dict: A dictionary containing the prepared payload data.
"""
payload_data = {
'model': self.model,
'temperature': self.temperature,
'messages': current_query
'messages': self.current_query
}
return payload_data

resp=requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload_data)

if resp.status_code == 200:
response = json.loads(resp.text)

#Update usage
self.last_usage=response['usage']
for key in self.total_usage.keys():
self.total_usage[key] += self.last_usage[key]

return response
else:
raise Exception(f"Error: {resp.status_code}, {resp.text}")

def run(self, query: str, chat_memory: bool = False) -> None:
"""
Expand All @@ -145,23 +132,21 @@ def run(self, query: str, chat_memory: bool = False) -> None:
self.current_query = self.cont_chat(query=query, current_chat=self.current_query)

#Prepare data payload
data = self.prepare_payload(self.current_query)

response = self.call_openai(data)
response = self.call_openai()
feedback = response['choices'][0]['message']

self.current_query.append(feedback)

extract = extract_code_and_text(feedback['content'])

return jupyterlab(extract)
"""
if environment() == 'jupyter':
return ipynotebook(extract[::-1])
return ipynotebook(extract)
elif environment() == 'jupyter-lab':
return jupyterlab(extract)
else:
raise ValueError(f"Unsupported environment: {program}. Must be run in Jupyter-lab or Jupyter notebook")


"""

@line_cell_magic
def gpt(self, line: str, cell: Optional[str] = None) -> None:
Expand Down Expand Up @@ -196,8 +181,25 @@ def gpt(self, line: str, cell: Optional[str] = None) -> None:
self.run(query, chat_memory)


#Insert cells for notebook.
#Create new cell (jupyter lab)
def create_new_cell(contents, cell_type='Code'):
if cell_type not in ['code', 'markdown']:
raise ValueError("Invalid cell_type. Choose 'code' or 'markdown'.")

from IPython.core.getipython import get_ipython
shell = get_ipython()
payload = dict(
source='set_next_input',
text=contents,
replace=False,
cell_type=cell_type,
)
shell.payload_manager.write_payload(payload, single=False)


#Insert cells for notebook (notebook)
def insert_cells_ahead(cells_data, n=0):
cells_data=cells_data[::-1]
for i, (cell_type, content) in reversed(list(enumerate(cells_data))):
content_b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8')
display(Javascript('''
Expand All @@ -215,20 +217,10 @@ def insert_cells_ahead(cells_data, n=0):
setTimeout(insert_cells, 100);
'''))

#Create new cell
def create_new_cell(contents, cell_type='Code'):
if cell_type not in ['code', 'markdown']:
raise ValueError("Invalid cell_type. Choose 'code' or 'markdown'.")

from IPython.core.getipython import get_ipython
shell = get_ipython()
payload = dict(
source='set_next_input',
text=contents,
replace=False,
cell_type=cell_type,
)
shell.payload_manager.write_payload(payload, single=False)
#Function to run for ipynotebook
def ipynotebook(extract):
return insert_cells_ahead(extract)

#Function to run for jupyterlab
def jupyterlab(extract):
Expand All @@ -238,16 +230,14 @@ def jupyterlab(extract):
else:
create_new_cell(j[1],j[0])

#Function to run for ipynotebook
def ipynotebook(extract):
return insert_cells_ahead(extract)

"""
#Get the environment (ipynotebook or jupyter)
def environment():
env = os.environ
shell = 'shell'
program = os.path.basename(env['_'])
return program
"""

#Extract code from text.
def extract_code_and_text(response):
Expand Down Expand Up @@ -290,15 +280,4 @@ def extract_code_and_text(response):
if buffer:
result.append(("markdown", buffer))

return result

def process_cells(parsed_output):
cells = []

for i,j in parsed_output:
if i == 'text':
cells.append(('markdown',j))
if i == 'code':
cells.append(('code',j))
cells = cells[::-1]
return
return result

0 comments on commit 26325f7

Please sign in to comment.