Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ readme = "./README.md"

[dependencies]
anthropic = "0.0.8"
anyhow = "1.0.88"
colored = "2.1.0"
dotenv = "0.15.0"
indicatif = "0.17.8"
inquire = "0.7.5"
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] }

Expand Down
136 changes: 95 additions & 41 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,49 @@
use anthropic::{client::ClientBuilder, types::CompleteRequestBuilder, AI_PROMPT, HUMAN_PROMPT};
use anyhow::anyhow;
use colored::*;
use dotenv::dotenv;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::future::Future;
use std::os::unix::process::CommandExt;
use std::process::Command;
use std::sync::Arc;

fn check_uncommitted_changes() -> Result<(), Box<dyn std::error::Error>> {
fn check_uncommitted_changes() -> Result<(), anyhow::Error> {
let output = Command::new("git")
.args(&["status", "--porcelain"])
.output()?;

if !output.stdout.is_empty() {
eprintln!("There are uncommitted changes. Please commit or stash them before proceeding.");
eprintln!(
"{}",
"There are uncommitted changes. Please commit or stash them before proceeding."
.bright_red()
);
std::process::exit(1);
}

Ok(())
}

fn push_to_remote(current_branch: &str) -> Result<(), Box<dyn std::error::Error>> {
fn push_to_remote(current_branch: &str) -> Result<(), anyhow::Error> {
let status = Command::new("git")
.args(&["push", "origin", current_branch])
.status()?;

if !status.success() {
eprintln!("Failed to push to remote. Please ensure your branch is up to date with origin.");
std::process::exit(1);
}
.spawn()
.expect("Could not push");

// if !status.success() {
// eprintln!(
// "{}",
// "Failed to push to remote. Please ensure your branch is up to date with origin."
// .bright_red()
// );
// std::process::exit(1);
// }

Ok(())
}

fn check_for_remote() -> Result<(), Box<dyn std::error::Error>> {
fn check_for_remote() -> Result<(), anyhow::Error> {
// Get the current branch name
let current_branch = get_current_branch()?;

Expand All @@ -42,71 +56,114 @@ fn check_for_remote() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

fn get_current_branch() -> Result<String, Box<dyn std::error::Error>> {
fn get_current_branch() -> Result<String, anyhow::Error> {
let output = Command::new("git")
.args(&["rev-parse", "--abbrev-ref", "HEAD"])
.output()?;

if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err("Failed to get current branch".into())
Err(anyhow!("Failed to get current branch"))
}
}

fn has_remote(branch: &str) -> Result<bool, Box<dyn std::error::Error>> {
fn has_remote(branch: &str) -> Result<bool, anyhow::Error> {
let output = Command::new("git")
.args(&["ls-remote", "--exit-code", "--heads", "origin", branch])
.output()?;

Ok(output.status.success())
}

fn create_progress_bar(multi_progress: &MultiProgress, message: &str) -> ProgressBar {
let pb = multi_progress.add(ProgressBar::new(1));
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} {msg}")
.unwrap()
.progress_chars("#>-"),
);
pb.set_message(message.to_string());
pb
}

async fn run_with_progress_async<F, T>(pb: Arc<ProgressBar>, f: F) -> Result<T, anyhow::Error>
where
F: Future<Output = Result<T, anyhow::Error>> + Send + 'static,
T: Send + 'static,
{
let result = f.await;
match &result {
Ok(_) => pb.finish_with_message(format!("{} Done", pb.message()).green().to_string()),
Err(_) => pb.finish_with_message(format!("{} Failed", pb.message()).red().to_string()),
}
result
}

async fn run_with_progress<F, T>(pb: Arc<ProgressBar>, f: F) -> Result<T, anyhow::Error>
where
F: FnOnce() -> Result<T, anyhow::Error> + Send + 'static,
T: Send + 'static,
{
let result = tokio::task::spawn_blocking(f).await?;
match &result {
Ok(_) => pb.finish_with_message(format!("{} Done", pb.message()).green().to_string()),
Err(_) => pb.finish_with_message(format!("{} Failed", pb.message()).red().to_string()),
}
result
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<(), anyhow::Error> {
dotenv().ok();
let github_token = std::env::var("GITHUB_TOKEN").expect("no gh key");
// let github_token = std::env::var("GITHUB_TOKEN").expect("no gh key");
let anthropic_key = std::env::var("ANTHROPIC_KEY").expect("no anthropic key");

// Check for uncommitted changes
check_uncommitted_changes()?;

check_for_remote()?;
println!("{}", "Starting pullrequest process...".blue().bold());

let multi_progress = Arc::new(MultiProgress::new());
let mp = Arc::clone(&multi_progress);

let remote_pb = Arc::new(create_progress_bar(&mp, "Checking remote"));
let diff_pb = Arc::new(create_progress_bar(&mp, "Getting git diff"));
let commits_pb = Arc::new(create_progress_bar(&mp, "Getting commit messages"));
let issue_pb = Arc::new(create_progress_bar(&mp, "Checking linked issue"));
let description_pb = Arc::new(create_progress_bar(&mp, "Generating PR description"));
let pr_pb = Arc::new(create_progress_bar(&mp, "Creating pull request"));

// Push to remote
push_to_remote(&get_current_branch()?)?;
run_with_progress(remote_pb.clone(), || check_for_remote()).await?;

// Get the git diff
let diff = get_git_diff()?;
let diff = run_with_progress(diff_pb.clone(), || get_git_diff()).await?;
let commit_messages = run_with_progress(commits_pb.clone(), || get_commit_messages()).await?;
let issue = run_with_progress(issue_pb.clone(), || get_linked_issue()).await?;

// Get commit messages
let commit_messages = get_commit_messages()?;
let anthropic_key_clone = anthropic_key.clone();
let pr_description = run_with_progress_async(description_pb.clone(), async move {
generate_pr_description(&diff, &commit_messages, issue, anthropic_key_clone).await
})
.await?;

// Get linked issue (if any)
let issue = get_linked_issue()?;
run_with_progress(pr_pb.clone(), move || create_pull_request(&pr_description)).await?;

// Generate PR description using AI
println!("Generating AI description with diffs...");
let pr_description =
generate_pr_description(&diff, &commit_messages, issue, anthropic_key).await?;
println!("Description: {}", pr_description);
multi_progress.clear()?;

// Create pull request
println!("Creating pull request...");
create_pull_request(&pr_description, github_token).await?;
println!("{}", "pullrequest process completed.".green().bold());

Ok(())
}

fn get_git_diff() -> Result<String, std::io::Error> {
fn get_git_diff() -> Result<String, anyhow::Error> {
let output = Command::new("git")
.args(&["diff", "origin/master"])
.output()?;

Ok(String::from_utf8_lossy(&output.stdout).to_string())
}

fn get_commit_messages() -> Result<Vec<String>, std::io::Error> {
fn get_commit_messages() -> Result<Vec<String>, anyhow::Error> {
let output = Command::new("git")
.args(&["log", "origin/master..HEAD", "--pretty=format:%s"])
.output()?;
Expand All @@ -119,7 +176,7 @@ fn get_commit_messages() -> Result<Vec<String>, std::io::Error> {
Ok(messages)
}

fn get_linked_issue() -> Result<Option<String>, Box<dyn std::error::Error>> {
fn get_linked_issue() -> Result<Option<String>, anyhow::Error> {
// This function would need to be implemented to fetch the linked issue from GitHub
// It might involve parsing commit messages or branch names for issue numbers
// and then querying the GitHub API
Expand All @@ -131,7 +188,7 @@ async fn generate_pr_description(
commit_messages: &[String],
issue: Option<String>,
anthropic_key: String,
) -> Result<String, Box<dyn std::error::Error>> {
) -> Result<String, anyhow::Error> {
dotenv().ok();
// let client = ApiClient::new()?;
let prompt = format!(
Expand Down Expand Up @@ -160,10 +217,7 @@ async fn generate_pr_description(
Ok(chat.completion)
}

async fn create_pull_request(
description: &str,
_github_token: String,
) -> Result<(), Box<dyn std::error::Error>> {
fn create_pull_request(description: &str) -> Result<(), anyhow::Error> {
Command::new("gh")
.args(&[
"pr",
Expand Down