Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Conversation

@rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Feb 1, 2023

Torchvision Integration Sparse Transfer Learn Bugfix

Current State

As of now, our torchvision integration has a bug, where during
Sparse Transfer Learning the number of output classes can be
mismatched b/w saved optimizer state_dict(from pretraining) and the
re-created optimizer(when starting finetuning)

This leads to broken flows and errors whenever there is a mismatch in the
number of output classes b/w upstream and downstream datasets

For example:

Sparse Transfer Learning a resnet50 model originally trained on ImageNet to
Imagenette

COMMAND:

sparseml.image_classification.train \
    --recipe "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95-none?recipe_type=transfer-classification" \
    --checkpoint-path "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95_quant-none" \
    --arch-key resnet50 \
    --dataset-path /home/XXXXX/datasets/imagenette-160

ERROR:

Not using distributed mode
2023-01-31 13:15:44 __main__     INFO     namespace(amp=False, arch_key='resnet50', augmix_severity=3, auto_augment=None, batch_size=32, bias_weight_decay=None, cache_dataset=False, checkpoint_path='zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95_quant-none', clip_grad_norm=None, cutmix_alpha=0.0, dataset_path='/home/XXXXX/datasets/imagenette-160', device='cuda', dist_url='env://', distill_teacher=None, distributed=False, epochs=10, eval_steps=None, gradient_accum_steps=1, interpolation='bilinear', label_smoothing=0.0, logging_steps=100, lr=0.1, lr_gamma=0.1, lr_min=0.0, lr_scheduler='steplr', lr_step_size=30, lr_warmup_decay=0.01, lr_warmup_epochs=0, lr_warmup_method='constant', mixup_alpha=0.0, model_ema=False, model_ema_decay=0.99998, model_ema_steps=32, momentum=0.9, norm_weight_decay=None, opt='sgd', output_dir='.', pretrained='True', pretrained_dataset=None, pretrained_teacher_dataset=None, print_freq=None, ra_magnitude=9, ra_reps=3, ra_sampler=False, random_erase=0.0, recipe='zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95-none?recipe_type=transfer-classification', recipe_args=None, resume=None, save_best_after=1, start_epoch=0, sync_bn=False, teacher_arch_key=None, test_only=False, train_crop_size=224, transformer_embedding_decay=None, use_deterministic_algorithms=False, val_crop_size=224, val_resize_size=256, weight_decay=0.0001, workers=16, world_size=1)
2023-01-31 13:15:44 __main__     INFO     Loading data
2023-01-31 13:15:44 __main__     INFO     Loading training data
2023-01-31 13:15:44 __main__     INFO     Took 0.024486064910888672
2023-01-31 13:15:44 __main__     INFO     Loading validation data
2023-01-31 13:15:44 __main__     INFO     Creating data loaders
2023-01-31 13:15:44 __main__     INFO     Creating model
/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py:568: UserWarning: Unable to import wandb for logging
  warnings.warn("Unable to import wandb for logging")
2023-01-31 13:15:56 __main__     INFO     Start training
INFO:__main__:Start training
Traceback (most recent call last):
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py", line 1175, in <module>
    cli()
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py", line 852, in new_func
    return f(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/click/decorators.py", line 26, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py", line 1171, in cli
    main(SimpleNamespace(**kwargs))
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py", line 634, in main
    train_metrics = train_one_epoch(
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py", line 129, in train_one_epoch
    optimizer.step()
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/optim/manager.py", line 172, in step
    return self._perform_wrapped_step(*args, **kwargs)
  File "/home/XXXXX/projects/sparseml/src/sparseml/pytorch/optim/manager.py", line 223, in _perform_wrapped_step
    ret = self._wrapped.step(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
    return func(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/torch/optim/sgd.py", line 146, in step
    sgd(params_with_grad,
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/torch/optim/sgd.py", line 197, in sgd
    func(params,
  File "/home/XXXXX/virtual_environments/sparseml3.8/lib/python3.8/site-packages/torch/optim/sgd.py", line 233, in _single_tensor_sgd
    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
RuntimeError: The size of tensor a (1000) must match the size of tensor b (10) at non-singleton dimension 0

Proposed Fix

Proposed fix, is to only load in optimizer state_dict if
a previous training run is to be resumed or when --resume
flag is set. A finetuning run should be considered as a new
run where the optimizer state does NOT need to be loaded in.

We also raise a warning when the optim state_dict is not loaded

After This Pull Request

The original command works as expected, even when the number of output classes
are mismatched b/w upstream(ImageNet has 1000 classes) and
downstream(Imagenette has 10 classes) dataset

OUTPUT:

Not using distributed mode
2023-02-01 12:23:40 __main__     INFO     namespace(amp=False, arch_key='resnet50', augmix_severity=3, auto_augment=None, batch_size=32, bias_weight_decay=None, cache_dataset=False, checkpoint_path='zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95_quant-none', clip_grad_norm=None, cutmix_alpha=0.0, dataset_path='/home/XXXXX/datasets/imagenette-160', device='cuda', dist_url='env://', distill_teacher=None, distributed=False, epochs=10, eval_steps=None, gradient_accum_steps=1, interpolation='bilinear', label_smoothing=0.0, logging_steps=100, lr=0.1, lr_gamma=0.1, lr_min=0.0, lr_scheduler='steplr', lr_step_size=30, lr_warmup_decay=0.01, lr_warmup_epochs=0, lr_warmup_method='constant', mixup_alpha=0.0, model_ema=False, model_ema_decay=0.99998, model_ema_steps=32, momentum=0.9, norm_weight_decay=None, opt='sgd', output_dir='.', pretrained='True', pretrained_dataset=None, pretrained_teacher_dataset=None, print_freq=None, ra_magnitude=9, ra_reps=3, ra_sampler=False, random_erase=0.0, recipe='zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95-none?recipe_type=transfer-classification', recipe_args=None, resume=None, save_best_after=1, start_epoch=0, sync_bn=False, teacher_arch_key=None, test_only=False, train_crop_size=224, transformer_embedding_decay=None, use_deterministic_algorithms=False, val_crop_size=224, val_resize_size=256, weight_decay=0.0001, workers=16, world_size=1)
2023-02-01 12:23:40 __main__     INFO     Loading data
2023-02-01 12:23:40 __main__     INFO     Loading training data
2023-02-01 12:23:40 __main__     INFO     Took 0.025422334671020508
2023-02-01 12:23:40 __main__     INFO     Loading validation data
2023-02-01 12:23:40 __main__     INFO     Creating data loaders
2023-02-01 12:23:40 __main__     INFO     Creating model
/home/XXXXX/projects/sparseml/src/sparseml/pytorch/torchvision/train.py:600: UserWarning: Unable to import wandb for logging
  warnings.warn("Unable to import wandb for logging")
2023-02-01 12:23:54 __main__     INFO     Start training
INFO:__main__:Start training
2023-02-01 12:23:56 __main__     INFO     Epoch: [0]  [  0/403]  eta: 0:15:02  lr: 0.0005  imgs_per_sec: 20.97494809307896  loss: 2.3581 (2.3581)  acc1: 6.2500 (6.2500)  acc5: 53.1250 (53.1250)  time: 2.2404  data: 0.7147  max mem: 8859
INFO:__main__:Epoch: [0]  [  0/403]  eta: 0:15:02  lr: 0.0005  imgs_per_sec: 20.97494809307896  loss: 2.3581 (2.3581)  acc1: 6.2500 (6.2500)  acc5: 53.1250 (53.1250)  time: 2.2404  data: 0.7147  max mem: 8859
2023-02-01 12:24:16 __main__     INFO     Epoch: [0]  [100/403]  eta: 0:01:05  lr: 0.0004992559437030161  imgs_per_sec: 164.01919577663585  loss: 0.2617 (0.7927)  acc1: 90.6250 (79.9505)  acc5: 96.8750 (94.9257)  time: 0.1959  data: 0.0001  max mem: 8859
INFO:__main__:Epoch: [0]  [100/403]  eta: 0:01:05  lr: 0.0004992559437030161  imgs_per_sec: 164.01919577663585  loss: 0.2617 (0.7927)  acc1: 90.6250 (79.9505)  acc5: 96.8750 (94.9257)  time: 0.1959  data: 0.0001  max mem: 8859

NOTE: The proposed fixes are in relation to this TICKET

@rahul-tuli rahul-tuli requested review from KSGulin, bfineran, corey-nm and dbogunowicz and removed request for bfineran February 1, 2023 17:45
@rahul-tuli rahul-tuli self-assigned this Feb 1, 2023
@rahul-tuli rahul-tuli added bug Something isn't working mle-team labels Feb 1, 2023
KSGulin
KSGulin previously approved these changes Feb 2, 2023
@rahul-tuli rahul-tuli merged commit a60a22e into main Feb 2, 2023
@rahul-tuli rahul-tuli deleted the torchvision-sparse-transfer-learn-bugfix branch February 2, 2023 16:15
rahul-tuli added a commit that referenced this pull request Feb 2, 2023
…rning (#1358)

* Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict

* Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set

* Remove: un-needed imports

* Address review comments

* Style
bfineran pushed a commit that referenced this pull request Feb 2, 2023
…rning (#1358)

* Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict

* Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set

* Remove: un-needed imports

* Address review comments

* Style
bfineran pushed a commit that referenced this pull request Feb 3, 2023
…rning (#1358)

* Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict

* Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set

* Remove: un-needed imports

* Address review comments

* Style
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

bug Something isn't working mle-team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants