Skip to content

Commit

Permalink
Add table-transformer-detection ONNXRT example (#1314)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Oct 17, 2023
1 parent 2344905 commit 550cee2
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 75 deletions.
15 changes: 11 additions & 4 deletions examples/.config/model_params_onnxrt.json
Expand Up @@ -868,11 +868,18 @@
"main_script": "main.py",
"batch_size": 1
},
"table_transformer": {
"table_transformer_structure_recognition": {
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
"dataset_location": "/tf_dataset/dataset/PubTables-1M-Structure",
"input_model": "/tf_dataset2/models/onnx/table-transformer/model.onnx",
"main_script": "table-transformer/src/main.py",
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_structure_detr_r18.onnx",
"main_script": "patch",
"batch_size": 1
},
"table_transformer_detection": {
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_detection_detr_r18.onnx",
"main_script": "patch",
"batch_size": 1
},
"hf_codebert": {
Expand Down
8 changes: 7 additions & 1 deletion examples/README.md
Expand Up @@ -1402,7 +1402,13 @@ Intel® Neural Compressor validated examples with multiple compression technique
<td><a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qlinearops</a> / <a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qdq</a></td>
</tr>
<tr>
<td>Table Transformer</td>
<td>Table Transformer Structure Recognition</td>
<td>Object Detection</td>
<td>Post-Training Static Quantization</td>
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>
</tr>
<tr>
<td>Table Transformer Detection</td>
<td>Object Detection</td>
<td>Post-Training Static Quantization</td>
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>
Expand Down
@@ -1,7 +1,7 @@
Step-by-Step
============

This example show how to export, quantize and evaluate the DETR R18 model for table structure recognition task based on PubTables-1M dataset.
This example show how to export, quantize and evaluate 2 [DETR](https://huggingface.co/docs/transformers/model_doc/detr) R18 models on [PubTables-1M](https://huggingface.co/datasets/bsmock/pubtables-1m) dataset, one for table detection and one for table structure recognition, dubbed Table Transformers.

# Prerequisite

Expand All @@ -16,16 +16,20 @@ bash prepare.sh
## 2. Prepare Dataset

Download dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
Download PubTables-1M dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
After downloading and extracting, PubTables-1M dataset folder should contain `PubTables-1M-Structure` and `PubTables-1M-Detection` folders.

## 3. Prepare Model

```shell
wget https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth
Prepare DETR R18 model for table structure recognition.

bash export.sh --input_model=/path/to/pubtables1m_structure_detr_r18.pth \
--output_model=/path/to/export \ # model path as *.onnx
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
```
python prepare_model.py --input_model=structure_detr --output_model=pubtables1m_structure_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
```

Prepare DETR R18 model for table detection.
```
python prepare_model.py --input_model=detection_detr --output_model=pubtables1m_detection_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
```

# Run
Expand All @@ -35,15 +39,15 @@ bash export.sh --input_model=/path/to/pubtables1m_structure_detr_r18.pth \
Static quantization with QOperator format:

```bash
bash run_tuning.sh --input_model=path/to/model \ # model path as *.onnx
bash run_quant.sh --input_model=path/to/model \ # model path as *.onnx
--output_model=path/to/save \ # model path as *.onnx
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
```

## 2. Benchmark

```bash
bash run_benchmark.sh --input_model=path/to/model \ # model path as *.onnx
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
--mode=performance # or accuracy
```

This file was deleted.

Expand Up @@ -110,7 +110,7 @@ index 73ae39e..2049449 100644
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
diff --git a/src/eval.py b/src/eval.py
index e3a0565..5514db5 100644
index e3a0565..d66b318 100644
--- a/src/eval.py
+++ b/src/eval.py
@@ -4,6 +4,7 @@ Copyright (C) 2021 Microsoft Corporation
Expand Down Expand Up @@ -152,8 +152,14 @@ index e3a0565..5514db5 100644

if args.debug:
for target, pred_logits, pred_boxes in zip(targets, outputs['pred_logits'], outputs['pred_boxes']):
@@ -696,3 +703,4 @@ def eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_
print("COCO metrics summary: AP50: {:.3f}, AP75: {:.3f}, AP: {:.3f}, AR: {:.3f}".format(
pubmed_stats['coco_eval_bbox'][1], pubmed_stats['coco_eval_bbox'][2],
pubmed_stats['coco_eval_bbox'][0], pubmed_stats['coco_eval_bbox'][8]))
+ return pubmed_stats['coco_eval_bbox'][0]
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
index 74cd13c..1e5e5e9 100644
index 74cd13c..c30377d 100644
--- a/src/main.py
+++ b/src/main.py
@@ -41,6 +41,7 @@ def get_args():
Expand Down Expand Up @@ -209,7 +215,7 @@ index 74cd13c..1e5e5e9 100644

dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
"test"),
@@ -169,6 +180,29 @@ def get_data(args):
@@ -169,6 +180,28 @@ def get_data(args):
num_workers=args.num_workers)
return data_loader_test, dataset_test

Expand All @@ -234,12 +240,11 @@ index 74cd13c..1e5e5e9 100644
+ collate_fn=utils.collate_fn,
+ num_workers=args.num_workers)
+ return OXDataloader(data_loader_test, args.batch_size), dataset_test
+
+
elif args.mode == "grits" or args.mode == "grits-all":
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
"test"),
@@ -337,6 +371,20 @@ def train(args, model, criterion, postprocessors, device):
@@ -337,6 +370,20 @@ def train(args, model, criterion, postprocessors, device):

print('Total training time: ', datetime.now() - start_time)

Expand All @@ -260,7 +265,7 @@ index 74cd13c..1e5e5e9 100644

def main():
cmd_args = get_args().__dict__
@@ -350,7 +398,7 @@ def main():
@@ -350,7 +397,7 @@ def main():
print('-' * 100)

# Check for debug mode
Expand All @@ -269,7 +274,7 @@ index 74cd13c..1e5e5e9 100644
print("Running evaluation/inference in DEBUG mode, processing will take longer. Saving output to: {}.".format(args.debug_save_dir))
os.makedirs(args.debug_save_dir, exist_ok=True)

@@ -366,10 +414,33 @@ def main():
@@ -366,10 +413,35 @@ def main():

if args.mode == "train":
train(args, model, criterion, postprocessors, device)
Expand All @@ -278,7 +283,9 @@ index 74cd13c..1e5e5e9 100644
data_loader_test, dataset_test = get_data(args)
- eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_test, device)
-
+ eval_coco(args, args.input_onnx_model, criterion, postprocessors, data_loader_test, dataset_test, device)
+ ap_result = eval_coco(args, args.input_onnx_model, criterion, postprocessors, data_loader_test, dataset_test, device)
+ print("Batch size = %d" % args.batch_size)
+ print("Accuracy: %.5f" % ap_result)
+ elif args.mode == "export":
+ data_loader_test, dataset_test = get_data(args)
+ export(args, model, data_loader_test, device)
Expand All @@ -303,6 +310,6 @@ index 74cd13c..1e5e5e9 100644
+ from neural_compressor.config import BenchmarkConfig
+ config = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1)
+ fit(args.input_onnx_model, config, b_dataloader=data_loader_test)

if __name__ == "__main__":
main()
@@ -0,0 +1,100 @@
import argparse
import os
import subprocess
import sys
from urllib import request

MODEL_URLS = {"structure_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth",
"detection_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_detection_detr_r18.pth"}
MAX_TIMES_RETRY_DOWNLOAD = 5


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--input_model",
type=str,
required=False,
choices=["structure_detr", "detection_detr"],
default="structure_detr")
parser.add_argument("--output_model", type=str, required=True)
parser.add_argument("--dataset_location", type=str, required=True)
return parser.parse_args()


def progressbar(cur, total=100):
percent = '{:.2%}'.format(cur / total)
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent))
sys.stdout.flush()


def schedule(blocknum, blocksize, totalsize):
if totalsize == 0:
percent = 0
else:
percent = min(1.0, blocknum * blocksize / totalsize) * 100
progressbar(percent)


def download_model(url, retry_times=5):
model_name = url.split("/")[-1]
if os.path.isfile(model_name):
print(f"{model_name} exists, skip download")
return True

print("download model...")
retries = 0
while retries < retry_times:
try:
request.urlretrieve(url, model_name, schedule)
break
except KeyboardInterrupt:
return False
except:
retries += 1
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}")
return retries < retry_times


def export_model(input_model, output_model, dataset_location):
print("\nexport model...")

if not os.path.exists("./table-transformer"):
subprocess.run("bash prepare.sh", shell=True)

model_load_path = os.path.abspath(MODEL_URLS[input_model].split("/")[-1])
output_model = os.path.join(os.path.dirname(model_load_path), output_model)
if input_model == "detection_detr":
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Detection")
data_type = "detection"
config_file = "detection_config.json"
elif input_model == "structure_detr":
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Structure")
data_type = "structure"
config_file = "structure_config.json"
table_words_dir = os.path.join(data_root_dir, "words")

os.chdir("table-transformer/src")

command = f"python main.py \
--model_load_path {model_load_path} \
--output_model {output_model} \
--data_root_dir {data_root_dir} \
--table_words_dir {table_words_dir} \
--mode export \
--data_type {data_type} \
--device cpu \
--config_file {config_file}"

subprocess.run(command, shell=True)
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!"


def prepare_model(input_model, output_model, dataset_location):
is_download_successful = download_model(MODEL_URLS[args.input_model], MAX_TIMES_RETRY_DOWNLOAD)
if is_download_successful:
export_model(input_model, output_model, dataset_location)


if __name__ == "__main__":
args = parse_arguments()
prepare_model(args.input_model, args.output_model, args.dataset_location)
Expand Up @@ -35,15 +35,28 @@ function run_benchmark {
bash prepare.sh
fi

if [[ "${input_model}" =~ "structure" ]]; then
task_data_dir="PubTables-1M-Structure"
data_type="structure"
config_file="structure_config.json"
fi
if [[ "${input_model}" =~ "detection" ]]; then
task_data_dir="PubTables-1M-Detection"
data_type="detection"
config_file="detection_config.json"
fi

input_model=$(realpath "$input_model")

cd table-transformer/src
python main.py \
--input_onnx_model ${input_model} \
--data_root_dir ${dataset_location} \
--table_words_dir ${dataset_location}/words \
--data_root_dir "${dataset_location}/${task_data_dir}" \
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
--mode ${mode} \
--data_type structure \
--data_type ${data_type} \
--device cpu \
--config_file structure_config.json
--config_file ${config_file}
}

main "$@"
Expand Up @@ -35,16 +35,30 @@ function run_tuning {
bash prepare.sh
fi

if [[ "${input_model}" =~ "structure" ]]; then
task_data_dir="PubTables-1M-Structure"
data_type="structure"
config_file="structure_config.json"
fi
if [[ "${input_model}" =~ "detection" ]]; then
task_data_dir="PubTables-1M-Detection"
data_type="detection"
config_file="detection_config.json"
fi

input_model=$(realpath "$input_model")
output_model=$(realpath "$output_model")

cd table-transformer/src
python main.py \
--input_onnx_model ${input_model} \
--output_model ${output_model} \
--data_root_dir ${dataset_location} \
--table_words_dir ${dataset_location}/words \
--data_root_dir "${dataset_location}/${task_data_dir}" \
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
--mode quantize \
--data_type structure \
--data_type ${data_type} \
--device cpu \
--config_file structure_config.json
--config_file ${config_file}
}

main "$@"

0 comments on commit 550cee2

Please sign in to comment.