In [1]:
import sys
sys.path.append('/home/ma-user/work/Yolov5_for_MindSpore_1.1_code')

In [2]:
import os
import time
import argparse
import datetime
import mindspore as ms
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig

from src.yolo import YOLOV5s, YoloWithLossCell, TrainingWrapper
from src.logger import get_logger
from src.util import AverageMeter, get_param_groups
from src.lr_scheduler import get_lr
from src.yolo_dataset import create_yolo_dataset
from src.initializer import default_recurisive_init, load_yolov5_params
from src.config import ConfigYOLOV5
ms.set_seed(1)

In [3]:
def parse_args(cloud_args=None):
    """Parse train arguments."""
    parser = argparse.ArgumentParser('mindspore coco training')

    # device related
    parser.add_argument('--device_target', type=str, default='Ascend',
                        help='device where the code will be implemented.')

    # dataset related
    parser.add_argument('--data_dir', default='antigen', type=str, help='Train dataset directory.')
    parser.add_argument('--per_batch_size', default=64, type=int, help='Batch size for Training. Default: 8')

    # network related
    parser.add_argument('--pretrained_backbone', default='', type=str,
                        help='The backbone file of YOLOv5. Default: "".')
    parser.add_argument('--resume_yolov5', default='', type=str,
                        help='The ckpt file of YOLOv5, which used to fine tune. Default: ""')

    # optimizer and lr related
    parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str,
                        help='Learning rate scheduler, options: exponential, cosine_annealing. Default: exponential')
    parser.add_argument('--lr', default=0.001, type=float, help='Learning rate. Default: 0.01')
    parser.add_argument('--lr_epochs', type=str, default='120,150',
                        help='Epoch of changing of lr changing, split with ",". Default: 220,250')
    parser.add_argument('--lr_gamma', type=float, default=0.1,
                        help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
    parser.add_argument('--eta_min', type=float, default=0., help='Eta_min in cosine_annealing scheduler. Default: 0')
    parser.add_argument('--T_max', type=int, default=200, help='T-max in cosine_annealing scheduler. Default: 320')
    parser.add_argument('--max_epoch', type=int, default=200, help='Max epoch num to train the model. Default: 320')
    parser.add_argument('--warmup_epochs', default=1, type=float, help='Warmup epochs. Default: 0')
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay factor. Default: 0.0005')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9')

    # loss related
    parser.add_argument('--loss_scale', type=int, default=1024, help='Static loss scale. Default: 1024')
    parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default:0')
    parser.add_argument('--label_smooth_factor', type=float, default=0.1,
                        help='Smooth strength of original one-hot. Default: 0.1')

    # logging related
    parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
    parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
    parser.add_argument('--ckpt_interval', type=int, default=100, help='Save checkpoint interval. Default: 10')

    parser.add_argument('--is_save_on_master', type=int, default=1,
                        help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')

    # distributed related
    parser.add_argument('--is_distributed', type=int, default=1,
                        help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
    parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed. Default: 0')
    parser.add_argument('--group_size', type=int, default=1, help='World size of device. Default: 1')

    # roma obs
    parser.add_argument('--train_url', type=str, default="", help='train url')
    # profiler init
    parser.add_argument('--need_profiler', type=int, default=0,
                        help='Whether use profiler. 0 for no, 1 for yes. Default: 0')

    # reset default config
    parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""')
    parser.add_argument('--resize_rate', type=int, default=10,
                        help='Resize rate for multi-scale training. Default: None')

    args, _ = parser.parse_known_args()
    args = merge_args(args, cloud_args)
    if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max:
        args.T_max = args.max_epoch

    args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
    args.data_root = os.path.join(args.data_dir, 'train')
    args.annFile = os.path.join(args.data_dir, 'annotations/train.json')

    devid = int(os.getenv('DEVICE_ID', '0'))
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.device_target, save_graphs=False)#, device_id=devid)
    # init distributed
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatible for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    return args

def merge_args(args, cloud_args):
    args_dict = vars(args)
    if isinstance(cloud_args, dict):
        for key in cloud_args.keys():
            val = cloud_args[key]
            if key in args_dict and val:
                arg_type = type(args_dict[key])
                if arg_type is not type(None):
                    val = arg_type(val)
                args_dict[key] = val
    return args


def convert_training_shape(args_training_shape):
    training_shape = [int(args_training_shape), int(args_training_shape)]
    return training_shape




In [None]:
cloud_args=None
args = parse_args(cloud_args)
loss_meter = AverageMeter('loss')

context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
degree = 1
if args.is_distributed:
    parallel_mode = ParallelMode.DATA_PARALLEL
    degree = get_group_size()
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)

network = YOLOV5s(is_training=True)
# default is kaiming-normal
default_recurisive_init(network)
load_yolov5_params(args, network)

network = YoloWithLossCell(network)
config = ConfigYOLOV5()

config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor

if args.training_shape:
    config.multi_scale = [convert_training_shape(args.training_shape)]
if args.resize_rate:
    config.resize_rate = args.resize_rate

ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
                                    batch_size=args.per_batch_size, max_epoch=args.max_epoch,
                                    device_num=args.group_size, rank=args.rank, config=config)
args.logger.info('Finish loading dataset')

args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)

if not args.ckpt_interval:
    args.ckpt_interval = args.steps_per_epoch

lr = get_lr(args)

opt = Momentum(params=get_param_groups(network),
               learning_rate=Tensor(lr),
               momentum=args.momentum,
               weight_decay=args.weight_decay,
               loss_scale=args.loss_scale)

network = TrainingWrapper(network, opt, args.loss_scale // 2)
network.set_train()

if args.rank_save_ckpt_flag:
    # checkpoint save
    ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
    ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                   keep_checkpoint_max=ckpt_max_num)
    save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=save_ckpt_path,
                              prefix='{}'.format(args.rank))
    cb_params = _InternalCallbackParam()
    cb_params.train_network = network
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

old_progress = -1
t_end = time.time()
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)

for i, data in enumerate(data_loader):
    images = data["image"]
    input_shape = images.shape[2:4]
    images = Tensor.from_numpy(images)
    batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
    batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
    batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
    batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
    batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
    batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])
    input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
    loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
                   batch_gt_box2, input_shape)
    loss_meter.update(loss.asnumpy())

    if args.rank_save_ckpt_flag:
        # ckpt progress
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        ckpt_cb.step_end(run_context)

    if i % args.log_interval == 0:
        time_used = time.time() - t_end
        epoch = int(i / args.steps_per_epoch)
        fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
        if args.rank == 0:
            args.logger.info(
                'epoch[{}], iter[{}], {}, fps:{:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
        t_end = time.time()
        loss_meter.reset()
        old_progress = i

    if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
        cb_params.cur_epoch_num += 1

args.logger.info('==========end training===============')



2022-09-20 10:02:05,810:INFO:Args:
2022-09-20 10:02:05,812:INFO:--> device_target: Ascend
2022-09-20 10:02:05,812:INFO:--> data_dir: antigen
2022-09-20 10:02:05,813:INFO:--> per_batch_size: 64
2022-09-20 10:02:05,814:INFO:--> pretrained_backbone: 
2022-09-20 10:02:05,815:INFO:--> resume_yolov5: 
2022-09-20 10:02:05,815:INFO:--> lr_scheduler: cosine_annealing
2022-09-20 10:02:05,816:INFO:--> lr: 0.001
2022-09-20 10:02:05,817:INFO:--> lr_epochs: [120, 150]
2022-09-20 10:02:05,818:INFO:--> lr_gamma: 0.1
2022-09-20 10:02:05,819:INFO:--> eta_min: 0.0
2022-09-20 10:02:05,820:INFO:--> T_max: 200
2022-09-20 10:02:05,821:INFO:--> max_epoch: 200
2022-09-20 10:02:05,821:INFO:--> warmup_epochs: 1
2022-09-20 10:02:05,822:INFO:--> weight_decay: 0.0005
2022-09-20 10:02:05,823:INFO:--> momentum: 0.9
2022-09-20 10:02:05,824:INFO:--> loss_scale: 1024
2022-09-20 10:02:05,825:INFO:--> label_smooth: 0
2022-09-20 10:02:05,825:INFO:--> label_smooth_factor: 0.1
2022-09-20 10:02:05,826:INFO:--> log_interval: 1



loading annotations into memory...
Done (t=0.57s)
creating index...
index created!
2022-09-20 10:02:10,413:INFO:Finish loading dataset
2022-09-20 10:04:53,896:INFO:epoch[0], iter[0], loss:6831.011719, fps:0.39 imgs/sec, lr:3.2258064948109677e-06
2022-09-20 10:06:27,349:INFO:epoch[0], iter[100], loss:827.520046, fps:68.49 imgs/sec, lr:0.0003258064389228821
2022-09-20 10:07:56,619:INFO:epoch[0], iter[200], loss:88.232236, fps:71.72 imgs/sec, lr:0.0006483871256932616




2022-09-20 10:09:41,072:INFO:epoch[0], iter[300], loss:56.008349, fps:61.28 imgs/sec, lr:0.0009709677542559803
2022-09-20 10:11:29,957:INFO:epoch[1], iter[400], loss:44.020527, fps:58.78 imgs/sec, lr:0.000999938347376883
2022-09-20 10:13:16,620:INFO:epoch[1], iter[500], loss:36.825419, fps:60.03 imgs/sec, lr:0.000999938347376883
2022-09-20 10:15:07,191:INFO:epoch[1], iter[600], loss:33.036590, fps:57.89 imgs/sec, lr:0.000999938347376883
2022-09-20 10:17:00,968:INFO:epoch[2], iter[700], loss:30.637895, fps:56.26 imgs/sec, lr:0.0009997532470151782




2022-09-20 10:18:53,260:INFO:epoch[2], iter[800], loss:28.768294, fps:57.01 imgs/sec, lr:0.0009997532470151782
2022-09-20 10:20:45,438:INFO:epoch[2], iter[900], loss:26.855416, fps:57.07 imgs/sec, lr:0.0009997532470151782
2022-09-20 10:22:28,523:INFO:epoch[3], iter[1000], loss:26.020338, fps:62.10 imgs/sec, lr:0.0009994449792429805
2022-09-20 10:24:11,982:INFO:epoch[3], iter[1100], loss:25.173995, fps:61.88 imgs/sec, lr:0.0009994449792429805
2022-09-20 10:25:59,787:INFO:epoch[3], iter[1200], loss:23.334516, fps:59.37 imgs/sec, lr:0.0009994449792429805
2022-09-20 10:27:45,361:INFO:epoch[4], iter[1300], loss:22.583913, fps:60.64 imgs/sec, lr:0.0009990133112296462
2022-09-20 10:29:28,910:INFO:epoch[4], iter[1400], loss:22.355498, fps:61.82 imgs/sec, lr:0.0009990133112296462
2022-09-20 10:31:10,366:INFO:epoch[4], iter[1500], loss:21.755259, fps:63.10 imgs/sec, lr:0.0009990133112296462




2022-09-20 10:32:54,378:INFO:epoch[5], iter[1600], loss:21.342742, fps:61.56 imgs/sec, lr:0.0009984587086364627
2022-09-20 10:34:30,928:INFO:epoch[5], iter[1700], loss:19.874972, fps:66.30 imgs/sec, lr:0.0009984587086364627
2022-09-20 10:36:12,403:INFO:epoch[5], iter[1800], loss:19.054879, fps:63.08 imgs/sec, lr:0.0009984587086364627
2022-09-20 10:37:54,580:INFO:epoch[6], iter[1900], loss:18.105585, fps:62.64 imgs/sec, lr:0.0009977809386327863
2022-09-20 10:39:35,616:INFO:epoch[6], iter[2000], loss:17.835368, fps:63.35 imgs/sec, lr:0.0009977809386327863
2022-09-20 10:41:18,196:INFO:epoch[6], iter[2100], loss:18.378264, fps:62.40 imgs/sec, lr:0.0009977809386327863
2022-09-20 10:43:01,362:INFO:epoch[7], iter[2200], loss:17.577214, fps:62.07 imgs/sec, lr:0.0009969804668799043
2022-09-20 10:44:46,546:INFO:epoch[7], iter[2300], loss:16.519074, fps:60.86 imgs/sec, lr:0.0009969804668799043
2022-09-20 10:46:29,218:INFO:epoch[7], iter[2400], loss:15.912904, fps:62.34 imgs/sec, lr:0.000996980466



2022-09-20 10:53:06,908:INFO:epoch[9], iter[2800], loss:14.825684, fps:63.08 imgs/sec, lr:0.0009950118837878108
2022-09-20 10:54:49,199:INFO:epoch[9], iter[2900], loss:15.175616, fps:62.58 imgs/sec, lr:0.0009950118837878108
2022-09-20 10:56:34,367:INFO:epoch[9], iter[3000], loss:14.273780, fps:60.86 imgs/sec, lr:0.0009950118837878108
2022-09-20 10:58:11,406:INFO:epoch[10], iter[3100], loss:14.472479, fps:65.97 imgs/sec, lr:0.0009938441216945648
2022-09-20 10:59:49,929:INFO:epoch[10], iter[3200], loss:14.289655, fps:64.97 imgs/sec, lr:0.0009938441216945648




2022-09-20 11:01:31,644:INFO:epoch[10], iter[3300], loss:13.883628, fps:62.93 imgs/sec, lr:0.0009938441216945648
2022-09-20 11:03:12,175:INFO:epoch[10], iter[3400], loss:13.722390, fps:63.67 imgs/sec, lr:0.0009938441216945648
2022-09-20 11:04:51,569:INFO:epoch[11], iter[3500], loss:13.255777, fps:64.40 imgs/sec, lr:0.0009925547055900097
2022-09-20 11:06:33,272:INFO:epoch[11], iter[3600], loss:12.776893, fps:62.95 imgs/sec, lr:0.0009925547055900097




2022-09-20 11:08:19,945:INFO:epoch[11], iter[3700], loss:12.947203, fps:60.00 imgs/sec, lr:0.0009925547055900097
2022-09-20 11:10:06,181:INFO:epoch[12], iter[3800], loss:13.198394, fps:60.25 imgs/sec, lr:0.0009911436354741454
2022-09-20 11:11:50,626:INFO:epoch[12], iter[3900], loss:12.565134, fps:61.28 imgs/sec, lr:0.0009911436354741454
2022-09-20 11:13:32,452:INFO:epoch[12], iter[4000], loss:13.112863, fps:62.86 imgs/sec, lr:0.0009911436354741454




2022-09-20 11:15:10,506:INFO:epoch[13], iter[4100], loss:12.020436, fps:65.28 imgs/sec, lr:0.0009896113770082593




2022-09-20 11:16:52,193:INFO:epoch[13], iter[4200], loss:12.215127, fps:62.94 imgs/sec, lr:0.0009896113770082593
2022-09-20 11:18:38,428:INFO:epoch[13], iter[4300], loss:11.903435, fps:60.26 imgs/sec, lr:0.0009896113770082593
2022-09-20 11:20:21,811:INFO:epoch[14], iter[4400], loss:12.418098, fps:61.92 imgs/sec, lr:0.0009879583958536386
2022-09-20 11:22:02,791:INFO:epoch[14], iter[4500], loss:11.866977, fps:63.38 imgs/sec, lr:0.0009879583958536386
2022-09-20 11:23:43,557:INFO:epoch[14], iter[4600], loss:11.626442, fps:63.54 imgs/sec, lr:0.0009879583958536386
2022-09-20 11:25:25,499:INFO:epoch[15], iter[4700], loss:11.355380, fps:62.79 imgs/sec, lr:0.0009861849248409271
2022-09-20 11:27:09,487:INFO:epoch[15], iter[4800], loss:11.550692, fps:61.55 imgs/sec, lr:0.0009861849248409271
2022-09-20 11:28:52,947:INFO:epoch[15], iter[4900], loss:10.987218, fps:61.87 imgs/sec, lr:0.0009861849248409271
2022-09-20 11:30:29,842:INFO:epoch[16], iter[5000], loss:11.131780, fps:66.06 imgs/sec, lr:0.000



2022-09-20 11:54:18,786:INFO:epoch[20], iter[6400], loss:9.550386, fps:61.43 imgs/sec, lr:0.0009755282662808895
2022-09-20 11:55:56,710:INFO:epoch[20], iter[6500], loss:9.275749, fps:65.36 imgs/sec, lr:0.0009755282662808895
2022-09-20 11:57:38,028:INFO:epoch[21], iter[6600], loss:9.616993, fps:63.18 imgs/sec, lr:0.0009730426827445626
2022-09-20 11:59:24,064:INFO:epoch[21], iter[6700], loss:9.467373, fps:60.36 imgs/sec, lr:0.0009730426827445626
2022-09-20 12:01:04,875:INFO:epoch[21], iter[6800], loss:9.421648, fps:63.49 imgs/sec, lr:0.0009730426827445626
2022-09-20 12:02:42,376:INFO:epoch[22], iter[6900], loss:9.296675, fps:65.65 imgs/sec, lr:0.0009704403928481042




2022-09-20 12:04:28,076:INFO:epoch[22], iter[7000], loss:9.301019, fps:60.57 imgs/sec, lr:0.0009704403928481042
2022-09-20 12:06:16,576:INFO:epoch[22], iter[7100], loss:9.441888, fps:59.00 imgs/sec, lr:0.0009704403928481042
2022-09-20 12:08:03,558:INFO:epoch[23], iter[7200], loss:9.159651, fps:59.84 imgs/sec, lr:0.0009677220368757844
2022-09-20 12:09:40,567:INFO:epoch[23], iter[7300], loss:9.599886, fps:65.99 imgs/sec, lr:0.0009677220368757844
2022-09-20 12:11:23,471:INFO:epoch[23], iter[7400], loss:9.491762, fps:62.21 imgs/sec, lr:0.0009677220368757844
2022-09-20 12:13:06,860:INFO:epoch[24], iter[7500], loss:9.091850, fps:61.92 imgs/sec, lr:0.0009648882551118731
2022-09-20 12:14:47,589:INFO:epoch[24], iter[7600], loss:8.817103, fps:63.55 imgs/sec, lr:0.0009648882551118731
2022-09-20 12:16:31,785:INFO:epoch[24], iter[7700], loss:8.830323, fps:61.43 imgs/sec, lr:0.0009648882551118731
2022-09-20 12:18:15,460:INFO:epoch[25], iter[7800], loss:8.967200, fps:61.74 imgs/sec, lr:0.000961939746



2022-09-20 12:35:28,450:INFO:epoch[28], iter[8800], loss:8.693801, fps:61.99 imgs/sec, lr:0.0009524135384708643
2022-09-20 12:37:13,638:INFO:epoch[28], iter[8900], loss:8.807833, fps:60.85 imgs/sec, lr:0.0009524135384708643
2022-09-20 12:38:54,670:INFO:epoch[29], iter[9000], loss:8.261140, fps:63.36 imgs/sec, lr:0.0009490138036198914
2022-09-20 12:40:38,190:INFO:epoch[29], iter[9100], loss:8.150254, fps:61.84 imgs/sec, lr:0.0009490138036198914
2022-09-20 12:42:21,159:INFO:epoch[29], iter[9200], loss:8.197369, fps:62.16 imgs/sec, lr:0.0009490138036198914
2022-09-20 12:44:04,960:INFO:epoch[30], iter[9300], loss:8.088545, fps:61.66 imgs/sec, lr:0.0009455032413825393
2022-09-20 12:45:47,999:INFO:epoch[30], iter[9400], loss:7.831754, fps:62.13 imgs/sec, lr:0.0009455032413825393
2022-09-20 12:47:32,501:INFO:epoch[30], iter[9500], loss:8.109892, fps:61.26 imgs/sec, lr:0.0009455032413825393
2022-09-20 12:49:17,410:INFO:epoch[30], iter[9600], loss:8.026987, fps:61.02 imgs/sec, lr:0.000945503241