Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The student model maintains an mAP of 0 throughout the distiller training process. #619

Open
Rax-Lie opened this issue Dec 22, 2023 · 0 comments

Comments

@Rax-Lie
Copy link

Rax-Lie commented Dec 22, 2023

Checklist

  • I have searched related issues but cannot get the expected help.
  • I have read related documents and don't know what to do.

Describe the question you meet

I modified the configuration of MGD in the mmrazor demo, using YOLOv5_X from mmYOLO as the teacher model and YOLOv5_S as the student model for knowledge distillation. However, during training, the student model's mAP remained at 0 and did not improve even after training up to the 12th epoch.

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch"
mmcv                   2.0.1
mmrazor                1.0.0
torch                  1.11.0
torchaudio             0.11.0
torchvision            0.12.0
  1. The Config File of distiller
_base_ = [
    'mmyolo::yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
]
teacher_ckpt = 'download/yolov5_x-v61_syncbn_fast_8xb16-300e_coco_20230305_152943-00776a4b.pth'  # noqa: E501
# student_ckpt = 'download/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'

student = _base_.model
student.neck.init_cfg = dict(
    type='Pretrained', prefix='neck.', checkpoint=teacher_ckpt)
student.bbox_head.init_cfg = dict(
    type='Pretrained', prefix='bbox_head.', checkpoint=teacher_ckpt)


max_epochs = 15
num_last_epochs = 15
interval = 1

batch_size = 64
num_workers = 32

custom_imports = dict(
    imports =
    [
        'module.dataset.yolov5.coco_dataloader'
    ],
    allow_failed_imports=False
)

# custom_hooks = []

default_hooks = dict(
    checkpoint=dict(
        interval=interval,
        max_keep_ckpts=max_epochs
    ))

model = dict(
    _scope_='mmrazor',
    _delete_=True,
    type='SingleTeacherDistill',
    architecture=student,
    teacher=dict(
        cfg_path='mmyolo::yolov5/yolov5_x-v61_syncbn_fast_8xb16-300e_coco.py',
        pretrained=False),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        student_recorders=dict(
            fpn0=dict(type='ModuleOutputs', source='neck.top_down_layers.0.1.conv'),
            fpn1=dict(type='ModuleOutputs', source='neck.bottom_up_layers.0.final_conv.conv'),
            fpn2=dict(type='ModuleOutputs', source='neck.bottom_up_layers.1.final_conv.conv')),
        teacher_recorders=dict(
            fpn0=dict(type='ModuleOutputs', source='neck.top_down_layers.0.1.conv'),
            fpn1=dict(type='ModuleOutputs', source='neck.bottom_up_layers.0.final_conv.conv'),
            fpn2=dict(type='ModuleOutputs', source='neck.bottom_up_layers.1.final_conv.conv')),
        distill_losses=dict(
            loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002)),
        connectors=dict(
            s_fpn0_connector=dict(
                type='MGDConnector',
                student_channels=128,
                teacher_channels=320,
                lambda_mgd=0.65),
            s_fpn1_connector=dict(
                type='MGDConnector',
                student_channels=256,
                teacher_channels=640,
                lambda_mgd=0.65),
            s_fpn2_connector=dict(
                type='MGDConnector',
                student_channels=512,
                teacher_channels=1280,
                lambda_mgd=0.65)),
        loss_forward_mappings=dict(
            loss_mgd_fpn0=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn0',
                    connector='s_fpn0_connector'),
                preds_T=dict(from_student=False, recorder='fpn0')),
            loss_mgd_fpn1=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn1',
                    connector='s_fpn1_connector'),
                preds_T=dict(from_student=False, recorder='fpn1')),
            loss_mgd_fpn2=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn2',
                    connector='s_fpn2_connector'),
                preds_T=dict(from_student=False, recorder='fpn2')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

train_cfg = dict(max_epochs=max_epochs,
                 val_interval=interval)

train_dataloader = dict(
    batch_size=batch_size,
    collate_fn=dict(_delete_=True, type='yolov5_collate_local'),
    num_workers=num_workers)

val_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers)
  1. log
2023/12/22 15:33:37 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
2023/12/22 15:33:37 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
2023/12/22 15:33:37 - mmengine - INFO - Checkpoints will be saved to /mnt/share_disk/rax/KD_TIDL/work_dirs/mgd_distiller_T_yolov5_x_S_yolov5_s.
2023/12/22 15:43:16 - mmengine - INFO - Epoch(train)  [1][ 50/463]  base_lr: 1.0000e-02 lr: 3.5277e-04  eta: 22:08:37  time: 11.5617  data_time: 9.4693  memory: 59166  loss: 68.8866  student.loss_cls: 24.3198  student.loss_obj: 2.3839  student.loss_bbox: 38.9102  distill.loss_mgd_fpn0: 1.0803  distill.loss_mgd_fpn1: 1.1973  distill.loss_mgd_fpn2: 0.9951
2023/12/22 15:49:34 - mmengine - INFO - Epoch(train)  [1][100/463]  base_lr: 1.0000e-02 lr: 7.1274e-04  eta: 18:11:12  time: 7.5684  data_time: 6.0495  memory: 44745  loss: 65.5584  student.loss_cls: 23.7124  student.loss_obj: 0.0849  student.loss_bbox: 38.7415  distill.loss_mgd_fpn0: 1.0470  distill.loss_mgd_fpn1: 1.0443  distill.loss_mgd_fpn2: 0.9284
2023/12/22 15:56:24 - mmengine - INFO - Epoch(train)  [1][150/463]  base_lr: 1.0000e-02 lr: 1.0727e-03  eta: 17:11:56  time: 8.2059  data_time: 6.8705  memory: 44745  loss: 64.8745  student.loss_cls: 23.1582  student.loss_obj: 0.0614  student.loss_bbox: 38.7306  distill.loss_mgd_fpn0: 1.0254  distill.loss_mgd_fpn1: 1.0186  distill.loss_mgd_fpn2: 0.8803
2023/12/22 16:03:25 - mmengine - INFO - Epoch(train)  [1][200/463]  base_lr: 1.0000e-02 lr: 1.4327e-03  eta: 16:44:46  time: 8.4156  data_time: 7.0705  memory: 44745  loss: 64.2053  student.loss_cls: 22.5574  student.loss_obj: 0.0474  student.loss_bbox: 38.7227  distill.loss_mgd_fpn0: 1.0151  distill.loss_mgd_fpn1: 1.0058  distill.loss_mgd_fpn2: 0.8570
2023/12/22 16:04:52 - mmengine - INFO - Epoch(train)  [1][250/463]  base_lr: 1.0000e-02 lr: 1.7927e-03  eta: 13:56:32  time: 1.7334  data_time: 0.0007  memory: 44745  loss: 63.7751  student.loss_cls: 22.1237  student.loss_obj: 0.0377  student.loss_bbox: 38.7310  distill.loss_mgd_fpn0: 1.0083  distill.loss_mgd_fpn1: 0.9987  distill.loss_mgd_fpn2: 0.8756
2023/12/22 16:12:30 - mmengine - INFO - Epoch(train)  [1][300/463]  base_lr: 1.0000e-02 lr: 2.1526e-03  eta: 14:21:15  time: 9.1750  data_time: 7.1599  memory: 44745  loss: 63.1501  student.loss_cls: 21.5529  student.loss_obj: 0.0314  student.loss_bbox: 38.7225  distill.loss_mgd_fpn0: 1.0016  distill.loss_mgd_fpn1: 0.9895  distill.loss_mgd_fpn2: 0.8521
2023/12/22 16:19:13 - mmengine - INFO - Epoch(train)  [1][350/463]  base_lr: 1.0000e-02 lr: 2.5126e-03  eta: 14:19:05  time: 8.0510  data_time: 6.4783  memory: 44745  loss: 62.7898  student.loss_cls: 21.2086  student.loss_obj: 0.0273  student.loss_bbox: 38.7137  distill.loss_mgd_fpn0: 0.9911  distill.loss_mgd_fpn1: 0.9804  distill.loss_mgd_fpn2: 0.8686
2023/12/22 16:25:05 - mmengine - INFO - Epoch(train)  [1][400/463]  base_lr: 1.0000e-02 lr: 2.8726e-03  eta: 14:01:54  time: 7.0338  data_time: 5.6988  memory: 44745  loss: 62.6934  student.loss_cls: 21.1535  student.loss_obj: 0.0241  student.loss_bbox: 38.7119  distill.loss_mgd_fpn0: 0.9865  distill.loss_mgd_fpn1: 0.9775  distill.loss_mgd_fpn2: 0.8399
2023/12/22 16:26:12 - mmengine - INFO - Epoch(train)  [1][450/463]  base_lr: 1.0000e-02 lr: 3.2325e-03  eta: 12:38:46  time: 1.3408  data_time: 0.0006  memory: 44745  loss: 62.6353  student.loss_cls: 21.0822  student.loss_obj: 0.0215  student.loss_bbox: 38.7099  distill.loss_mgd_fpn0: 0.9860  distill.loss_mgd_fpn1: 0.9737  distill.loss_mgd_fpn2: 0.8620
2023/12/22 16:26:34 - mmengine - INFO - Exp name: mgd_distiller_T_yolov5_x_S_yolov5_s_20231222_153257
2023/12/22 16:26:34 - mmengine - INFO - Saving checkpoint at 1 epochs
2023/12/22 16:26:36 - mmengine - WARNING - `save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers
2023/12/22 16:27:20 - mmengine - INFO - Evaluating bbox...
2023/12/22 16:27:27 - mmengine - INFO - bbox_mAP_copypaste: 0.000 0.000 0.000 0.000 0.000 0.000
2023/12/22 16:29:28 - mmengine - INFO - Evaluating bbox...
2023/12/22 16:30:30 - mmengine - INFO - bbox_mAP_copypaste: 0.502 0.685 0.548 0.346 0.553 0.635
  1. Other code you modified in the mmrazor folder.
    Used in custom_imports module/dataset/yolov5/coco_dataloader.py
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional

from mmdet.datasets import BaseDetDataset, CocoDataset

from typing import List, Sequence

import numpy as np
import torch
from mmengine.dataset import COLLATE_FUNCTIONS
from mmengine.dist import get_dist_info

@COLLATE_FUNCTIONS.register_module()
def yolov5_collate_local(data_batch: Sequence,
                   use_ms_training: bool = False) -> dict:
    """Rewrite collate_fn to get faster training speed.

    Args:
       data_batch (Sequence): Batch of data.
       use_ms_training (bool): Whether to use multi-scale training.
    """
    batch_imgs = []
    batch_bboxes_labels = []
    batch_masks = []
    batch_keyponits = []
    batch_keypoints_visible = []
    for i in range(len(data_batch)):
        datasamples = data_batch[i]['data_samples']
        inputs = data_batch[i]['inputs']
        batch_imgs.append(inputs)

        gt_bboxes = datasamples.gt_instances.bboxes.tensor
        gt_labels = datasamples.gt_instances.labels
        if 'masks' in datasamples.gt_instances:
            masks = datasamples.gt_instances.masks
            batch_masks.append(masks)
        if 'gt_panoptic_seg' in datasamples:
            batch_masks.append(datasamples.gt_panoptic_seg.pan_seg)
        if 'keypoints' in datasamples.gt_instances:
            keypoints = datasamples.gt_instances.keypoints
            keypoints_visible = datasamples.gt_instances.keypoints_visible
            batch_keyponits.append(keypoints)
            batch_keypoints_visible.append(keypoints_visible)

        batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
        bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
                                  dim=1)
        batch_bboxes_labels.append(bboxes_labels)
    collated_results = {
        'data_samples': {
            'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
        }
    }
    if len(batch_masks) > 0:
        collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)

    if len(batch_keyponits) > 0:
        collated_results['data_samples']['keypoints'] = torch.cat(
            batch_keyponits, 0)
        collated_results['data_samples']['keypoints_visible'] = torch.cat(
            batch_keypoints_visible, 0)

    if use_ms_training:
        collated_results['inputs'] = batch_imgs
    else:
        collated_results['inputs'] = torch.stack(batch_imgs, 0)
    return collated_results


@TASK_UTILS.register_module()
class BatchShapePolicy:
    """BatchShapePolicy is only used in the testing phase, which can reduce the
    number of pad pixels during batch inference.

    Args:
       batch_size (int): Single GPU batch size during batch inference.
           Defaults to 32.
       img_size (int): Expected output image size. Defaults to 640.
       size_divisor (int): The minimum size that is divisible
           by size_divisor. Defaults to 32.
       extra_pad_ratio (float):  Extra pad ratio. Defaults to 0.5.
    """

    def __init__(self,
                 batch_size: int = 32,
                 img_size: int = 640,
                 size_divisor: int = 32,
                 extra_pad_ratio: float = 0.5):
        self.img_size = img_size
        self.size_divisor = size_divisor
        self.extra_pad_ratio = extra_pad_ratio
        _, world_size = get_dist_info()
        # During multi-gpu testing, the batchsize should be multiplied by
        # worldsize, so that the number of batches can be calculated correctly.
        # The index of batches will affect the calculation of batch shape.
        self.batch_size = batch_size * world_size

    def __call__(self, data_list: List[dict]) -> List[dict]:
        image_shapes = []
        for data_info in data_list:
            image_shapes.append((data_info['width'], data_info['height']))

        image_shapes = np.array(image_shapes, dtype=np.float64)

        n = len(image_shapes)  # number of images
        batch_index = np.floor(np.arange(n) / self.batch_size).astype(
            np.int64)  # batch index
        number_of_batches = batch_index[-1] + 1  # number of batches

        aspect_ratio = image_shapes[:, 1] / image_shapes[:, 0]  # aspect ratio
        irect = aspect_ratio.argsort()

        data_list = [data_list[i] for i in irect]

        aspect_ratio = aspect_ratio[irect]
        # Set training image shapes
        shapes = [[1, 1]] * number_of_batches
        for i in range(number_of_batches):
            aspect_ratio_index = aspect_ratio[batch_index == i]
            min_index, max_index = aspect_ratio_index.min(
            ), aspect_ratio_index.max()
            if max_index < 1:
                shapes[i] = [max_index, 1]
            elif min_index > 1:
                shapes[i] = [1, 1 / min_index]

        batch_shapes = np.ceil(
            np.array(shapes) * self.img_size / self.size_divisor +
            self.extra_pad_ratio).astype(np.int64) * self.size_divisor

        for i, data_info in enumerate(data_list):
            data_info['batch_shape'] = batch_shapes[batch_index[i]]

        return data_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant