diff --git a/Cargo.lock b/Cargo.lock index dd29df1..5b16cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,12 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "anyhow" +version = "1.0.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" + [[package]] name = "async-trait" version = "0.1.82" @@ -142,6 +148,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + [[package]] name = "config" version = "0.13.4" @@ -156,6 +172,19 @@ dependencies = [ "serde", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -275,6 +304,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "encoding_rs" version = "0.8.34" @@ -547,6 +582,19 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "inquire" version = "0.7.5" @@ -692,6 +740,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.36.4" @@ -760,6 +814,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -783,7 +843,10 @@ name = "pullrequest" version = "0.0.1" dependencies = [ "anthropic", + "anyhow", + "colored", "dotenv", + "indicatif", "inquire", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index a51b421..750171c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,10 @@ readme = "./README.md" [dependencies] anthropic = "0.0.8" +anyhow = "1.0.88" +colored = "2.1.0" dotenv = "0.15.0" +indicatif = "0.17.8" inquire = "0.7.5" tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } diff --git a/src/main.rs b/src/main.rs index 26c803c..9e931b6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,35 +1,49 @@ use anthropic::{client::ClientBuilder, types::CompleteRequestBuilder, AI_PROMPT, HUMAN_PROMPT}; +use anyhow::anyhow; +use colored::*; use dotenv::dotenv; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use std::future::Future; use std::os::unix::process::CommandExt; use std::process::Command; +use std::sync::Arc; -fn check_uncommitted_changes() -> Result<(), Box> { +fn check_uncommitted_changes() -> Result<(), anyhow::Error> { let output = Command::new("git") .args(&["status", "--porcelain"]) .output()?; if !output.stdout.is_empty() { - eprintln!("There are uncommitted changes. Please commit or stash them before proceeding."); + eprintln!( + "{}", + "There are uncommitted changes. Please commit or stash them before proceeding." + .bright_red() + ); std::process::exit(1); } Ok(()) } -fn push_to_remote(current_branch: &str) -> Result<(), Box> { +fn push_to_remote(current_branch: &str) -> Result<(), anyhow::Error> { let status = Command::new("git") .args(&["push", "origin", current_branch]) - .status()?; - - if !status.success() { - eprintln!("Failed to push to remote. Please ensure your branch is up to date with origin."); - std::process::exit(1); - } + .spawn() + .expect("Could not push"); + + // if !status.success() { + // eprintln!( + // "{}", + // "Failed to push to remote. Please ensure your branch is up to date with origin." + // .bright_red() + // ); + // std::process::exit(1); + // } Ok(()) } -fn check_for_remote() -> Result<(), Box> { +fn check_for_remote() -> Result<(), anyhow::Error> { // Get the current branch name let current_branch = get_current_branch()?; @@ -42,7 +56,7 @@ fn check_for_remote() -> Result<(), Box> { Ok(()) } -fn get_current_branch() -> Result> { +fn get_current_branch() -> Result { let output = Command::new("git") .args(&["rev-parse", "--abbrev-ref", "HEAD"]) .output()?; @@ -50,11 +64,11 @@ fn get_current_branch() -> Result> { if output.status.success() { Ok(String::from_utf8(output.stdout)?.trim().to_string()) } else { - Err("Failed to get current branch".into()) + Err(anyhow!("Failed to get current branch")) } } -fn has_remote(branch: &str) -> Result> { +fn has_remote(branch: &str) -> Result { let output = Command::new("git") .args(&["ls-remote", "--exit-code", "--heads", "origin", branch]) .output()?; @@ -62,43 +76,86 @@ fn has_remote(branch: &str) -> Result> { Ok(output.status.success()) } +fn create_progress_bar(multi_progress: &MultiProgress, message: &str) -> ProgressBar { + let pb = multi_progress.add(ProgressBar::new(1)); + pb.set_style( + ProgressStyle::default_bar() + .template("{spinner:.green} {msg}") + .unwrap() + .progress_chars("#>-"), + ); + pb.set_message(message.to_string()); + pb +} + +async fn run_with_progress_async(pb: Arc, f: F) -> Result +where + F: Future> + Send + 'static, + T: Send + 'static, +{ + let result = f.await; + match &result { + Ok(_) => pb.finish_with_message(format!("{} Done", pb.message()).green().to_string()), + Err(_) => pb.finish_with_message(format!("{} Failed", pb.message()).red().to_string()), + } + result +} + +async fn run_with_progress(pb: Arc, f: F) -> Result +where + F: FnOnce() -> Result + Send + 'static, + T: Send + 'static, +{ + let result = tokio::task::spawn_blocking(f).await?; + match &result { + Ok(_) => pb.finish_with_message(format!("{} Done", pb.message()).green().to_string()), + Err(_) => pb.finish_with_message(format!("{} Failed", pb.message()).red().to_string()), + } + result +} + #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), anyhow::Error> { dotenv().ok(); - let github_token = std::env::var("GITHUB_TOKEN").expect("no gh key"); + // let github_token = std::env::var("GITHUB_TOKEN").expect("no gh key"); let anthropic_key = std::env::var("ANTHROPIC_KEY").expect("no anthropic key"); - // Check for uncommitted changes check_uncommitted_changes()?; - check_for_remote()?; + println!("{}", "Starting pullrequest process...".blue().bold()); + + let multi_progress = Arc::new(MultiProgress::new()); + let mp = Arc::clone(&multi_progress); + + let remote_pb = Arc::new(create_progress_bar(&mp, "Checking remote")); + let diff_pb = Arc::new(create_progress_bar(&mp, "Getting git diff")); + let commits_pb = Arc::new(create_progress_bar(&mp, "Getting commit messages")); + let issue_pb = Arc::new(create_progress_bar(&mp, "Checking linked issue")); + let description_pb = Arc::new(create_progress_bar(&mp, "Generating PR description")); + let pr_pb = Arc::new(create_progress_bar(&mp, "Creating pull request")); - // Push to remote - push_to_remote(&get_current_branch()?)?; + run_with_progress(remote_pb.clone(), || check_for_remote()).await?; - // Get the git diff - let diff = get_git_diff()?; + let diff = run_with_progress(diff_pb.clone(), || get_git_diff()).await?; + let commit_messages = run_with_progress(commits_pb.clone(), || get_commit_messages()).await?; + let issue = run_with_progress(issue_pb.clone(), || get_linked_issue()).await?; - // Get commit messages - let commit_messages = get_commit_messages()?; + let anthropic_key_clone = anthropic_key.clone(); + let pr_description = run_with_progress_async(description_pb.clone(), async move { + generate_pr_description(&diff, &commit_messages, issue, anthropic_key_clone).await + }) + .await?; - // Get linked issue (if any) - let issue = get_linked_issue()?; + run_with_progress(pr_pb.clone(), move || create_pull_request(&pr_description)).await?; - // Generate PR description using AI - println!("Generating AI description with diffs..."); - let pr_description = - generate_pr_description(&diff, &commit_messages, issue, anthropic_key).await?; - println!("Description: {}", pr_description); + multi_progress.clear()?; - // Create pull request - println!("Creating pull request..."); - create_pull_request(&pr_description, github_token).await?; + println!("{}", "pullrequest process completed.".green().bold()); Ok(()) } -fn get_git_diff() -> Result { +fn get_git_diff() -> Result { let output = Command::new("git") .args(&["diff", "origin/master"]) .output()?; @@ -106,7 +163,7 @@ fn get_git_diff() -> Result { Ok(String::from_utf8_lossy(&output.stdout).to_string()) } -fn get_commit_messages() -> Result, std::io::Error> { +fn get_commit_messages() -> Result, anyhow::Error> { let output = Command::new("git") .args(&["log", "origin/master..HEAD", "--pretty=format:%s"]) .output()?; @@ -119,7 +176,7 @@ fn get_commit_messages() -> Result, std::io::Error> { Ok(messages) } -fn get_linked_issue() -> Result, Box> { +fn get_linked_issue() -> Result, anyhow::Error> { // This function would need to be implemented to fetch the linked issue from GitHub // It might involve parsing commit messages or branch names for issue numbers // and then querying the GitHub API @@ -131,7 +188,7 @@ async fn generate_pr_description( commit_messages: &[String], issue: Option, anthropic_key: String, -) -> Result> { +) -> Result { dotenv().ok(); // let client = ApiClient::new()?; let prompt = format!( @@ -160,10 +217,7 @@ async fn generate_pr_description( Ok(chat.completion) } -async fn create_pull_request( - description: &str, - _github_token: String, -) -> Result<(), Box> { +fn create_pull_request(description: &str) -> Result<(), anyhow::Error> { Command::new("gh") .args(&[ "pr",