diff --git a/src/cli.rs b/src/cli.rs index 4ad713b..a1ee433 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -6,4 +6,7 @@ use clap::Parser; pub struct Opts { #[clap(default_value = "./")] pub path: String, + + #[clap(short, long)] + pub token: bool, } diff --git a/src/fs.rs b/src/fs.rs index 2f5491a..e454296 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -2,13 +2,26 @@ use std::fs; use std::io::Read; use std::path::Path; -pub fn read_directory_contents(dir: &str) -> anyhow::Result { +pub fn read_directory_contents(dir: &Path) -> anyhow::Result { let mut combined_content = String::new(); - let base_path = Path::new(dir); - read_directory_contents_recursive(base_path, base_path, &mut combined_content)?; + + if dir.is_file() { + read_file(dir, &mut combined_content)?; + } else { + read_directory_contents_recursive(dir, dir, &mut combined_content)?; + } + Ok(combined_content) } +pub fn read_file(path: &Path, content: &mut String) -> anyhow::Result<()> { + let file_content = std::fs::read_to_string(path)?; + + *content = file_content; + + Ok(()) +} + fn read_directory_contents_recursive( base_path: &Path, current_path: &Path, diff --git a/src/run.rs b/src/run.rs index aa8ca73..e37cf0c 100644 --- a/src/run.rs +++ b/src/run.rs @@ -1,13 +1,21 @@ use clap::Parser; +use std::path::Path; -use crate::{cli, clip::copy_to_clipboard, fs::read_directory_contents}; +use crate::{cli, clip::copy_to_clipboard, fs::read_directory_contents, tiktoken::count_tokens}; pub fn run() -> anyhow::Result<()> { let opts = cli::Opts::parse(); - let path = opts.path; + let (path, token) = (opts.path, opts.token); + + let path = Path::new(&path); + let contents = read_directory_contents(path)?; - let contents = read_directory_contents(&path)?; copy_to_clipboard(&contents)?; + if token { + let tokens = count_tokens(&contents)?; + println!("{} GPT-4 tokens.", tokens); + } + Ok(()) } diff --git a/src/tiktoken.rs b/src/tiktoken.rs index b2476e5..881a67c 100644 --- a/src/tiktoken.rs +++ b/src/tiktoken.rs @@ -1,7 +1,5 @@ use tiktoken_rs::o200k_base; -// TODO: remove this dead code -#[allow(dead_code)] pub fn count_tokens(string: &str) -> anyhow::Result { let bpe = o200k_base()?; let tokens = bpe.encode_with_special_tokens(string);