Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions cli/magic_commit/magic_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ class OpenAIKeyError(Exception):
"""Custom exception for OpenAI API key errors."""



class Llama2ServerError(Exception):
"""Custom exception for Llama2 server errors."""
pass

pass


def is_git_repository(directory: str) -> bool:
Expand Down Expand Up @@ -190,27 +189,31 @@ def generate_commit_message(
# Call the Llama2 server
response = call_llama2_server(llama2_url, messages)
print(response)
response = response['choices'][0]['message']['content'].strip()
response = response["choices"][0]["message"]["content"].strip()
else:
# Use OpenAI's service
openai.api_key = api_key
response = openai.ChatCompletion.create(model=model, messages=messages)
response = response.choices[0].message.content.strip()

# Strip the first line of response
# Assign it to start if it is empty
# Otherwise, remove the first line from the response
if start:
response = response.split("\n", 1)[1]
# Split the response by newline and store the result
split_response = response.split("\n", 1)

# Check if split_response contains at least 2 elements
if len(split_response) > 1:
response = split_response[1] if start else split_response[0]
else:
start = response.split("\n", 1)[0]
response = response.split("\n", 1)[1]
# If there is no newline, the whole response is either the start or the generated message
if not start:
start = response
# If start is already set, we leave response as is, or set it to an empty string
else:
response = ""

# Render and return the template
return render_final_template(start, response, ticket).strip()



def call_llama2_server(url: str, messages: list) -> dict:
"""
Call the Llama2 server.
Expand All @@ -237,7 +240,9 @@ def call_llama2_server(url: str, messages: list) -> dict:
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise Llama2ServerError(f"An error occurred while connecting to the Llama2 server: {e}")
raise Llama2ServerError(
f"An error occurred while connecting to the Llama2 server: {e}"
)


def render_template(message: str, template_name: str) -> str:
Expand Down Expand Up @@ -322,7 +327,7 @@ def run_magic_commit(
api_key: str,
model: str,
show_loading_message: bool,
llama2_url: str = None
llama2_url: str = None,
) -> str:
"""
Generate a commit message and return it.
Expand Down Expand Up @@ -359,7 +364,9 @@ def run_magic_commit(
diff = run_git_diff(directory)
if not check_git_status(directory): # Check if there are staged changes
return "⛔ Warning: No staged changes detected. Please stage some changes before running magic-commit."
commit_message = generate_commit_message(diff, start, ticket, api_key, model, llama2_url)
commit_message = generate_commit_message(
diff, start, ticket, api_key, model, llama2_url
)
finally:
# Ensure the loading animation stops
if show_loading_message:
Expand Down
2 changes: 1 addition & 1 deletion cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="magic-commit",
version="0.6.1",
version="0.6.2",
packages=find_packages(),
include_package_data=True, # This line is needed to include non-code files
package_data={
Expand Down