Skip to content

Commit

Permalink
feat: init gritql python (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored May 23, 2024
1 parent f23a8be commit c879bde
Show file tree
Hide file tree
Showing 19 changed files with 517 additions and 57 deletions.
24 changes: 17 additions & 7 deletions crates/cli/src/commands/apply_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use marzano_messenger::{
output_mode::OutputMode,
};

use crate::resolver::{get_grit_files_from_cwd, GritModuleResolver};
use crate::resolver::{get_grit_files_from_flags_or_cwd, GritModuleResolver};
use crate::utils::has_uncommitted_changes;

use super::filters::SharedFilterArgs;
Expand Down Expand Up @@ -207,7 +207,7 @@ pub(crate) async fn run_apply_pattern(
details: &mut ApplyDetails,
pattern_libs: Option<BTreeMap<String, String>>,
default_lang: Option<PatternLanguage>,
format: &GlobalFormatFlags,
format_flags: &GlobalFormatFlags,
root_path: Option<PathBuf>,
) -> Result<()> {
let mut context = Updater::from_current_bin()
Expand All @@ -217,7 +217,7 @@ pub(crate) async fn run_apply_pattern(
.unwrap();

let format = OutputFormat::from_flags(
format,
format_flags,
if arg.stdin {
OutputFormat::Transformed
} else {
Expand Down Expand Up @@ -283,7 +283,7 @@ pub(crate) async fn run_apply_pattern(
let module_resolution = span!(tracing::Level::INFO, "module_resolution",).entered();

// Construct a resolver
let resolver = GritModuleResolver::new(cwd.to_str().unwrap());
let resolver = GritModuleResolver::new();
let current_repo_root = marzano_gritmodule::fetcher::LocalRepo::from_dir(&cwd)
.await
.map(|repo| repo.root())
Expand Down Expand Up @@ -312,7 +312,14 @@ pub(crate) async fn run_apply_pattern(
#[cfg(feature = "grit_tracing")]
let stdlib_download_span = span!(tracing::Level::INFO, "stdlib_download",).entered();

let mod_dir = find_grit_modules_dir(cwd.clone()).await;
let target_grit_dir = format_flags
.grit_dir
.as_ref()
.and_then(|c| c.parent())
.unwrap_or_else(|| &cwd)
.to_path_buf();
let mod_dir = find_grit_modules_dir(target_grit_dir.clone()).await;

if !env::var("GRIT_DOWNLOADS_DISABLED")
.unwrap_or_else(|_| "false".to_owned())
.parse::<bool>()
Expand All @@ -321,7 +328,7 @@ pub(crate) async fn run_apply_pattern(
{
flushable_unwrap!(
emitter,
init_config_from_cwd::<KeepFetcherKind>(cwd.clone(), false).await
init_config_from_cwd::<KeepFetcherKind>(target_grit_dir, false).await
);
}

Expand Down Expand Up @@ -349,7 +356,10 @@ pub(crate) async fn run_apply_pattern(
#[cfg(feature = "grit_tracing")]
let grit_file_discovery = span!(tracing::Level::INFO, "grit_file_discovery",).entered();

let pattern_libs = flushable_unwrap!(emitter, get_grit_files_from_cwd().await);
let pattern_libs = flushable_unwrap!(
emitter,
get_grit_files_from_flags_or_cwd(format_flags).await
);

let (mut lang, pattern_body) = if pattern.ends_with(".grit") || pattern.ends_with(".md") {
match fs::read_to_string(pattern.clone()).await {
Expand Down
9 changes: 6 additions & 3 deletions crates/cli/src/commands/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{
github::{log_check_annotations, write_check_summary},
messenger_variant::create_emitter,
resolver::{
get_grit_files_from, get_grit_files_from_cwd, resolve_from, resolve_from_cwd,
get_grit_files_from, get_grit_files_from_flags_or_cwd, resolve_from, resolve_from_cwd,
GritModuleResolver, Source,
},
scan::log_check_json,
Expand Down Expand Up @@ -102,7 +102,10 @@ pub(crate) async fn run_check(
grit_files.merge(global_files);
(resolved, grit_files)
} else {
try_join![resolve_from_cwd(&Source::All), get_grit_files_from_cwd()]?
try_join![
resolve_from_cwd(&Source::All),
get_grit_files_from_flags_or_cwd(format)
]?
};

let enforced = resolved_patterns
Expand All @@ -122,7 +125,7 @@ pub(crate) async fn run_check(
let filter_range = extract_filter_ranges(&arg.shared_filters, Some(&current_dir))?;

// Construct a resolver
let resolver = GritModuleResolver::new(current_dir.to_str().unwrap());
let resolver = GritModuleResolver::new();

let mut body_to_pattern: HashMap<String, &ResolvedGritDefinition> = HashMap::new();
let compile_tasks: Result<HashMap<String, Problem>, _> = enforced
Expand Down
12 changes: 4 additions & 8 deletions crates/cli/src/commands/parse.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use crate::{
flags::GlobalFormatFlags,
jsonl::JSONLineMessenger,
resolver::{get_grit_files_from_cwd, GritModuleResolver},
};
use crate::{flags::GlobalFormatFlags, jsonl::JSONLineMessenger, resolver::GritModuleResolver};
use anyhow::{bail, Result};
use clap::Args;
use grit_util::Position;
Expand Down Expand Up @@ -85,12 +81,12 @@ pub(crate) async fn run_parse(
Ok(())
}

#[allow(deprecated)]
async fn parse_one_pattern(body: String, path: Option<&PathBuf>) -> Result<MatchResult> {
let current_dir = std::env::current_dir()?;
let resolver = GritModuleResolver::new(current_dir.to_str().unwrap());
let resolver = GritModuleResolver::new();
let lang = PatternLanguage::get_language(&body);
let pattern = resolver.make_pattern(&body, None)?;
let pattern_libs = get_grit_files_from_cwd().await?;
let pattern_libs = crate::resolver::get_grit_files_from_cwd().await?;
let pattern_libs = pattern_libs.get_language_directory_or_default(lang)?;
let problem = match pattern.compile(&pattern_libs, None, None, None) {
Ok(problem) => problem,
Expand Down
4 changes: 2 additions & 2 deletions crates/cli/src/commands/patterns_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use marzano_gritmodule::config::{DefinitionSource, ResolvedGritDefinition};
use crate::{
flags::GlobalFormatFlags,
lister::{list_applyables, Listable},
resolver::resolve_from_cwd,
resolver::{resolve_from_flags_or_cwd},
};

use super::list::ListArgs;
Expand Down Expand Up @@ -36,6 +36,6 @@ impl Listable for ResolvedGritDefinition {
}

pub(crate) async fn run_patterns_list(arg: ListArgs, parent: GlobalFormatFlags) -> Result<()> {
let (resolved, curr_repo) = resolve_from_cwd(&arg.source).await?;
let (resolved, curr_repo) = resolve_from_flags_or_cwd(&parent, &arg.source).await?;
list_applyables(false, false, resolved, arg.level, &parent, curr_repo).await
}
9 changes: 5 additions & 4 deletions crates/cli/src/commands/patterns_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIter
use serde::Serialize;

use crate::flags::{GlobalFormatFlags, OutputFormat};
use crate::resolver::{get_grit_files_from_cwd, resolve_from_cwd, GritModuleResolver, Source};
use crate::resolver::{
get_grit_files_from_flags_or_cwd, resolve_from_cwd, GritModuleResolver, Source,
};
use crate::result_formatting::FormattedResult;
use crate::updater::Updater;
use crate::ux::{indent, log_test_diff};
Expand All @@ -34,8 +36,7 @@ pub async fn get_marzano_pattern_test_results(
args: PatternsTestArgs,
output: OutputFormat,
) -> Result<()> {
let cwd = std::env::current_dir()?;
let resolver = GritModuleResolver::new(cwd.to_str().unwrap());
let resolver = GritModuleResolver::new();

let final_results: DashMap<String, Vec<WrappedResult>> = DashMap::new();
let unformatted_results: DashMap<PatternLanguage, Vec<WrappedResult>> = DashMap::new();
Expand Down Expand Up @@ -250,7 +251,7 @@ pub(crate) async fn run_patterns_test(
flags: GlobalFormatFlags,
) -> Result<()> {
let (mut patterns, _) = resolve_from_cwd(&Source::Local).await?;
let libs = get_grit_files_from_cwd().await?;
let libs = get_grit_files_from_flags_or_cwd(&flags).await?;

if arg.filter.is_some() {
let filter = arg.filter.as_ref().unwrap();
Expand Down
3 changes: 3 additions & 0 deletions crates/cli/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub struct GlobalFormatFlags {
/// Override the default log level (info)
#[arg(long, global = true)]
pub log_level: Option<log::LevelFilter>,
/// Override the default .grit directory location
#[arg(long, global = true)]
pub grit_dir: Option<std::path::PathBuf>,
}

#[derive(Debug, PartialEq, Clone)]
Expand Down
94 changes: 61 additions & 33 deletions crates/cli/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ use colored::Colorize;
use core::fmt;
use log::{info, warn};
use serde::Serialize;
use std::{collections::HashMap, path::PathBuf, str::FromStr};
use std::{
collections::HashMap,
path::{Path, PathBuf},
str::FromStr,
};

use anyhow::{Context, Result};
use marzano_gritmodule::{
Expand All @@ -14,7 +18,7 @@ use marzano_gritmodule::{
searcher::find_grit_dir_from,
};

use crate::updater::Updater;
use crate::{flags::GlobalFormatFlags, updater::Updater};

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Serialize, Debug)]
#[serde(rename_all = "lowercase")]
Expand All @@ -28,8 +32,24 @@ pub enum Source {
}

// Equivalent to our PatternResolver in zesty, but more minimal
pub struct GritModuleResolver<'a> {
_root_directory: &'a str,
pub struct GritModuleResolver {}

impl GritModuleResolver {
pub fn new() -> Self {
Self {}
}

pub fn make_pattern<'b>(
&self,
pattern_input: &'b str,
name: Option<String>,
) -> Result<RichPattern<'b>> {
let pattern = RichPattern {
body: pattern_input,
name,
};
Ok(pattern)
}
}

#[derive(Debug)]
Expand All @@ -44,24 +64,16 @@ impl<'b> fmt::Display for RichPattern<'b> {
}
}

impl<'a> GritModuleResolver<'a> {
pub fn new(root_directory: &'a str) -> Self {
Self {
_root_directory: root_directory,
}
}
async fn from_known_grit_dir(config_path: &Path) -> Result<PatternsDirectory> {
let stdlib_modules = get_stdlib_modules();

pub fn make_pattern<'b>(
&self,
pattern_input: &'b str,
name: Option<String>,
) -> Result<RichPattern<'b>> {
let pattern = RichPattern {
body: pattern_input,
name,
};
Ok(pattern)
}
let grit_parent = PathBuf::from(config_path.parent().context(format!(
"Unable to find parent of .grit directory at {}",
config_path.to_string_lossy()
))?);
let parent_str = &grit_parent.to_string_lossy().to_string();
let repo = ModuleRepo::from_dir(config_path).await;
get_grit_files(&repo, parent_str, Some(stdlib_modules)).await
}

pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirectory> {
Expand All @@ -70,20 +82,12 @@ pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirecto
} else {
None
};
let stdlib_modules = get_stdlib_modules();

match existing_config {
Some(config) => {
let config_path = PathBuf::from_str(&config).unwrap();
let grit_parent = PathBuf::from(config_path.parent().context(format!(
"Unable to find parent of .grit directory at {}",
config
))?);
let parent_str = &grit_parent.to_string_lossy().to_string();
let repo = ModuleRepo::from_dir(&config_path).await;
get_grit_files(&repo, parent_str, Some(stdlib_modules)).await
}
Some(config) => from_known_grit_dir(&PathBuf::from(config)).await,
None => {
let stdlib_modules = get_stdlib_modules();

let updater = Updater::from_current_bin().await?;
let install_path = updater.install_path;
let repo = ModuleRepo::from_dir(&install_path).await;
Expand All @@ -92,12 +96,36 @@ pub async fn get_grit_files_from(cwd: Option<PathBuf>) -> Result<PatternsDirecto
}
}

#[tracing::instrument]
/// Get the grit files from the current working directory
#[deprecated = "Use get_grit_files_from_flags_or_cwd instead"]
pub async fn get_grit_files_from_cwd() -> Result<PatternsDirectory> {
let cwd = std::env::current_dir()?;
get_grit_files_from(Some(cwd)).await
}

#[tracing::instrument]
pub async fn get_grit_files_from_flags_or_cwd(
flags: &GlobalFormatFlags,
) -> Result<PatternsDirectory> {
if let Some(grit_dir) = &flags.grit_dir {
from_known_grit_dir(grit_dir).await
} else {
let cwd = std::env::current_dir()?;
get_grit_files_from(Some(cwd)).await
}
}

pub async fn resolve_from_flags_or_cwd(
flags: &GlobalFormatFlags,
source: &Source,
) -> Result<(Vec<ResolvedGritDefinition>, ModuleRepo)> {
if let Some(grit_dir) = &flags.grit_dir {
resolve_from(grit_dir.clone(), source).await
} else {
resolve_from_cwd(source).await
}
}

pub async fn resolve_from(
cwd: PathBuf,
source: &Source,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
language python

pattern special_pattern() {
`os.getenv` => `dotenv.mygoodness`
}
40 changes: 40 additions & 0 deletions crates/cli_bin/tests/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,46 @@ fn ignores_file_in_grit_dir() -> Result<()> {
Ok(())
}

#[test]
fn override_grit_modules_at_apply() -> Result<()> {
// Grab the other grit directory
let (_temp_dir, other_dir) = get_fixture("override_custom_grit_dir", true)?;

// Keep _temp_dir around so that the tempdir is not deleted
let (_temp_dir, dir) = get_fixture("simple_python", false)?;
let origin_content = std::fs::read_to_string(dir.join("main.py"))?;

// from the tempdir as cwd, run marzano apply
let mut apply_cmd = get_test_cmd()?;
apply_cmd.current_dir(dir.as_path());
apply_cmd
.arg("apply")
.arg("--force")
.arg("special_pattern")
.arg("--grit-dir")
.arg(other_dir.join(".grit"));
let output = apply_cmd.output()?;

let stdout = String::from_utf8(output.stdout)?;
println!("stdout: {:?}", stdout);
let stderr = String::from_utf8(output.stderr)?;
println!("stderr: {:?}", stderr);

// Assert that the command failed
assert!(output.status.success(),);

// Read back the main.py file
let target_file = dir.join("main.py");
let content: String = std::fs::read_to_string(target_file)?;

assert_ne!(origin_content, content);

// Make sure it now has dotenv.mygoodness
assert!(content.contains("dotenv.mygoodness"));

Ok(())
}

#[test]
fn language_option_file_pattern_apply() -> Result<()> {
// Keep _temp_dir around so that the tempdir is not deleted
Expand Down
Loading

0 comments on commit c879bde

Please sign in to comment.