Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好,您的mmtrack/datasets/pipelines/transforms.py中的SeqExpand应该是写错了,得到的序列中每个图片的增强都不同,在vid算法训练时,时序信息将会丢失。 #917

Closed
moon6666 opened this issue Nov 30, 2023 · 2 comments

Comments

@moon6666
Copy link

Thanks for your error report and we appreciate it a lot.

Checklist

class SeqExpand(Expand):

def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

def __call__(self, results):
    """Call function.

    For each dict in results, call the call function of `Expand` to Expand image.

    Args:
        results (list[dict]): List of dict that from
            :obj:`mmtrack.CocoVideoDataset`.

    Returns:
        list[dict]: List of dict that contains padding results,
        'pad_shape', 'pad_fixed_size' and 'pad_size_divisor' keys are
        added into the dict.
    """
    outs = []
    for _results in results:
        _results = super().__call__(_results)
        outs.append(_results)
    return outs

修改建议:可以添加一个随机种子来控制,random函数。

@moon6666
Copy link
Author

也可以直接在mmdetection的mmdet/datasets/dataset_wrappers.py中添加一个wrapper,参考我写的,这样就可以方便调用所有mmdetection中实现的数据增强了。
@DATASETS.register_module()
class SeqMultiImageMixDataset:

def __init__(self,
             dataset,
             pipeline,
             dynamic_scale=None,
             skip_type_keys=None,
             max_refetch=15):
    if dynamic_scale is not None:
        raise RuntimeError(
            'dynamic_scale is deprecated. Please use Resize pipeline '
            'to achieve similar functions')
    assert isinstance(pipeline, collections.abc.Sequence)
    if skip_type_keys is not None:
        assert all([
            isinstance(skip_type_key, str)
            for skip_type_key in skip_type_keys
        ])
    self._skip_type_keys = skip_type_keys

    self.pipeline = []
    self.pipeline_types = []
    for transform in pipeline:
        if isinstance(transform, dict):
            self.pipeline_types.append(transform['type'])
            transform = build_from_cfg(transform, PIPELINES)
            self.pipeline.append(transform)
        else:
            raise TypeError('pipeline must be a dict')

    self.dataset = dataset
    self.CLASSES = dataset.CLASSES
    self.PALETTE = getattr(dataset, 'PALETTE', None)
    if hasattr(self.dataset, 'flag'):
        self.flag = dataset.flag
    self.num_samples = len(dataset)
    self.max_refetch = max_refetch

def __len__(self):
    return self.num_samples

def __getitem__(self, idx):
    results = copy.deepcopy(self.dataset[idx])
    for (transform, transform_type) in zip(self.pipeline,
                                           self.pipeline_types):
        if self._skip_type_keys is not None and \
                transform_type in self._skip_type_keys:
            continue

        # 1.需要多帧拼接的pipline提取拼接帧
        if hasattr(transform, 'get_indexes'):
            indexes = transform.get_indexes(self.dataset)
            # 将序列按顺序分放在不同的帧中
            for j in range(len(results)):
                if not isinstance(indexes, collections.abc.Sequence):
                    indexes = [indexes]
                mix_results = [
                    copy.deepcopy(self.dataset[index][j]) for index in indexes
                ]
                if None not in mix_results:
                    results[j]['mix_results'] = mix_results
                    # print(results[j]['img_info'])
                    # print(mix_results[0]['img_info'])
                    # print(mix_results[1]['img_info'])
                    # print(mix_results[2]['img_info'])

        # 2.将准备好的帧进行transform
        if transform_type == 'VideoCollect' or transform_type == 'ConcatVideoReferences' or transform_type == 'SeqDefaultFormatBundle':
            updated_results = transform(copy.deepcopy(results))
            if updated_results is not None:
                results = updated_results
        else:
            p = random.uniform(0, 1)
            if (transform_type == 'Mosaic' or transform_type == 'RandomAffine' or transform_type == 'MixUp'
                or transform_type == 'Expand' or transform_type == 'PhotoMetricDistortion'or transform_type == 'YOLOXHSVRandomAug') and p > 0.8:
                #print("nononononono")
                continue
            seed = random.randint(1, 100)
            for j in range(len(results)):
                results[j]['seed'] = seed
                updated_results = transform(copy.deepcopy(results[j]))
                if updated_results is not None:
                    results[j] = updated_results

    return results

def update_skip_type_keys(self, skip_type_keys):
    """Update skip_type_keys. It is called by an external hook.

    Args:
        skip_type_keys (list[str], optional): Sequence of type
            string to be skip pipeline.
    """
    assert all([
        isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
    ])
    self._skip_type_keys = skip_type_keys

@moon6666
Copy link
Author

抱歉这个函数是我自己添加的,是我自己写错了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant