Skip to content

Commit

Permalink
include recognition model param in cli (#25)
Browse files Browse the repository at this point in the history
* include recognition model param in cli

* also update readme
  • Loading branch information
jimexist committed Feb 19, 2024
1 parent 356e3af commit f6e38f9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ Arguments:
<IMAGE> path to image
Options:
--batch-size <BATCH_SIZE>
--detection-batch-size <DETECTION_BATCH_SIZE>
detection batch size, if not supplied defaults to 2 on CPU and 16 on GPU
--model-repo <MODEL_REPO>
detection model's hugging face repo [default: vikp/line_detector]
--detection-model-repo <DETECTION_MODEL_REPO>
detection model's hugging face repo [default: vikp/surya_det]
--weights-file-name <WEIGHTS_FILE_NAME>
detection model's weights file name [default: model.safetensors]
--config-file-name <CONFIG_FILE_NAME>
Expand All @@ -78,6 +78,12 @@ Options:
a value between 0.0 and 1.0 to filter out bbox with low heatmap density [default: 0.6]
--bbox-area-threshold <BBOX_AREA_THRESHOLD>
a pixel threshold to filter out small area bbox [default: 10]
--recognition-batch-size <RECOGNITION_BATCH_SIZE>
recognition batch size, if not supplied defaults to 8 on CPU and 256 on GPU
--recognition-model-repo <RECOGNITION_MODEL_REPO>
recognition model's hugging face repo [default: vikp/surya_rec]
--output-dir <OUTPUT_DIR>
output directory, under which the input image will be generating a subdirectory [default: ./surya_output]
--polygons
whether to output polygons json file
--image
Expand All @@ -86,8 +92,6 @@ Options:
whether to generate heatmap
--affinity-map
whether to generate affinity map
--output-dir <OUTPUT_DIR>
output directory, under which the input image will be generating a subdirectory [default: ./surya_output]
--device <DEVICE_TYPE>
device type, if not specified will try to use GPU or Metal [possible values: cpu, gpu, metal]
--verbose
Expand Down
46 changes: 31 additions & 15 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ struct Cli {
long,
help = "detection batch size, if not supplied defaults to 2 on CPU and 16 on GPU"
)]
batch_size: Option<usize>,
detection_batch_size: Option<usize>,

#[arg(
long,
default_value = "vikp/line_detector",
default_value = "vikp/surya_det",
help = "detection model's hugging face repo"
)]
model_repo: String,
detection_model_repo: String,

#[arg(
long,
Expand Down Expand Up @@ -95,6 +95,26 @@ struct Cli {
)]
bbox_area_threshold: usize,

#[arg(
long,
help = "recognition batch size, if not supplied defaults to 8 on CPU and 256 on GPU"
)]
recognition_batch_size: Option<usize>,

#[arg(
long,
default_value = "vikp/surya_rec",
help = "recognition model's hugging face repo"
)]
recognition_model_repo: String,

#[arg(
long,
default_value = "./surya_output",
help = "output directory, under which the input image will be generating a subdirectory"
)]
output_dir: PathBuf,

#[arg(
long = "polygons",
default_value_t = true,
Expand Down Expand Up @@ -123,13 +143,6 @@ struct Cli {
)]
generate_affinity_map: bool,

#[arg(
long,
default_value = "./surya_output",
help = "output directory, under which the input image will be generating a subdirectory"
)]
output_dir: PathBuf,

#[arg(
long = "device",
value_enum,
Expand All @@ -142,14 +155,17 @@ struct Cli {
}

impl Cli {
fn get_model(
fn get_detection_model(
&self,
device: &Device,
num_labels: usize,
) -> surya::Result<SemanticSegmentationModel> {
let api = ApiBuilder::new().with_progress(true).build()?;
let repo = api.model(self.model_repo.clone());
debug!("using model from HuggingFace repo {0}", self.model_repo);
let repo = api.model(self.detection_model_repo.clone());
debug!(
"using model from HuggingFace repo {0}",
self.detection_model_repo
);
let model_file = repo.get(&self.weights_file_name)?;
debug!("using weights file '{0}'", self.weights_file_name);
let vb = unsafe {
Expand Down Expand Up @@ -202,9 +218,9 @@ fn main() -> surya::Result<()> {
.create(output_dir.clone())?;
info!("generating output to {:?}", output_dir);

let model = args.get_model(&device, NUM_LABELS)?;
let model = args.get_detection_model(&device, NUM_LABELS)?;

let batch_size = args.batch_size.unwrap_or(match device {
let batch_size = args.detection_batch_size.unwrap_or(match device {
Device::Cpu => 2,
Device::Cuda(_) | Device::Metal(_) => 16,
});
Expand Down

0 comments on commit f6e38f9

Please sign in to comment.