Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ pandas
# extras --------------------------------------
thop # FLOPS computation
pycocotools>=2.0 # COCO mAP
sparseml~=0.2
104 changes: 86 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from sparseml.pytorch.nn import replace_activations
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
from sparseml.pytorch.utils import PythonLogger, TensorBoardLogger, ModuleExporter
from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize

import test # import test.py to get mAP after each epoch
from models.experimental import attempt_load
from models.yolo import Model
Expand Down Expand Up @@ -59,6 +64,7 @@ def train(hyp, opt, device, tb_writer=None):
# Configure
plots = not opt.evolve # create plots
cuda = device.type != 'cpu'
half_precision = cuda and not opt.disable_amp
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict
Expand Down Expand Up @@ -87,7 +93,7 @@ def train(hyp, opt, device, tb_writer=None):
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = ckpt['model'].float().state_dict() if isinstance(ckpt['model'], nn.Module) else ckpt['model']
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
Expand Down Expand Up @@ -141,7 +147,7 @@ def train(hyp, opt, device, tb_writer=None):
# plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
ema = ModelEMA(model, enabled=not opt.disable_ema) if rank in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
Expand All @@ -153,8 +159,7 @@ def train(hyp, opt, device, tb_writer=None):

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']
ema.load_state_dict(ckpt)

# Results
if ckpt.get('training_results') is not None:
Expand Down Expand Up @@ -214,7 +219,8 @@ def train(hyp, opt, device, tb_writer=None):
# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision
if half_precision:
model.half().float() # pre-reduce anchor precision

# DDP mode
if cuda and rank != -1:
Expand All @@ -233,14 +239,50 @@ def train(hyp, opt, device, tb_writer=None):
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names

# SparseML Integration
if opt.use_leaky_relu: # use LeakyReLU activations
model = replace_activations(model, 'lrelu', inplace=True)

qat = False
if opt.sparseml_recipe:
manager = ScheduledModifierManager.from_yaml(opt.sparseml_recipe)
optimizer = ScheduledOptimizer(
optimizer,
model if not is_parallel(model) else model.module,
manager,
steps_per_epoch=len(dataloader),
loggers=[PythonLogger(), TensorBoardLogger(writer=tb_writer)]
)
# override lr scheduler if recipe makes any LR updates
if manager.learning_rate_modifiers:
logger.info('Disabling LR scheduler, managing LR using SparseML recipe')
scheduler = None
# override num epochs if recipe explicitly modifies epoch range
if manager.epoch_modifiers and manager.max_epochs:
epochs = manager.max_epochs or epochs # override num_epochs
logger.info(f'Overriding number of epochs from SparseML manager to {manager.max_epochs}')
# mark that QAT will be applied, pickled QAT exports currently not supported
if manager.quantization_modifiers:
logger.info('Disabling pickling for model exports, QAT scheduled to run')
if not opt.use_leaky_relu:
logger.warning(
'QAT detected in sparsification recipe, but --use-leaky-relu not set '
'quantized model may not run well with default activations'
)
qat = True
# make sure that sparsity structure is held during EMA updates
if ema and manager.pruning_modifiers:
ema.pruning_manager = manager

# Start training
t0 = time.time()
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
if scheduler:
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=half_precision)
compute_loss = ComputeLoss(model) # init loss class
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
f'Using {dataloader.num_workers} dataloader workers\n'
Expand Down Expand Up @@ -286,7 +328,8 @@ def train(hyp, opt, device, tb_writer=None):
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
if scheduler:
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

Expand All @@ -299,7 +342,7 @@ def train(hyp, opt, device, tb_writer=None):
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
with amp.autocast(enabled=cuda):
with amp.autocast(enabled=half_precision):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1:
Expand Down Expand Up @@ -342,7 +385,8 @@ def train(hyp, opt, device, tb_writer=None):

# Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
scheduler.step()
if scheduler:
scheduler.step()

# DDP process 0 or single-GPU
if rank in [-1, 0]:
Expand All @@ -354,15 +398,16 @@ def train(hyp, opt, device, tb_writer=None):
results, maps, times = test.test(data_dict,
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
model=ema.ema if ema.enabled else model,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss,
is_coco=is_coco)
is_coco=is_coco,
half_precision=half_precision)

# Write
with open(results_file, 'a') as f:
Expand All @@ -389,14 +434,16 @@ def train(hyp, opt, device, tb_writer=None):

# Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
ckpt_model = deepcopy(model.module if is_parallel(model) else model)
if qat:
ckpt_model = model.state_dict() # pickled QAT exports not currently supported
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'model': ckpt_model.half() if half_precision else ckpt_model,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
ckpt.update(ema.state_dict(half_precision=half_precision)) # add EMA model and updates if enabled

# Save last, best and delete
torch.save(ckpt, last)
Expand All @@ -422,23 +469,39 @@ def train(hyp, opt, device, tb_writer=None):
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for m in (last, best) if best.exists() else (last): # speed, mAP tests
test_model = attempt_load(m, device) if not qat else model
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
model=test_model.half() if half_precision else test_model,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)
is_coco=is_coco,
half_precision=half_precision)

# ONNX export
if opt.export_onnx:
try:
onnx_path = f'{save_dir}/model.onnx'
logger.info(f'training complete, exporting ONNX to {onnx_path}')
export_model = model.module if is_parallel_model(model) else model
export_model.model[-1].export = True # do not export grid post-procesing
exporter = ModuleExporter(export_model, save_dir)
exporter.export_onnx(torch.randn(1, 3, imgsz, imgsz), convert_qat=True)
if qat:
skip_onnx_input_quantize(onnx_path, onnx_path)
except Exception as e:
logger.warning(f'exception occured during ONNX export, model not exported to ONNX. error message {e}')

# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
if f.exists() and not qat: # qat state dict incompatible
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
Expand Down Expand Up @@ -489,6 +552,11 @@ def train(hyp, opt, device, tb_writer=None):
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--sparseml-recipe', type=str, default=None, help='Path to a SparseML sparsification recipe, see <TODO: Link Here> for more information')
parser.add_argument('--use-leaky-relu', action='store_true', help='Override default SiLU activation with LeakyReLU')
parser.add_argument('--export-onnx', action='store_true', help='export final model to ONNX')
parser.add_argument('--disable-amp', action='store_true', help='Disable FP16 half precision (enabled by default)')
parser.add_argument('--disable-ema', action='store_true', help='Disable EMA model updates (enabled by default)')
opt = parser.parse_args()

# Set DDP variables
Expand Down
35 changes: 34 additions & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,38 @@ class ModelEMA:
GPU assignment and distributed training wrappers.
"""

def __init__(self, model, decay=0.9999, updates=0):
def __init__(self, model, decay=0.9999, updates=0, enabled=True):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
self.enabled = enabled
self.pruning_manager = None # type: sparseml.pytorch.optim.ScheduledModifierManager
for p in self.ema.parameters():
p.requires_grad_(False)

def state_dict(self, half_precision=True):
if not self.enabled:
return {}

ema = deepcopy(self.ema)
return {
'ema': ema.half() if half_precision else ema,
'updates': self.updates,
}

def load_state_dict(self, state_dict):
if 'ema' in state_dict:
self.ema.load_state_dict(state_dict['ema'].float().state_dict())
if 'updates' in state_dict:
self.updates = state_dict['updates']

def update(self, model):
if not self.enabled:
return

# Update EMA parameters
with torch.no_grad():
self.updates += 1
Expand All @@ -299,5 +320,17 @@ def update(self, model):
v += (1. - d) * msd[k].detach()

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
if not self.enabled:
return

# store pre-ema sparsity masks
if self.pruning_manager is not None:
pruning_dict = self.pruning_manager.state_dict()

# Update EMA attributes
copy_attr(self.ema, model, include, exclude)

# restore sparsity structure post-ema
if self.pruning_manager is not None:
self.pruning_manager.load_state_dict(pruning_dict)
del pruning_dict