Skip to content

Commit

Permalink
[Fix] Fix potential array overrun risk (#872)
Browse files Browse the repository at this point in the history
  • Loading branch information
kv-chiu committed Jul 2, 2023
1 parent 61d6af6 commit 9ea1aee
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions mmrotate/datasets/dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import tempfile
import time
import warnings
import zipfile
from collections import defaultdict
from functools import partial
Expand Down Expand Up @@ -218,17 +219,37 @@ def merge_det(self, results, nproc=4):
Args:
results (list): Testing results of the dataset.
nproc (int): number of process. Default: 4.
Returns:
list: merged results.
"""

def extract_xy(img_id):
"""Extract x and y coordinates from image ID.
Args:
img_id (str): ID of the image.
Returns:
Tuple of two integers, the x and y coordinates.
"""
pattern = re.compile(r'__(\d+)___(\d+)')
match = pattern.search(img_id)
if match:
x, y = int(match.group(1)), int(match.group(2))
return x, y
else:
warnings.warn(
"Can't find coordinates in filename, "
'the coordinates will be set to (0,0) by default.',
category=Warning)
return 0, 0

collector = defaultdict(list)
for idx in range(len(self)):
for idx, img_id in enumerate(self.img_ids):
result = results[idx]
img_id = self.img_ids[idx]
splitname = img_id.split('__')
oriname = splitname[0]
pattern1 = re.compile(r'__\d+___\d+')
x_y = re.findall(pattern1, img_id)
x_y_2 = re.findall(r'\d+', x_y[0])
x, y = int(x_y_2[0]), int(x_y_2[1])
oriname = img_id.split('__', maxsplit=1)[0]
x, y = extract_xy(img_id)
new_result = []
for i, dets in enumerate(result):
bboxes, scores = dets[:, :-1], dets[:, [-1]]
Expand All @@ -238,20 +259,20 @@ def merge_det(self, results, nproc=4):
labels = np.zeros((bboxes.shape[0], 1)) + i
new_result.append(
np.concatenate([labels, ori_bboxes, scores], axis=1))

new_result = np.concatenate(new_result, axis=0)
collector[oriname].append(new_result)

merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)
if nproc <= 1:
print('Single processing')
print('Executing on Single Processor')
merged_results = mmcv.track_iter_progress(
(map(merge_func, collector.items()), len(collector)))
else:
print('Multiple processing')
print(f'Executing on {nproc} processors')
merged_results = mmcv.track_parallel_progress(
merge_func, list(collector.items()), nproc)

# Return a zipped list of merged results
return zip(*merged_results)

def _results2submission(self, id_list, dets_list, out_folder=None):
Expand Down

0 comments on commit 9ea1aee

Please sign in to comment.