In [2]:
import os
import yaml
import argparse
from yolov7_utils.datasets import create_dataloader
from yolov7_utils.general import colorstr, check_img_size
from yolov7_utils.loss import ComputeLoss, ComputeLossOTA
# load supernet
from nas.supernet.supernet_yolov7 import YOLOSuperNet
# load zero-cost
from nas.predictors.zeorcost_predictor.zerocost import ZeroCost

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='./yaml/yolov7_supernet.yml', help='model.yaml path')
parser.add_argument('--data', type=str, default='./yaml/data/coco.yaml', help='data.yaml path') #coco
parser.add_argument('--hyp', type=str, default='./yaml/data/hyp.scratch.p5.yaml', help='hyperparameters path')
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--workers', type=int, default=0, help='maximum number of dataloader workers')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--quad', action='store_true', help='quad dataloader')
opt = parser.parse_args(args=[])

print(opt)

Namespace(batch_size=16, cache_images=False, cfg='./yaml/yolov7_supernet.yml', data='./yaml/data/coco.yaml', hyp='./yaml/data/hyp.scratch.p5.yaml', image_weights=False, img_size=[640, 640], quad=False, rect=False, single_cls=False, workers=0)


In [4]:
with open(opt.hyp) as f:
        hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps
with open(opt.data, encoding="UTF-8") as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data_dict

device = 'cuda:0'
nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes

supernet = YOLOSuperNet(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
print(supernet)

YOLOSuperNet(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (1): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (2): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (3): Conv(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (4): BBoneELAN(
      (layers): Sequential(
        (0

In [5]:
train_path = data_dict['train']
gs = max(int(supernet.stride.max()), 32)
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size]
batch_size = opt.batch_size
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
rank = opt.global_rank

dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
                                            world_size=opt.world_size, workers=opt.workers,
                                            image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))

[34m[1mtrain: [0mScanning 'coco/train2017' images and labels... 117266 found, 1021 missing, 0 empty, 0 corrupted: 100%|██████████| 118287/118287 [00:43<00:00, 2691.61it/s]


-----------------------------
118287
-----------------------------


In [5]:
supernet.nc = nc
supernet.hyp = hyp
supernet.gr = 1.0

In [6]:
zc_pred = 'l2_norm'
compute_loss_ota = ComputeLossOTA(supernet)
train_loader = [dataloader, dataset]
zc_predictor = ZeroCost(method_type=zc_pred)

score = zc_predictor.query(net=supernet, loss=compute_loss_ota, dataloader=train_loader)

9935
158957
