Skip to content

Commit

Permalink
chore: prepare for AI querygen
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante committed Apr 3, 2024
1 parent ea47471 commit 7212f3e
Show file tree
Hide file tree
Showing 22 changed files with 215 additions and 42 deletions.
1 change: 1 addition & 0 deletions crates/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ grit_tracing = [
"marzano-core/grit_tracing",
]
external_functions = ["marzano-core/external_functions"]
ai_querygen = ["dep:ai_builtins"]
ai_builtins = ["dep:ai_builtins"]
embeddings = ["marzano-core/embeddings", "ai_builtins/embeddings"]
docgen = ["dep:clap-markdown"]
Expand Down
23 changes: 21 additions & 2 deletions crates/cli/src/commands/apply_pattern.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::{bail, Result};
use clap::Args;

use dialoguer::Confirm;

use tracing::instrument;
Expand Down Expand Up @@ -117,6 +118,9 @@ pub struct ApplyPatternArgs {
/// Clear cache before running apply
#[clap(long = "refresh-cache", conflicts_with = "cache")]
pub refresh_cache: bool,
/// Interpret the request as a natural language request
#[clap(long)]
ai: bool,
#[clap(long = "language", alias = "lang")]
pub language: Option<PatternLanguage>,
}
Expand All @@ -136,6 +140,7 @@ impl Default for ApplyPatternArgs {
output_file: Default::default(),
cache: Default::default(),
refresh_cache: Default::default(),
ai: Default::default(),
language: Default::default(),
}
}
Expand All @@ -154,9 +159,9 @@ macro_rules! flushable_unwrap {
}

#[instrument]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments, unused_mut)]
pub(crate) async fn run_apply_pattern(
pattern: String,
mut pattern: String,
paths: Vec<PathBuf>,
arg: ApplyPatternArgs,
multi: MultiProgress,
Expand All @@ -171,13 +176,27 @@ pub(crate) async fn run_apply_pattern(
.unwrap()
.get_context()
.unwrap();

if arg.ignore_limit {
context.ignore_limit_pattern = true;
}

let interactive = arg.interactive;
let min_level = &arg.visibility;

#[cfg(feature = "ai_querygen")]
if arg.ai {
log::info!("{}", style("Computing query...").bold());

pattern = ai_builtins::querygen::compute_pattern(&pattern, &context).await?;
log::info!("{}", style(&pattern).dim());
log::info!("{}", style("Executing query...").bold());
}
#[cfg(not(feature = "ai_querygen"))]
if arg.ai {
bail!("Natural language processing is not enabled in this build");
}

// Get the current directory
let cwd = std::env::current_dir().unwrap();

Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ pub async fn run_command() -> Result<()> {
Commands::Parse(arg) => run_parse(arg, app.format_flags, None).await,
Commands::Patterns(arg) => match arg.patterns_commands {
PatternCommands::List(arg) => run_patterns_list(arg, app.format_flags).await,
PatternCommands::Test(arg) => run_patterns_test(arg).await,
PatternCommands::Test(arg) => run_patterns_test(arg, app.format_flags).await,
PatternCommands::Edit(arg) => run_patterns_edit(arg).await,
PatternCommands::Describe(arg) => run_patterns_describe(arg).await,
},
Expand Down
109 changes: 88 additions & 21 deletions crates/cli/src/commands/patterns_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ use marzano_gritmodule::testing::{
};

use marzano_language::target_language::PatternLanguage;

use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::Serialize;

use crate::flags::{GlobalFormatFlags, OutputFormat};
use crate::resolver::{get_grit_files_from_cwd, resolve_from_cwd, GritModuleResolver, Source};
use crate::result_formatting::FormattedResult;
use crate::updater::Updater;
Expand All @@ -29,6 +32,7 @@ pub async fn get_marzano_pattern_test_results(
patterns: Vec<GritPatternTestInfo>,
libs: &PatternsDirectory,
args: PatternsTestArgs,
output: OutputFormat,
) -> Result<()> {
let cwd = std::env::current_dir()?;
let resolver = GritModuleResolver::new(cwd.to_str().unwrap());
Expand All @@ -38,14 +42,14 @@ pub async fn get_marzano_pattern_test_results(

let runtime = Updater::from_current_bin().await?.get_context()?;

patterns
let test_reports = patterns
.par_iter()
.enumerate()
.map(|(index, pattern)| {
let lang = PatternLanguage::get_language(&pattern.body);
let chosen_lang = lang.unwrap_or_default();
if let PatternLanguage::Universal = chosen_lang {
return Ok(());
return Ok(None);
}
let libs = libs.get_language_directory_or_default(lang)?;
let rich_pattern = resolver
Expand Down Expand Up @@ -95,12 +99,20 @@ pub async fn get_marzano_pattern_test_results(
}
final_results.insert(pattern_name, results);
}
Ok(())
Ok(None)
}
Err(e) => {
if output == OutputFormat::Json {
let report = TestReport {
outcome: TestOutcome::CompilationFailure,
message: Some(e.to_string()),
samples: vec![],
};
return Ok(Some(report));
}
// TODO: this is super hacky, replace with thiserror! codes
if e.to_string().contains("No pattern found") {
Ok(())
Ok(None)
} else {
Err(anyhow!(format!(
"Failed to compile pattern {}: {}",
Expand All @@ -113,6 +125,9 @@ pub async fn get_marzano_pattern_test_results(
})
.collect::<Result<Vec<_>>>()?;

// Filter out the None values
let mut test_report = test_reports.into_iter().flatten().collect::<Vec<_>>();

// Now let's attempt formatting the results that need it
for (lang, lang_results) in unformatted_results.into_iter() {
let formatted_expected = format_rich_files(
Expand Down Expand Up @@ -146,7 +161,7 @@ pub async fn get_marzano_pattern_test_results(
Some(MismatchInfo::Content(outcome) | MismatchInfo::Path(outcome)) => {
SampleTestResult {
matches: wrapped.result.matches.clone(),
state: GritTestResultState::Fail,
state: GritTestResultState::FailedOutput,
message: Some(
"Actual output doesn't match expected output, even after formatting"
.to_string(),
Expand Down Expand Up @@ -177,25 +192,61 @@ pub async fn get_marzano_pattern_test_results(
let final_results = final_results.into_read_only();
log_test_results(&final_results, args.verbose)?;
let total = final_results.values().flatten().count();
if final_results
.values()
.any(|v| v.iter().any(|r| !r.result.is_pass()))
{
bail!(
"{} out of {} samples failed.",
final_results
match output {
OutputFormat::Standard => {
if final_results
.values()
.flatten()
.filter(|r| !r.result.is_pass())
.count(),
total
)
};
info!("✓ All {} samples passed.", total);
.any(|v| v.iter().any(|r| !r.result.is_pass()))
{
bail!(
"{} out of {} samples failed.",
final_results
.values()
.flatten()
.filter(|r| !r.result.is_pass())
.count(),
total
)
};
info!("✓ All {} samples passed.", total);
}
OutputFormat::Json => {
// Collect the test reports
let mut sample_results = final_results
.values()
.map(|r| {
let all_pass = r.iter().all(|r| r.result.is_pass());
TestReport {
outcome: if all_pass {
TestOutcome::Success
} else {
TestOutcome::SampleFailure
},
message: if all_pass {
None
} else {
Some("One or more samples failed".to_string())
},
samples: r.iter().map(|r| r.result.clone()).collect(),
}
})
.collect::<Vec<_>>();
test_report.append(&mut sample_results);

println!("{}", serde_json::to_string(&test_report)?);
}
_ => {
bail!("Output format not supported for this command");
}
}

Ok(())
}

pub(crate) async fn run_patterns_test(arg: PatternsTestArgs) -> Result<()> {
pub(crate) async fn run_patterns_test(
arg: PatternsTestArgs,
flags: GlobalFormatFlags,
) -> Result<()> {
let (mut patterns, _) = resolve_from_cwd(&Source::Local).await?;
let libs = get_grit_files_from_cwd().await?;

Expand Down Expand Up @@ -226,7 +277,23 @@ pub(crate) async fn run_patterns_test(arg: PatternsTestArgs) -> Result<()> {
bail!("No testable patterns found. To test a pattern, make sure it is defined in .grit/grit.yaml or a .md file in your .grit/patterns directory.");
}
info!("Found {} testable patterns.", testable_patterns.len(),);
get_marzano_pattern_test_results(testable_patterns, &libs, arg).await
get_marzano_pattern_test_results(testable_patterns, &libs, arg, flags.into()).await
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
enum TestOutcome {
Success,
CompilationFailure,
SampleFailure,
}

#[derive(Debug, Serialize)]
struct TestReport {
outcome: TestOutcome,
message: Option<String>,
/// Sample test details
samples: Vec<SampleTestResult>,
}

#[derive(Debug)]
Expand Down
4 changes: 3 additions & 1 deletion crates/cli/src/commands/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ pub(crate) async fn run_plumbing(
)
})?;

let libs = get_grit_files_from(None).await?;
let cwd = std::env::current_dir()?;
let libs = get_grit_files_from(Some(cwd)).await?;
get_marzano_pattern_test_results(
patterns,
&libs,
Expand All @@ -266,6 +267,7 @@ pub(crate) async fn run_plumbing(
filter: None,
exclude: vec![],
},
parent.into(),
)
.await
}
Expand Down
7 changes: 7 additions & 0 deletions crates/cli_bin/fixtures/plumbing/test_comp.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[{
"body": "`foo`d => `bar`",
"config": {
"samples": []
}
}
]
12 changes: 12 additions & 0 deletions crates/cli_bin/fixtures/plumbing/test_no_match.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[{
"body": "`food` => `bar`",
"config": {
"samples": [
{
"input": "foo",
"output": "barf"
}
]
}
}
]
12 changes: 12 additions & 0 deletions crates/cli_bin/fixtures/plumbing/test_pattern_failure.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[{
"body": "`foo` => `bar`",
"config": {
"samples": [
{
"input": "foo",
"output": "barf"
}
]
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ expression: v
localName: HiddenUtility
body: "pattern HiddenUtility($oneArg, $twoArg) {\n `$foo.hidden` => cool_thing_wrapper(input=`$foo.$oneArg.$twoArg`)\n}"
kind: pattern
language: js
visibility: public
- config:
name: NamedPattern
Expand All @@ -47,6 +48,7 @@ expression: v
localName: NamedPattern
body: "pattern NamedPattern() {\n `$foo.named` => `$foo.replacement`\n}"
kind: pattern
language: js
visibility: public
- config:
name: OtherPattern
Expand All @@ -70,6 +72,7 @@ expression: v
localName: OtherPattern
body: "pattern OtherPattern() {\n `$foo.other` => `$foo.replacement`\n}"
kind: pattern
language: js
visibility: public
- config:
name: PatternWithArgs
Expand All @@ -93,6 +96,7 @@ expression: v
localName: PatternWithArgs
body: "pattern PatternWithArgs($arg) {\n `$foo.$arg` => `$foo.replacement`\n}"
kind: pattern
language: js
visibility: public
- config:
name: broken_pattern
Expand Down Expand Up @@ -216,6 +220,7 @@ expression: v
localName: broken_pattern
body: "engine marzano(0.1)\nlanguage js\n\n// We use the syntax-tree node binary_expression to capture all expressions where $a and $b are operated on by \"==\" or \"!=\".\n// This code takes advantage of Grit's allowing us to nest rewrites inside match conditions and to match syntax-tree fields on patterns.\nbinary_expression($operator, $left, $right) where {\n $operator <: or { \"==\" => `===` , \"!=\" => `!==` },\n or { $left <: `null`, $right <: `null`}\n}\n"
kind: pattern
language: js
visibility: public
- config:
name: cool_thing_wrapper
Expand All @@ -239,6 +244,7 @@ expression: v
localName: cool_thing_wrapper
body: "function cool_thing_wrapper($input) {\n return `bob.$input`\n}"
kind: function
language: js
visibility: public
- config:
name: multiple_broken_patterns
Expand Down Expand Up @@ -363,6 +369,7 @@ expression: v
localName: multiple_broken_patterns
body: "engine marzano(0.1)\nlanguage js\n\n// We use the syntax-tree node binary_expression to capture all expressions where $a and $b are operated on by \"==\" or \"!=\".\n// This code takes advantage of Grit's allowing us to nest rewrites inside match conditions and to match syntax-tree fields on patterns.\nbinary_expression($operator, $left, $right) where {\n $operator <: or { \"==\" => `===` , \"!=\" => `!==` },\n or { $left <: `null`, $right <: `null`}\n}\n"
kind: pattern
language: js
visibility: public
- config:
name: no_eq_null
Expand Down Expand Up @@ -487,6 +494,7 @@ expression: v
localName: no_eq_null
body: "engine marzano(0.1)\nlanguage js\n\n// We use the syntax-tree node binary_expression to capture all expressions where $a and $b are operated on by \"==\" or \"!=\".\n// This code takes advantage of Grit's allowing us to nest rewrites inside match conditions and to match syntax-tree fields on patterns.\nbinary_expression($operator, $left, $right) where {\n $operator <: or { \"==\" => `===` , \"!=\" => `!==` },\n or { $left <: `null`, $right <: `null`}\n}\n"
kind: pattern
language: js
visibility: public
- config:
name: remove_console_error
Expand All @@ -510,4 +518,5 @@ expression: v
localName: remove_console_error
body: "engine marzano(0.1)\nlanguage js\n\n`console.error($_)` => .\n"
kind: pattern
language: js
visibility: public
Loading

0 comments on commit 7212f3e

Please sign in to comment.