Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Result on HRSC2016 #8

Closed
ming71 opened this issue Sep 27, 2020 · 15 comments
Closed

Result on HRSC2016 #8

ming71 opened this issue Sep 27, 2020 · 15 comments

Comments

@ming71
Copy link

ming71 commented Sep 27, 2020

Hi, I've obtained the same results reported in paper on hrsc2016 with s2anet.
But detection results with RetinaNet are not good enough, what's wong with my configs:

# model settings
model = dict(
    type='RBoxRetinaNet',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        add_extra_convs=True,
        num_outs=5),
    rbox_head=dict(
        type='RBoxRetinaHead',
        num_classes=2,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        octave_base_scale=4,
        scales_per_octave=3,
        anchor_ratios=[0.5, 1.0, 2.0],
        anchor_strides=[8, 16, 32, 64, 128],
        target_means=[.0, .0, .0, .0, .0],
        target_stds=[1.0, 1.0, 1.0, 1.0, 1.0],
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
# training and testing settings
train_cfg = dict(
#     anchor_target_type='hbb_obb_rbox_overlap',
    assigner=dict(
        type='MaxIoUAssigner',
#         type='MaxIoUAssignerRbox',
        pos_iou_thr=0.5,
        neg_iou_thr=0.4,
        min_pos_iou=0,
        ignore_iof_thr=-1),
    allowed_border=-1,
    pos_weight=-1,
    debug=False)
test_cfg = dict(
    nms_pre=2000,
    min_bbox_size=0,
    score_thr=0.05,
    nms=dict(type='nms_rotated', iou_thr=0.1),#15fps
    max_per_img=2000)
# dataset settings
dataset_type = 'HRSC2016Dataset'
data_root = 'data/HRSC2016/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='RotatedResize', img_scale=(800, 800), keep_ratio=True),
    dict(type='RotatedRandomFlip', flip_ratio=0.5),
    dict(type='RandomRotate', rate=0.5, angles=[30, 60, 90, 120, 150], auto_bound=False),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(800, 800),
        flip=False,
        transforms=[
            dict(type='RotatedResize',img_scale=(800, 800), keep_ratio=True),
            dict(type='RotatedRandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'ImageSets/train.json',
        img_prefix=data_root + 'Train/images/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'ImageSets/val.json',
        img_prefix=data_root + 'Val/images/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'ImageSets/test.json',
        img_prefix=data_root + 'Test/images/',
        pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
# optimizer = dict(type='Adam', lr=1e-4)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3,
    step=[24, 33])
checkpoint_config = dict(interval=12)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 36
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/retinanet_obb_r50_fpn_1x_hrsc2016/'
load_from = None
resume_from = None
workflow = [('train', 1)]
@ming71
Copy link
Author

ming71 commented Sep 27, 2020

S2ANet reaches the mAP higher than 89%, while RetinaNet obtains only 56%.

@csuhan
Copy link
Owner

csuhan commented Sep 27, 2020

HRSC2016 is very sensitive to the hyperparameters, e.g., lr, schedules, warmup_iters.
My experiments on HRSC2016 with Retinanet get poor preformance as well. But I think it will be better by carefully adjusting the hyperparameters.
You can validate the AP with different lr (e.g., 0.05 for 4gpus), schedules (e.g., 24 epochs), less warmup_iters and so on.

@ming71
Copy link
Author

ming71 commented Sep 27, 2020

HRSC2016 is very sensitive to the hyperparameters, e.g., lr, schedules, warmup_iters.
My experiments on HRSC2016 with Retinanet get poor preformance as well. But I think it will be better by carefully adjusting the hyperparameters.
You can validate the AP with different lr (e.g., 0.05 for 4gpus), schedules (e.g., 24 epochs), less warmup_iters and so on.

But it works well on DOTA... it's amazing that with the same config file, it reached mAP higher than 70%

@Fly-dream12
Copy link

Thanks for your code, i can not reproduce the same results on HRSC2016 as reported in your paper, what may be wrong?
I have changed the lr to 0.001, and the other settings are the same.
@ csuhan @ ming71

@Fly-dream12
Copy link

When I keep the learning rate as 0.01, the loss becomes too large as this:
2020-10-26 10:45:04,367 - INFO - Epoch [7][200/219] lr: 0.01000, eta: 0:24:59, time: 0.249, data_time: 0.002, memory: 2323, loss_fam_cls: 0.8456, loss_fam_bbox: 0.9042, loss_odm_cls: 0.3194, loss_odm_bbox: 58701399372.6484, loss: 58701399374.7175

@ming71
Copy link
Author

ming71 commented Oct 26, 2020

I don‘t know. I achieved the mAP of 89+% with the original settings for s2anet.

@Fly-dream12
Copy link

But the result in the original paper is 90.17. @ming71

@ming71
Copy link
Author

ming71 commented Oct 27, 2020

But the result in the original paper is 90.17. @ming71

Note that I just run about 20 epochs to achieve the performance, not the whole scheduler.

@Fly-dream12
Copy link

I have run 36 epochs, what is the lr in your config 0.01? when i trained the model under 0.01 learning rate, the loss become too large to converge. @ming71

@ming71
Copy link
Author

ming71 commented Oct 27, 2020

I have run 36 epochs, what is the lr in your config 0.01? when i trained the model under 0.01 learning rate, the loss become too large to converge. @ming71

I trained via dist_train.sh with 4 2080 Ti, everything is OK, maybe it's your problem.

@Fly-dream12
Copy link

I trained with single gpu @ming71

@csuhan
Copy link
Owner

csuhan commented Nov 4, 2020

Refer to https://github.com/csuhan/s2anet/blob/master/docs/GETTING_STARTED.md#train-a-model
4GPU * 2img/GPU = 0.01lr
1GPU * 2img/GPU = 0.0025lr

@Sp2-Hybrid
Copy link

When I keep the learning rate as 0.01, the loss becomes too large as this:
2020-10-26 10:45:04,367 - INFO - Epoch [7][200/219] lr: 0.01000, eta: 0:24:59, time: 0.249, data_time: 0.002, memory: 2323, loss_fam_cls: 0.8456, loss_fam_bbox: 0.9042, loss_odm_cls: 0.3194, loss_odm_bbox: 58701399372.6484, loss: 58701399374.7175

Hello,how did you solve this problem? I met the same problem like you, when lr was equal to 0.01, loss was too large to train, I reduced lr even to 0.00001, but eventually it became large, even Nan.

@ming71 ming71 closed this as completed Nov 29, 2020
@csuhan
Copy link
Owner

csuhan commented Jan 2, 2021

Hi, @ming71 , I know the reason why the mAP of RetinaNet in my codebase is so low. First, 3x is too short for RetinaNet. Then horizontal anchors make the angle prediction hard to converge.

Therefore, I changes some settings:
(1) longer training schedule: 6x RetinaNet reaches 73% mAP.
(2) more anchor angles: when set anchor_angles=[0., PI/3, PI/6, PI/2], the mAP is about 81%.

@ming71
Copy link
Author

ming71 commented Jan 3, 2021

The multiple versions of my RetinaNet also reached mAP of about 80% on HRSC2016, which are consistent with your results. Thank you again for your experiment and nice work! 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants