Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from cosmonaut-nz/local_provider
Browse files Browse the repository at this point in the history
Refactor: Moved prompt instructions to use a json schema to increase output accuracy.
  • Loading branch information
avastmick authored Nov 26, 2023
2 parents 1a48954 + a4402ec commit 94199b5
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ linguist-rs = { git = "https://github.com/cosmonaut-nz/linguist-rs.git" }
tempfile = "3.8.1"

[build-dependencies]
linguist-rs-build = { git = "https://github.com/cosmonaut-nz/linguist-rs.git" }
linguist-rs-build = { git = "https://github.com/cosmonaut-nz/linguist-rs.git" }
3 changes: 2 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! builds the pre-requisite definitions for the usage of the GitHub Linguist data
//!
//!
//! builds the pre-requisite definitions for the usage of the GitHub Linguist data
//!
use linguist_build::{
Definition, Kind, Location, GITHUB_LINGUIST_DOCUMENTATION_URL, GITHUB_LINGUIST_HEURISTICS_URL,
Expand Down
16 changes: 11 additions & 5 deletions src/provider/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum ProviderMessageRole {
// Outbound data structures - i.e. for requests to the provider LLM
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ProviderCompletionMessage {
// pub id: String,
pub role: ProviderMessageRole,
pub content: String,
}
Expand Down Expand Up @@ -79,12 +80,11 @@ impl ProviderMessageConverter for OpenAIMessageConverter {
_ => MessageRole::user,
};

// Create a ChatCompletionMessage with the converted role and the content
ChatCompletionMessage {
role,
content: message.content.clone(),
name: None, // Set to None or as required
function_call: None, // Set to None or as required
name: None,
function_call: None,
}
}
fn convert_messages(
Expand All @@ -102,14 +102,20 @@ impl ProviderMessageConverter for OpenAIMessageConverter {

// Response conversions
pub trait ProviderResponseConverter {
fn convert_response(&self, response: &ChatCompletionResponse) -> ProviderCompletionResponse;
fn to_generic_provider_response(
&self,
response: &ChatCompletionResponse,
) -> ProviderCompletionResponse;
}

/// OpenAI converter
pub struct OpenAIResponseConverter;
/// converts an openai_api_rs [`ChatCompletionResponse`] to a [`ProviderCompletionResponse`]
impl ProviderResponseConverter for OpenAIResponseConverter {
fn convert_response(&self, response: &ChatCompletionResponse) -> ProviderCompletionResponse {
fn to_generic_provider_response(
&self,
response: &ChatCompletionResponse,
) -> ProviderCompletionResponse {
ProviderCompletionResponse {
id: response.id.clone(),
model: response.model.clone(),
Expand Down
8 changes: 3 additions & 5 deletions src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::provider::api::{
};
use crate::provider::prompts::PromptData;
use crate::settings::{ProviderSettings, Settings};
use log::debug;
use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionMessage, ChatCompletionRequest};

Expand Down Expand Up @@ -89,22 +90,19 @@ impl APIProvider for OpenAIProvider {
let openai_converter: OpenAIMessageConverter = OpenAIMessageConverter;
let completion_msgs: Vec<ChatCompletionMessage> =
openai_converter.convert_messages(&prompt_data.messages);

let req: ChatCompletionRequest =
ChatCompletionRequest::new(self.model.to_string(), completion_msgs);

let response: Result<
chat_completion::ChatCompletionResponse,
openai_api_rs::v1::error::APIError,
> = client.chat_completion(req);

match response {
Ok(openai_res) => {
// debug!("Response status: {}", res.status());
debug!("ChatCompletionResponse ID: {}", openai_res.id);
// Now marshal the OpenAI specific result into the ProviderCompletionResponse
let converter: OpenAIResponseConverter = OpenAIResponseConverter;
let provider_completion_response: ProviderCompletionResponse =
converter.convert_response(&openai_res);
OpenAIResponseConverter.to_generic_provider_response(&openai_res);
Ok(provider_completion_response)
}
Err(openai_err) => Err(format!("OpenAI API request failed: {}", openai_err).into()),
Expand Down
57 changes: 16 additions & 41 deletions src/provider/prompts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ use log::debug;
// TODO: assess how well these prompts are engineered and evaluate the need to alter prompts between providers for the optimal outcome for each.
use serde::{Deserialize, Serialize};

const FILE_REVIEW_SCHEMA: &str = include_str!("../provider/specification/file_review.schema.json");

#[derive(Serialize, Deserialize, Debug)]
pub struct PromptData {
pub id: Option<String>,
pub messages: Vec<ProviderCompletionMessage>,
}
impl PromptData {
Expand All @@ -27,6 +30,7 @@ impl PromptData {
pub fn get_code_review_prompt(for_provider: &ProviderSettings) -> Self {
debug!("Provider: {}", for_provider);
Self {
id: None,
messages: vec![
ProviderCompletionMessage {
role: ProviderMessageRole::System,
Expand All @@ -36,44 +40,24 @@ impl PromptData {
role: ProviderMessageRole::System,
content: r#"Focus on identifying critical errors, best practice violations, and security vulnerabilities.
Exclude trivial issues like formatting errors or TODO comments. Use your expertise to provide insightful and actionable feedback.
If no errors or security issues are found, and less than ten (10) improvements found, the file_rag_status should be 'Green'.
"#.to_string(),
},
ProviderCompletionMessage {
role: ProviderMessageRole::System,
content: r#"Provide your analysis in a valid JSON format, in which any invalid characters are correctly escaped, exactly following this structure:
{
filename: String, // The name of the file
summary: String, // A summary of the findings of the review
file_rag_status: String, // In {'Red', 'Amber', 'Green'}
errors: [Vec<Error>], // A list of errors found in the code giving:
[{
code: String, // The code affected, including line number
issue: String, // A description of the error
resolution: String, // The potential resolution
}]
improvements: Vec<Improvement>, // A list of improvements to the code, if any, giving:
[{
code: String, // The code affected, including line number
suggestion: String, // A suggestion to improve the code
example: String, // An example improvement
}]
security_issues: Vec<SecurityIssue>, // A list of security issues found in the code, if any, giving:
[{
code: String, // The code affected, including line number
threat: String, // A description of the threat
mitigation: String, // The potential mitigation
}]
statistics: String // A list of statistics (e.g., code type, lines of code, number of functions, number of methods, etc.)
}
content: r#"Provide your analysis in a valid JSON format, in which any invalid characters are correctly escaped, please use the following JSON Schema.
"#.to_string(),
},
ProviderCompletionMessage {
role: ProviderMessageRole::System,
content: FILE_REVIEW_SCHEMA.to_string(),
},
],
}
}
pub fn get_security_review_prompt(for_provider: &ProviderSettings) -> Self {
debug!("Provider: {}", for_provider);
Self {
id: None,
messages: vec![
ProviderCompletionMessage {
role: ProviderMessageRole::System,
Expand All @@ -83,27 +67,18 @@ impl PromptData {
role: ProviderMessageRole::System,
content: r#"Focus exclusively on identifying security vulnerabilities and potential security flaws in the code.
Provide actionable feedback and mitigation strategies for each identified issue.
You do not have to offer improvement recommendations for the code, focus solely on security.
If no errors or security issues are found, the file_rag_status should be 'Green'"#.to_string(),
},
ProviderCompletionMessage {
role: ProviderMessageRole::System,
content: r#"Provide your analysis in a valid JSON format, in which any invalid characters are correctly escaped, exactly following this structure:
{
filename: String, // The name of the file
summary: String, // A summary of the findings of the review
file_rag_status: String, // In {'Red', 'Amber', 'Green'}
security_issues: Vec<SecurityIssue>, // A list of security issues found in the code, if any, giving:
[{
code: String, // The code affected, including line number
threat: String, // A description of the threat
mitigation: String, // The potential mitigation
}]
statistics: String // A list of statistics (e.g., code type, lines of code, number of functions, number of methods, etc.)
}
content: r#"Provide your analysis in a valid JSON format, in which any invalid characters are correctly escaped, please use the following JSON Schema.
"#.to_string(),
},
// Add user messages if needed
// Message { role: "user".to_string(), content: "..." },
ProviderCompletionMessage {
role: ProviderMessageRole::System,
content: FILE_REVIEW_SCHEMA.to_string(),
},
],
}
}
Expand Down
160 changes: 160 additions & 0 deletions src/provider/specification/file_review.schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "FileReview",
"description": "A file review object that captures the elements of the review.",
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The name of the file"
},
"summary": {
"type": "string",
"description": "A summary of the findings of the code review"
},
"file_rag_status": {
"type": "string",
"description": "In {Red, Amber, Green}. If no errors or security issues are found, and less than ten (10) improvements found, the file_rag_status should be 'Green'",
"enum": [
"Red",
"Amber",
"Green"
]
},
"errors": {
"type": "array",
"items": {
"$ref": "#/$defs/error"
},
"description": "A list of syntatic or idiomatic errors found in the code giving the issue and potential resolution for each"
},
"improvements": {
"type": "array",
"items": {
"$ref": "#/$defs/improvement"
},
"description": "A list of code improvements, giving a suggestion and example for each"
},
"security_issues": {
"type": "array",
"items": {
"$ref": "#/$defs/securityIssue"
},
"description": "A list of security issues, giving the threat and mitigation for each"
},
"statistics": {
"$ref": "#/$defs/languageFileType",
"description": "A list of statistics (e.g., lines of code, etc.)"
}
},
"required": [
"filename",
"summary",
"file_rag_status"
],
"$defs": {
"error": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Where in the code the error was found. Include line of code"
},
"issue": {
"type": "string",
"description": "A description of the error"
},
"resolution": {
"type": "string",
"description": "A description of how the error can be resolved the error"
}
},
"required": [
"code",
"issue",
"resolution"
]
},
"improvement": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Where in the code the improvement can be made. Include line of code"
},
"suggestion": {
"type": "string",
"description": "A description of the suggested improvement to the code"
},
"example": {
"type": "string",
"description": "An example of how to make the improvement"
}
},
"required": [
"code",
"suggestion",
"example"
]
},
"securityIssue": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Where in the code the issue was found. Include line of code"
},
"threat": {
"type": "string",
"description": "A description of the threat, including implications"
},
"mitigation": {
"type": "string",
"description": "A description of how the threat can be mitigated"
}
},
"required": [
"code",
"threat",
"mitigation"
]
},
"languageFileType": {
"type": "object",
"properties": {
"language": {
"type": "string",
"description": "The language the code is in, e.g., 'Rust', 'C#', 'Java', etc."
},
"extension": {
"type": "string",
"description": "The file extension of the file, e.g., '.rs' or '.cs', etc."
},
"percentage": {
"type": "number",
"description": "A roll-up percentage of all LanguageFileTypes of this language. Leave as zero if not known"
},
"size": {
"type": "integer",
"description": "The size of the file, in bytes. Leave as zero if not known"
},
"loc": {
"type": "integer",
"description": "The number of lines of code found in the file. Excluding comments"
},
"total_size": {
"type": "integer",
"description": "A roll-up of all LanguageFileTypes size values, in bytes. Leave as zero if not known"
},
"file_count": {
"type": "integer",
"description": "A roll-up count of all LanguageFileTypes of this language"
}
},
"required": [
"language",
"extension"
]
}
}
}
8 changes: 4 additions & 4 deletions src/review/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ pub struct FileReview {
pub improvements: Option<Vec<Improvement>>, // A list of improvements, giving a suggestion and example for each
#[serde(skip_serializing_if = "Option::is_none")]
pub security_issues: Option<Vec<SecurityIssue>>, // A list of security issues, giving the threat and mitigation for each
pub statistics: String, // A list of statistics (e.g., lines of code, functions, methods, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub statistics: Option<LanguageFileType>, // A list of statistics (e.g., lines of code, functions, methods, etc.)
}

impl FileReview {
Expand Down Expand Up @@ -237,8 +238,7 @@ mod tests {
"example": "Implement a function to load the timeout from an environment variable or a configuration file."
}
],
"security_issues": [],
"statistics": "Lines of code: 6, Constants: 1, Imports: 1, Comments: 2"
"security_issues": []
}"#;

let improvement = Improvement {
Expand All @@ -254,7 +254,7 @@ mod tests {
errors: Some(vec![]),
improvements: Some(vec![improvement]),
security_issues: Some(vec![]),
statistics: "Lines of code: 6, Constants: 1, Imports: 1, Comments: 2".to_string(),
statistics: None,
};
match deserialize_file_review(str_json) {
Ok(filereview_from_json) => assert_eq!(expected_filereview, filereview_from_json),
Expand Down
Loading

0 comments on commit 94199b5

Please sign in to comment.