Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 175 additions & 35 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
import warnings
from functools import update_wrapper
from types import SimpleNamespace
from typing import Callable
from typing import Callable, Optional

import torch
import torch.utils.data
import torchvision
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.transforms.functional import InterpolationMode

import click
from sparseml.optim.helpers import load_recipe_yaml_str
from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.torchvision import presets, transforms, utils
Expand Down Expand Up @@ -64,6 +66,7 @@ def train_one_epoch(
epoch: int,
args,
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
manager=None,
model_ema=None,
scaler=None,
) -> utils.MetricLogger:
Expand Down Expand Up @@ -92,13 +95,24 @@ def train_one_epoch(
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
outputs = output = model(image)
if isinstance(output, tuple):
# NOTE: sparseml models return two things (logits & probs)
output = output[0]
loss = criterion(output, target)

if steps_accumulated % accum_steps == 0:
if manager is not None:
loss = manager.loss_update(
loss=loss,
module=model,
optimizer=optimizer,
epoch=epoch,
steps_per_epoch=len(data_loader) / accum_steps,
student_outputs=outputs,
student_inputs=image,
)

# first: do training to consume gradients
if scaler is not None:
scaler.scale(loss).backward()
Expand Down Expand Up @@ -126,11 +140,17 @@ def train_one_epoch(
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
acc1, num_correct_1, acc5, num_correct_5 = utils.accuracy(
output, target, topk=(1, 5)
)
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["acc1"].update(
acc1.item(), n=batch_size, total=num_correct_1
)
metric_logger.meters["acc5"].update(
acc5.item(), n=batch_size, total=num_correct_5
)
metric_logger.meters["imgs_per_sec"].update(
batch_size / (time.time() - start_time)
)
Expand Down Expand Up @@ -168,13 +188,19 @@ def evaluate(
output = output[0]
loss = criterion(output, target)

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
acc1, num_correct_1, acc5, num_correct_5 = utils.accuracy(
output, target, topk=(1, 5)
)
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["acc1"].update(
acc1.item(), n=batch_size, total=num_correct_1
)
metric_logger.meters["acc5"].update(
acc5.item(), n=batch_size, total=num_correct_5
)
num_processed_samples += batch_size
# gather the stats from all processes

Expand Down Expand Up @@ -355,32 +381,45 @@ def collate_fn(batch):
)

_LOGGER.info("Creating model")
if args.arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(args.rank if args.distributed else None):
model = ModelRegistry.create(
key=args.arch_key,
pretrained=args.pretrained,
pretrained_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
num_classes=num_classes,
)
elif args.arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[args.arch_key](
pretrained=args.pretrained, num_classes=num_classes
local_rank = args.rank if args.distributed else None
model, arch_key = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
device=device,
num_classes=num_classes,
)

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
distill_teacher, _ = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
pretrained_dataset=args.pretrained_teacher_dataset,
checkpoint_path=args.distill_teacher,
device=device,
num_classes=num_classes,
)
if args.checkpoint_path is not None:
load_model(args.checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
)
model.to(device)
distill_teacher = args.distill_teacher

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if version.parse(torch.__version__) >= version.parse("1.10"):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
elif args.label_smoothing > 0:
raise ValueError(
f"`label_smoothing` not supported for {torch.__version__}, "
f"try upgrading to at-least 1.10"
)
else:
criterion = nn.CrossEntropyLoss()

custom_keys_weight_decay = []
if args.bias_weight_decay is not None:
Expand Down Expand Up @@ -467,7 +506,7 @@ def collate_fn(batch):
)
checkpoint_manager = (
ScheduledModifierManager.from_yaml(checkpoint["recipe"])
if "recipe" in checkpoint
if "recipe" in checkpoint and checkpoint["recipe"] is not None
else None
)
elif args.resume:
Expand Down Expand Up @@ -495,8 +534,15 @@ def collate_fn(batch):

# load params
if checkpoint is not None:
if "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
if "optimizer" in checkpoint and not args.test_only:
if args.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
else:
warnings.warn(
"Optimizer state dict not loaded from checkpoint. Unless run is "
"resumed with the --resume arg, the optimizer will start from a "
"fresh state"
)
if model_ema and "model_ema" in checkpoint:
model_ema.load_state_dict(checkpoint["model_ema"])
if scaler and "scaler" in checkpoint:
Expand Down Expand Up @@ -532,13 +578,26 @@ def collate_fn(batch):
TensorBoardLogger(log_path=args.output_dir),
]
try:
loggers.append(WANDBLogger())
config = vars(args)
if manager is not None:
config["manager"] = str(manager)
loggers.append(WANDBLogger(init_kwargs=dict(config=config)))
except ImportError:
warnings.warn("Unable to import wandb for logging")
logger = LoggerManager(loggers)
else:
logger = LoggerManager(log_python=False)

if args.recipe is not None:
base_path = os.path.join(args.output_dir, "original_recipe.yaml")
with open(base_path, "w") as fp:
fp.write(load_recipe_yaml_str(args.recipe))
logger.save(base_path)

full_path = os.path.join(args.output_dir, "final_recipe.yaml")
manager.save(full_path)
logger.save(full_path)

steps_per_epoch = len(data_loader) / args.gradient_accum_steps

def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: int):
Expand All @@ -549,10 +608,23 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
)

if manager is not None:
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
optimizer = manager.modify(
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
manager.initialize(
model,
epoch=args.start_epoch,
loggers=logger,
distillation_teacher=distill_teacher,
)
step_wrapper = manager.modify(
model,
optimizer,
steps_per_epoch=steps_per_epoch,
epoch=args.start_epoch,
wrap_optim=scaler,
)
if scaler is None:
optimizer = step_wrapper
else:
scaler = step_wrapper

lr_scheduler = _get_lr_scheduler(
args, optimizer, checkpoint=checkpoint, manager=manager
Expand All @@ -573,7 +645,8 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
if args.distributed:
train_sampler.set_epoch(epoch)
if manager is not None and manager.qat_active(epoch=epoch):
scaler = None
if scaler is not None:
scaler._enabled = False
model_ema = None

train_metrics = train_one_epoch(
Expand All @@ -586,6 +659,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
epoch,
args,
log_metrics,
manager=manager,
model_ema=model_ema,
scaler=scaler,
)
Expand Down Expand Up @@ -616,6 +690,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
"state_dict": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"args": args,
"arch_key": arch_key,
}
if lr_scheduler:
checkpoint["lr_scheduler"] = lr_scheduler.state_dict()
Expand All @@ -635,7 +710,8 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
)
else:
checkpoint["epoch"] = -1 if epoch == max_epochs - 1 else epoch
checkpoint["recipe"] = str(manager)
if str(manager) is not None:
checkpoint["recipe"] = str(manager)

file_names = ["checkpoint.pth"]
if is_new_best:
Expand All @@ -657,6 +733,42 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
_LOGGER.info(f"Training time {total_time_str}")


def _create_model(
arch_key: Optional[str] = None,
local_rank=None,
pretrained: Optional[bool] = False,
checkpoint_path: Optional[str] = None,
pretrained_dataset: Optional[str] = None,
device=None,
num_classes=None,
):
if not arch_key or arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(local_rank):
model = ModelRegistry.create(
key=arch_key,
pretrained=pretrained,
pretrained_path=checkpoint_path,
pretrained_dataset=pretrained_dataset,
num_classes=num_classes,
)

if isinstance(model, tuple):
model, arch_key = model
elif arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[arch_key](
pretrained=pretrained, num_classes=num_classes
)
if checkpoint_path is not None:
load_model(checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model.to(device)
return model, arch_key


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
lr_scheduler = None

Expand Down Expand Up @@ -1039,6 +1151,34 @@ def new_func(*args, **kwargs):
help="Save the best validation result after the given "
"epoch completes until the end of training",
)
@click.option(
"--distill-teacher",
default=None,
type=str,
help="Teacher model for distillation (a trained image classification model)"
" can be set to 'self' for self-distillation and 'disable' to switch-off"
" distillation, additionally can also take in a SparseZoo stub",
)
@click.option(
"--pretrained-teacher-dataset",
default=None,
type=str,
help=(
"The dataset to load pretrained weights for the teacher"
"Load the default dataset for the architecture if set to None. "
"examples:`imagenet`, `cifar10`, etc..."
),
)
@click.option(
"--teacher-arch-key",
default=None,
type=str,
help=(
"The architecture key for teacher image classification model; "
"example: `resnet50`, `mobilenet`. "
"Note: Will be read from the checkpoint if not specified"
),
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down