本笔记希望复现 huang gao 组的 budgeted training of vision transformers。我们基于已有的 train.py 进行修改。

首先是必要的头文件和一些 args 的处理。可以折叠该部分。

In [2]:
#!/usr/bin/env python3
""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""

"""
ADDITIONAL NOTES BY fish:
    - This script is a modified version of the original train.py script from Ross Wightman's timm library.
    - for the current version, the 
"""

import argparse
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial

import torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP

from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler

try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as ApexDDP
    from apex.parallel import convert_syncbn_model
    has_apex = True
except ImportError:
    has_apex = False

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

try:
    import wandb
    has_wandb = True
except ImportError:
    has_wandb = False

try:
    from functorch.compile import memory_efficient_fusion
    has_functorch = True
except ImportError as e:
    has_functorch = False

# fish: secondorder
from eva import KFAC as Eva
from eva import KFACParamScheduler

has_compile = hasattr(torch, 'compile')


_logger = logging.getLogger('train')

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
                    help='path to dataset (positional is *deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR',
                    help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
                   help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
                   help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
                   help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                   help='path to class to idx mapping file (default: "")')

# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
                   help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
                   help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                   help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
                   help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
                   help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
                   help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
                   help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
                   help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
                   help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
                   metavar='N N N',
                   help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
                   metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                   help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                   help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
                   help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
                   help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
                   help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
                   help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
                   help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
                   help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
                   help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
                   help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,
                   help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
                   help='Head initialization bias value')

# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
                             help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
                             help="Enable compilation w/ specified backend (default: inductor).")

# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                   help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                   help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                   help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
                   help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
                   help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                   help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
                   help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
                   help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)

# fish: EVA preconditioner parameters
group = parser.add_argument_group('Preconditioner parameters')

# eva second conditioner arguments
parser.add_argument('--use-eva', action='store_true',default=False,help='use Eva preconditioner')

# KFAC Parameters
# parser.add_argument('--kfac-name', type=str, default='inverse',
#         help='choises: %s' % kfac.kfac_mappers.keys() + ', default: '+'inverse')
parser.add_argument('--exclude-parts', type=str, default='',
        help='choises: CommunicateInverse,ComputeInverse,CommunicateFactor,ComputeFactor')
parser.add_argument('--kfac-update-freq', type=int, default=1,
                    help='iters between kfac inv ops (0 = no kfac) (default: 1)')
parser.add_argument('--kfac-cov-update-freq', type=int, default=1,
                    help='iters between kfac cov ops (default: 1)')
parser.add_argument('--kfac-update-freq-alpha', type=float, default=10,
                    help='KFAC update freq multiplier (default: 10)')
parser.add_argument('--kfac-update-freq-decay', nargs='+', type=int, default=None,
                    help='KFAC update freq schedule (default None)')
parser.add_argument('--stat-decay', type=float, default=0.95,
                    help='Alpha value for covariance accumulation (default: 0.95)')
parser.add_argument('--damping', type=float, default=0.001,
                    help='KFAC damping factor (default 0.001)')
parser.add_argument('--damping-alpha', type=float, default=0.5,
                    help='KFAC damping decay factor (default: 0.5)')
parser.add_argument('--damping-decay', nargs='+', type=int, default=None,
                    help='KFAC damping decay schedule (default None)')
parser.add_argument('--kl-clip', type=float, default=0.001,
                    help='KL clip (default: 0.001)')
parser.add_argument('--diag-blocks', type=int, default=1,
                    help='Number of blocks to approx layer factor with (default: 1)')
parser.add_argument('--diag-warmup', type=int, default=0,
                    help='Epoch to start diag block approximation at (default: 0)')
parser.add_argument('--distribute-layer-factors', action='store_true', default=None,
                    help='Compute A and G for a single layer on different workers. '
                            'None to determine automatically based on worker and '
                            'layer count.')


# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
                   help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
                   help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
                   help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
                   help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
                   help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
                   help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                   help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                   help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                   help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                   help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
                   help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                   help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
                   help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
                   help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
                   help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
                   help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                   help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
                   help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
                   help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
                   help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                   help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
                   help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
                   help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                   help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                   help='LR decay rate (default: 0.1)')

# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
                   help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                   help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
                   help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
                   help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
                   help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                   help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
                   help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
                   help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
                   help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
                   help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
                   help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
                   help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
                   help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
                   help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
                   help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
                   help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
                   help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
                   help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                   help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
                   help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
                   help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
                   help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                   help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
                   help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
                   help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                   help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                   help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                   help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                   help='Drop block rate (default: None)')

# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
                   help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
                   help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
                   help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
                   help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
                   help='Enable separate BN layers per augmentation split.')

# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
                   help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                   help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
                   help='decay factor for model weights moving average (default: 0.9998)')

# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
                   help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
                   help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
                   help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                   help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
                   help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                   help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
                   help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
                   help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
                   help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
                   help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
                   help='torch.cuda.synchronize() end of each step')
group.add_argument('--pin-mem', action='store_true', default=False,
                   help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
                   help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
                   help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
                   help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                   help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
                   help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                   help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
                   help='log training and validation metrics to wandb')


def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text

重写之后，我们尝试新建一个 budgeted training 的 model，设置为初始 stage，在 config 里我们已经将其设置为 baby model。

In [3]:
cmd_line_args = [
    "--model", "bud_small_patch16_224_baby",
    "--batch-size", "128",
    "--epochs", "15",
    "--warmup-epochs", "0",
    "--data-dir", "/home/fish/Documents/imagenet/",
    "--output", "./output",
    "--experiment", "budgeted",
    # "--num-stages", "3",
    "--amp",
    "--opt", "lion",
    "--lr-base", "1e-4",
    "--weight-decay", "1e-3",
    "--log-wandb",
]

args = parser.parse_args(cmd_line_args)
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
args

Namespace(data=None, data_dir='/home/fish/Documents/imagenet/', dataset='', train_split='train', val_split='validation', dataset_download=False, class_map='', model='bud_small_patch16_224_baby', pretrained=False, initial_checkpoint='', resume='', no_resume_opt=False, num_classes=None, gp=None, img_size=None, in_chans=None, input_size=None, crop_pct=None, mean=None, std=None, interpolation='', batch_size=128, validation_batch_size=None, channels_last=False, fuser='', grad_accum_steps=1, grad_checkpointing=False, fast_norm=False, model_kwargs={}, head_init_scale=None, head_init_bias=None, torchscript=False, torchcompile=None, opt='lion', opt_eps=None, opt_betas=None, momentum=0.9, weight_decay=0.001, clip_grad=None, clip_mode='norm', layer_decay=None, opt_kwargs={}, use_eva=False, exclude_parts='', kfac_update_freq=1, kfac_cov_update_freq=1, kfac_update_freq_alpha=10, kfac_update_freq_decay=None, stat_decay=0.95, damping=0.001, damping_alpha=0.5, damping_decay=None, kl_clip=0.001, diag_b

设置 logging，注意只运行一次该格，不然会有重复输出。

In [4]:
utils.setup_default_logging()

In [None]:
# import wandb

# wandb.init(project="jupyter-projo",
#            config={
#                "batch_size": 128,
#                "learning_rate": 0.01,
#                "dataset": "CIFAR-100",
#            })

In [None]:
# WANDB_NOTEBOOK_NAME = 'budgeted_train.ipynb'

运行下面的代码进行训练的初始设置。

In [5]:
    # utils.setup_default_logging()
    # args, args_text = _parse_args() # args 我们已经另外做了处理

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    args.prefetcher = not args.no_prefetcher
    args.grad_accum_steps = max(1, args.grad_accum_steps)
    device = utils.init_distributed_device(args)
    if args.distributed:
        _logger.info(
            'Training in distributed mode with multiple processes, 1 device per process.'
            f'Process {args.rank}, total {args.world_size}, device {args.device}.')
    else:
        _logger.info(f'Training with a single process on 1 device ({args.device}).')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    amp_dtype = torch.float16
    if args.amp:
        if args.amp_impl == 'apex':
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
            use_amp = 'apex'
            assert args.amp_dtype == 'float16'
        else:
            assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
            use_amp = 'native'
            assert args.amp_dtype in ('float16', 'bfloat16')
        if args.amp_dtype == 'bfloat16':
            amp_dtype = torch.bfloat16

    utils.random_seed(args.seed, args.rank)

    if args.fuser:
        utils.set_jit_fuser(args.fuser)
    if args.fast_norm:
        set_fast_norm()

    in_chans = 3
    if args.in_chans is not None:
        in_chans = args.in_chans
    elif args.input_size is not None:
        in_chans = args.input_size[0]

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **args.model_kwargs,
    )
    if args.head_init_scale is not None:
        with torch.no_grad():
            model.get_classifier().weight.mul_(args.head_init_scale)
            model.get_classifier().bias.mul_(args.head_init_scale)
    if args.head_init_bias is not None:
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    if utils.is_primary(args):
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')

    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not args.torchcompile
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
    )

    # fish: add eva preconditioner, not sure if to use model without ddp
    if args.use_eva:
        
        # preconditioner = Eva(model_without_ddp)
        preconditioner = Eva(
                model, lr=args.lr, factor_decay=args.stat_decay,
                damping=args.damping, kl_clip=args.kl_clip,
                fac_update_freq=args.kfac_cov_update_freq,
                kfac_update_freq=args.kfac_update_freq,
                #diag_blocks=args.diag_blocks,
                #diag_warmup=args.diag_warmup,
                #distribute_layer_factors=args.distribute_layer_factors, 
                exclude_parts=args.exclude_parts)

        kfac_param_scheduler = KFACParamScheduler(
               preconditioner,
               damping_alpha=args.damping_alpha,
               damping_schedule=args.damping_decay,
               update_freq_alpha=args.kfac_update_freq_alpha,
               update_freq_schedule=args.kfac_update_freq_decay,
               start_epoch=args.start_epoch)

        print(f"preconditioner eva is adapted")

    else:
        preconditioner = None

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        try:
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        except (AttributeError, TypeError):
            # fallback to CUDA only AMP for PyTorch < 1.10
            assert device.type == 'cuda'
            amp_autocast = torch.cuda.amp.autocast
        if device.type == 'cuda' and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    if args.torchcompile:
        # torch compile should be done after DDP
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        model = torch.compile(model, backend=args.torchcompile)

    # create the train and eval datasets
    if args.data and not args.data_dir:
        args.data_dir = args.data
    dataset_train = create_dataset(
        args.dataset,
        root=args.data_dir,
        split=args.train_split,
        is_training=True,
        class_map=args.class_map,
        download=args.dataset_download,
        batch_size=args.batch_size,
        seed=args.seed,
        repeats=args.epoch_repeats,
    )

    dataset_eval = create_dataset(
        args.dataset,
        root=args.data_dir,
        split=args.val_split,
        is_training=False,
        class_map=args.class_map,
        download=args.dataset_download,
        batch_size=args.batch_size,
    )

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_repeats=args.aug_repeats,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        device=device,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        worker_seeding=args.worker_seeding,
    )

    eval_workers = args.workers
    if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
        # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
        eval_workers = min(2, args.workers)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size or args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=eval_workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
        device=device,
    )

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.to(device=device)
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            wandb.init(project='timm', name=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
    lr_scheduler, num_epochs = create_scheduler_v2(
        optimizer,
        **scheduler_kwargs(args),
        updates_per_epoch=updates_per_epoch,
    )
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        _logger.info(
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')


Training with a single process on 1 device (cuda:0).
Model bud_small_patch16_224_baby created, param count:6101224
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.9
	crop_mode: center
Learning rate (5e-05) calculated from base learning rate (0.0001) and effective global batch size (128) with linear scaling.
Using native Torch AMP. Training in mixed precision.
ERROR: Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfabfish[0m. Use [1m`wandb login --relogin`[0m to force relogin


Scheduled epochs: 15. LR stepped per epoch.


训练 5 个 epoch 查看结果。下面第一格是训练代码，第二格是训练过程。

In [7]:
import time
def train_one_epoch(
        epoch,
        model,
        loader,
        optimizer,
        loss_fn,
        args,
        device=torch.device('cuda'),
        lr_scheduler=None,
        saver=None,
        output_dir=None,
        amp_autocast=suppress,
        loss_scaler=None,
        model_ema=None,
        mixup_fn=None,
        # fish: add preconditioner support
        preconditioner=None,
):
    start_time = time.time()
    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
        if args.prefetcher and loader.mixup_enabled:
            loader.mixup_enabled = False
        elif mixup_fn is not None:
            mixup_fn.mixup_enabled = False

    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
    has_no_sync = hasattr(model, "no_sync")
    update_time_m = utils.AverageMeter()
    data_time_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()

    model.train()

    accum_steps = args.grad_accum_steps
    last_accum_steps = len(loader) % accum_steps
    updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
    num_updates = epoch * updates_per_epoch
    last_batch_idx = len(loader) - 1
    last_batch_idx_to_accum = len(loader) - last_accum_steps

    data_start_time = update_start_time = time.time()
    optimizer.zero_grad()
    update_sample_count = 0
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_batch_idx
        need_update = last_batch or (batch_idx + 1) % accum_steps == 0
        update_idx = batch_idx // accum_steps
        if batch_idx >= last_batch_idx_to_accum:
            accum_steps = last_accum_steps

        if not args.prefetcher:
            input, target = input.to(device), target.to(device)
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # multiply by accum steps to get equivalent for full update
        data_time_m.update(accum_steps * (time.time() - data_start_time))

        def _forward():
            with amp_autocast():
                output = model(input)
                loss = loss_fn(output, target)
            if accum_steps > 1:
                loss /= accum_steps
            return loss

        def _backward(_loss):
            if loss_scaler is not None:
                loss_scaler(
                    _loss,
                    optimizer,
                    clip_grad=args.clip_grad,
                    clip_mode=args.clip_mode,
                    parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
                    create_graph=second_order,
                    need_update=need_update,
                )
            else:
                _loss.backward(create_graph=second_order)
                if need_update:
                    if args.clip_grad is not None:
                        utils.dispatch_clip_grad(
                            model_parameters(model, exclude_head='agc' in args.clip_mode),
                            value=args.clip_grad,
                            mode=args.clip_mode,
                        )
                    # fish: step the optimizer
                    if preconditioner is not None:
                        # self._scaler.step(preconditioner)
                        preconditioner.step()
                    optimizer.step()

        if has_no_sync and not need_update:
            with model.no_sync():
                loss = _forward()
                _backward(loss)
        else:
            loss = _forward()
            _backward(loss)

        if not args.distributed:
            losses_m.update(loss.item() * accum_steps, input.size(0))
        update_sample_count += input.size(0)

        if not need_update:
            data_start_time = time.time()
            continue

        num_updates += 1
        optimizer.zero_grad()
        if model_ema is not None:
            model_ema.update(model)

        if args.synchronize_step and device.type == 'cuda':
            torch.cuda.synchronize()
        time_now = time.time()
        update_time_m.update(time.time() - update_start_time)
        update_start_time = time_now

        if update_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
                update_sample_count *= args.world_size

            if utils.is_primary(args):
                _logger.info(
                    f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
                    f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)]  '
                    f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g})  '
                    f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s  '
                    f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s)  '
                    f'LR: {lr:.3e}  '
                    f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
                    f'Time: {time.time() - start_time:.3f}s'
                )

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True
                    )

        if saver is not None and args.recovery_interval and (
                (update_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(epoch, batch_idx=update_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        update_sample_count = 0
        data_start_time = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])


def validate(
        model,
        loader,
        loss_fn,
        args,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        log_suffix=''
):
    batch_time_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()
    top1_m = utils.AverageMeter()
    top5_m = utils.AverageMeter()

    model.eval()

    end = time.time()
    last_idx = len(loader) - 1
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            with amp_autocast():
                output = model(input)
                if isinstance(output, (tuple, list)):
                    output = output[0]

                # augmentation reduction
                reduce_factor = args.tta
                if reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

                loss = loss_fn(output, target)
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            if args.distributed:
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
                acc1 = utils.reduce_tensor(acc1, args.world_size)
                acc5 = utils.reduce_tensor(acc5, args.world_size)
            else:
                reduced_loss = loss.data

            if device.type == 'cuda':
                torch.cuda.synchronize()

            losses_m.update(reduced_loss.item(), input.size(0))
            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
            if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
                log_name = 'Test' + log_suffix
                _logger.info(
                    f'{log_name}: [{batch_idx:>4d}/{last_idx}]  '
                    f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f})  '
                    f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f})  '
                    f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f})  '
                    f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
                )

    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])

    return metrics

In [8]:
    try:
        for epoch in range(start_epoch, 5):
            if hasattr(dataset_train, 'set_epoch'):
                dataset_train.set_epoch(epoch)
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch,
                model,
                loader_train,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                # fish: add preconditioner
                preconditioner=preconditioner,
            )

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model,
                loader_eval,
                validate_loss_fn,
                args,
                amp_autocast=amp_autocast,
            )

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                ema_eval_metrics = validate(
                    model_ema.module,
                    loader_eval,
                    validate_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                )
                eval_metrics = ema_eval_metrics

            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Test: [   0/390]  Time: 0.911 (0.911)  Loss:   3.472 ( 3.472)  Acc@1:  37.500 ( 37.500)  Acc@5:  57.031 ( 57.031)
Test: [  50/390]  Time: 0.029 (0.156)  Loss:   4.439 ( 4.491)  Acc@1:  17.969 ( 16.008)  Acc@5:  39.062 ( 34.344)
Test: [ 100/390]  Time: 0.047 (0.142)  Loss:   5.039 ( 4.855)  Acc@1:   1.562 ( 11.626)  Acc@5:  21.094 ( 27.413)
Test: [ 150/390]  Time: 0.030 (0.142)  Loss:   4.976 ( 4.827)  Acc@1:   6.250 ( 11.455)  Acc@5:  26.562 ( 27.633)
Test: [ 200/390]  Time: 0.029 (0.140)  Loss:   5.667 ( 4.936)  Acc@1:   0.781 ( 10.821)  Acc@5:   8.594 ( 26.154)
Test: [ 250/390]  Time: 0.029 (0.141)  Loss:   4.774 ( 5.017)  Acc@1:  17.969 ( 10.327)  Acc@5:  33.594 ( 24.813)
Test: [ 300/390]  Time: 0.033 (0.139)  Loss:   5.208 ( 5.094)  Acc@1:   7.812 (  9.648)  Acc@5:  27.344 ( 23.539)
Test: [ 350/390]  Time: 0.029 (0.138)  Loss:   5.303 ( 5.147)  Acc@1:   2.344 (  9.255)  Acc@5:  12.500 ( 22.799)
Test: [ 390/390] 

In [12]:
    old_model = model
    old_optimizer = optimizer
    old_loss_scaler = loss_scaler
    old_lr_scheduler = lr_scheduler
    if args.use_eva:
        old_preconditioner = preconditioner

    args.model = args.model.replace('baby', 'child')
    # args.model = args.model.replace('child', 'man')
    # args.model = args.model.replace('man', 'baby')
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **args.model_kwargs,
    )
    if args.head_init_scale is not None:
        with torch.no_grad():
            model.get_classifier().weight.mul_(args.head_init_scale)
            model.get_classifier().bias.mul_(args.head_init_scale)
    if args.head_init_bias is not None:
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    if utils.is_primary(args):
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')

    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not args.torchcompile
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
    )

    # fish: add eva preconditioner, not sure if to use model without ddp
    if args.use_eva:
        
        # preconditioner = Eva(model_without_ddp)
        preconditioner = Eva(
                model, lr=args.lr, factor_decay=args.stat_decay,
                damping=args.damping, kl_clip=args.kl_clip,
                fac_update_freq=args.kfac_cov_update_freq,
                kfac_update_freq=args.kfac_update_freq,
                #diag_blocks=args.diag_blocks,
                #diag_warmup=args.diag_warmup,
                #distribute_layer_factors=args.distribute_layer_factors, 
                exclude_parts=args.exclude_parts)

        kfac_param_scheduler = KFACParamScheduler(
               preconditioner,
               damping_alpha=args.damping_alpha,
               damping_schedule=args.damping_decay,
               update_freq_alpha=args.kfac_update_freq_alpha,
               update_freq_schedule=args.kfac_update_freq_decay,
               start_epoch=args.start_epoch)

        print(f"preconditioner eva is adapted")

    else:
        preconditioner = None

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        try:
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        except (AttributeError, TypeError):
            # fallback to CUDA only AMP for PyTorch < 1.10
            assert device.type == 'cuda'
            amp_autocast = torch.cuda.amp.autocast
        if device.type == 'cuda' and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    if args.torchcompile:
        # torch compile should be done after DDP
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        model = torch.compile(model, backend=args.torchcompile)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.to(device=device)
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            wandb.init(project='timm',name=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
    lr_scheduler, num_epochs = create_scheduler_v2(
        optimizer,
        **scheduler_kwargs(args),
        updates_per_epoch=updates_per_epoch,
    )
    start_epoch = 5
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        _logger.info(
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

Model bud_small_patch16_224_child created, param count:11417704
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.9
	crop_mode: center
Using native Torch AMP. Training in mixed precision.


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▃▅▆█
eval_loss,█▅▃▂▁
eval_top1,▁▄▅▇█
eval_top5,▁▄▆▇█
lr,██▆▄▁
train_loss,█▄▃▂▁

0,1
epoch,4.0
eval_loss,3.60661
eval_top1,28.652
eval_top5,51.896
lr,4e-05
train_loss,4.08336


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016667971833764266, max=1.0…

Scheduled epochs: 15. LR stepped per epoch.


In [14]:
print(model.state_dict()['blocks.0.attn.qkv.weight'])
print(old_model.state_dict()['blocks.0.attn.qkv.weight'])

with torch.no_grad():
    tmp = old_model.state_dict()
    for param_tensor in old_model.state_dict(): # 字典的遍历默认是遍历 key，所以param_tensor实际上是键值
        # print(param_tensor,'\t',model.state_dict()[param_tensor].size())
        # 如果不是 attn 或者 mlp 层
        if 'attn' not in param_tensor and 'mlp' not in param_tensor:
            continue
        old_size = old_model.state_dict()[param_tensor].size()
        new_size = model.state_dict()[param_tensor].size()
        flag = False
        if len(old_size) != len(new_size):
            continue
        elif len(old_size) == 1:
            # 如果整除
            if new_size[0] % old_size[0] == 0:
                times = new_size[0] // old_size[0]
                tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
            # if old_size[0]*2 == new_size[0]:
            #     tmp[param_tensor] = torch.cat((tmp[param_tensor], tmp[param_tensor]), 0)
                flag = True
            # 如果是 2:3
            elif old_size[0]*3 == new_size[0]*2:
                # 拼接tensor和它自己的一半
                if old_size[0] % 2 == 0:
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                    # tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][]), 0)
                    flag = True
                else:
                    # 未实现
                    print('error')
        elif len(old_size) == 2:
            if old_size[0] == new_size[0]:
                # if old_size[1]*2 == new_size[1]:
                if new_size[1] % old_size[1] == 0:
                    times = new_size[1] // old_size[1]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 1)
                    flag = True
                elif old_size[1]*3 == new_size[1]*2:
                    if old_size[1] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:,old_size[1]//2]), 1)
                        flag = True
                    else:
                        # 未实现
                        print('error')
            elif old_size[1] == new_size[1]:
                # if old_size[0]*2 == new_size[0]:
                if new_size[0] % old_size[0] == 0:
                    times = new_size[0] // old_size[0]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
                    flag = True
                elif old_size[0]*3 == new_size[0]*2:    
                    if old_size[0] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                        flag = True
                    else:
                        # 未实现
                        print('error')
        if not flag:
            if param_tensor in model.state_dict():
                tmp[param_tensor] = model.state_dict()[param_tensor]
            else:
                continue
    model.load_state_dict(tmp, strict=False)
    print(model.state_dict()['blocks.0.attn.qkv.weight'])

tensor([[ 0.0315, -0.0346,  0.0174,  ...,  0.0112,  0.0255,  0.0137],
        [ 0.0115,  0.0173,  0.0114,  ...,  0.0149,  0.0192, -0.0242],
        [ 0.0148,  0.0067, -0.0030,  ...,  0.0030, -0.0055, -0.0211],
        ...,
        [-0.0092,  0.0030, -0.0129,  ..., -0.0232,  0.0294, -0.0003],
        [-0.0041,  0.0124, -0.0091,  ..., -0.0248, -0.0320, -0.0066],
        [-0.0025, -0.0027,  0.0012,  ..., -0.0146,  0.0144,  0.0328]],
       device='cuda:0')
tensor([[ 0.0359, -0.0161, -0.0002,  ...,  0.0428, -0.0494, -0.0629],
        [-0.0269,  0.0833,  0.0226,  ...,  0.0066, -0.0962, -0.1213],
        [-0.0033,  0.0057, -0.1324,  ...,  0.0825, -0.0265, -0.0474],
        ...,
        [-0.0475,  0.0690,  0.0118,  ...,  0.0751, -0.0250, -0.0040],
        [ 0.0761,  0.0436,  0.0120,  ...,  0.0375, -0.0297,  0.0170],
        [ 0.0940,  0.0508, -0.0287,  ...,  0.0445, -0.0137,  0.0418]],
       device='cuda:0')
tensor([[ 0.0359, -0.0161, -0.0002,  ...,  0.0428, -0.0494, -0.0629],
        [-0.02

In [15]:
len(old_optimizer.param_groups[0]['params'])
# 接下来，遍历原始模型的参数，把他们加入新的优化器中
for param_group in range(len(old_optimizer.param_groups)):
    for i, (old_p, new_p) in enumerate(zip(old_optimizer.param_groups[param_group]['params'], optimizer.param_groups[param_group]['params'])):
        # 查看 old_p 和 new_p 的 shape
        print(old_p.shape, new_p.shape)
        with torch.no_grad():
            if old_p.shape == new_p.shape:
                new_p = old_p
            else:
                # 否则，用 0 填补 new_p 的后半部分
                # 如果 new_p 是 1 维
                if len(new_p.shape) == 1:
                    new_p[:old_p.shape[0]] = old_p
                # 如果 new_p 是 2 维
                elif len(new_p.shape) == 2:
                    new_p[:old_p.shape[0], :old_p.shape[1]] = old_p

torch.Size([1, 1, 384]) torch.Size([1, 1, 384])
torch.Size([1, 197, 384]) torch.Size([1, 197, 384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([288]) torch.Size([576])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([768])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([288]) torch.Size([576])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([768])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([288]) torch.Size([576])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([768])
torch.Size([384]) torch.Size([384])


In [16]:
    try:
        for epoch in range(start_epoch, 10):
            if hasattr(dataset_train, 'set_epoch'):
                dataset_train.set_epoch(epoch)
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch,
                model,
                loader_train,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                # fish: add preconditioner
                preconditioner=preconditioner,
            )

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model,
                loader_eval,
                validate_loss_fn,
                args,
                amp_autocast=amp_autocast,
            )

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                ema_eval_metrics = validate(
                    model_ema.module,
                    loader_eval,
                    validate_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                )
                eval_metrics = ema_eval_metrics

            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

Test: [   0/390]  Time: 0.674 (0.674)  Loss:   2.138 ( 2.138)  Acc@1:  57.031 ( 57.031)  Acc@5:  77.344 ( 77.344)
Test: [  50/390]  Time: 0.038 (0.138)  Loss:   1.760 ( 2.790)  Acc@1:  55.469 ( 41.774)  Acc@5:  86.719 ( 66.238)
Test: [ 100/390]  Time: 0.253 (0.132)  Loss:   3.437 ( 2.829)  Acc@1:  16.406 ( 38.405)  Acc@5:  53.906 ( 65.664)
Test: [ 150/390]  Time: 0.041 (0.137)  Loss:   2.224 ( 2.783)  Acc@1:  46.094 ( 39.373)  Acc@5:  78.125 ( 66.603)
Test: [ 200/390]  Time: 0.042 (0.135)  Loss:   3.927 ( 2.960)  Acc@1:  15.625 ( 36.983)  Acc@5:  42.969 ( 63.483)
Test: [ 250/390]  Time: 0.043 (0.135)  Loss:   3.155 ( 3.072)  Acc@1:  41.406 ( 35.636)  Acc@5:  60.156 ( 61.330)
Test: [ 300/390]  Time: 0.404 (0.135)  Loss:   3.101 ( 3.152)  Acc@1:  41.406 ( 34.518)  Acc@5:  60.938 ( 59.738)
Test: [ 350/390]  Time: 0.041 (0.133)  Loss:   3.414 ( 3.224)  Acc@1:  26.562 ( 33.578)  Acc@5:  60.938 ( 58.549)
Test: [ 390/390]  Time: 0.026 (0.133)  Loss:   4.316 ( 3.189)  Acc@1:  12.500 ( 34.288) 

In [17]:
    old_model = model
    old_optimizer = optimizer
    old_loss_scaler = loss_scaler
    old_lr_scheduler = lr_scheduler
    if args.use_eva:
        old_preconditioner = preconditioner

    # args.model = args.model.replace('baby', 'child')
    args.model = args.model.replace('child', 'man')
    # args.model = args.model.replace('man', 'baby')
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **args.model_kwargs,
    )
    if args.head_init_scale is not None:
        with torch.no_grad():
            model.get_classifier().weight.mul_(args.head_init_scale)
            model.get_classifier().bias.mul_(args.head_init_scale)
    if args.head_init_bias is not None:
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    if utils.is_primary(args):
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')

    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not args.torchcompile
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
    )

    # fish: add eva preconditioner, not sure if to use model without ddp
    if args.use_eva:
        
        # preconditioner = Eva(model_without_ddp)
        preconditioner = Eva(
                model, lr=args.lr, factor_decay=args.stat_decay,
                damping=args.damping, kl_clip=args.kl_clip,
                fac_update_freq=args.kfac_cov_update_freq,
                kfac_update_freq=args.kfac_update_freq,
                #diag_blocks=args.diag_blocks,
                #diag_warmup=args.diag_warmup,
                #distribute_layer_factors=args.distribute_layer_factors, 
                exclude_parts=args.exclude_parts)

        kfac_param_scheduler = KFACParamScheduler(
               preconditioner,
               damping_alpha=args.damping_alpha,
               damping_schedule=args.damping_decay,
               update_freq_alpha=args.kfac_update_freq_alpha,
               update_freq_schedule=args.kfac_update_freq_decay,
               start_epoch=args.start_epoch)

        print(f"preconditioner eva is adapted")

    else:
        preconditioner = None

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        try:
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        except (AttributeError, TypeError):
            # fallback to CUDA only AMP for PyTorch < 1.10
            assert device.type == 'cuda'
            amp_autocast = torch.cuda.amp.autocast
        if device.type == 'cuda' and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    if args.torchcompile:
        # torch compile should be done after DDP
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        model = torch.compile(model, backend=args.torchcompile)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.to(device=device)
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            wandb.init(project='timm',name=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
    lr_scheduler, num_epochs = create_scheduler_v2(
        optimizer,
        **scheduler_kwargs(args),
        updates_per_epoch=updates_per_epoch,
    )
    start_epoch = 10
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        _logger.info(
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

Model bud_small_patch16_224_man created, param count:22050664
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.9
	crop_mode: center
Using native Torch AMP. Training in mixed precision.


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▃▅▆█
eval_loss,█▆▄▃▁
eval_top1,▁▃▅▆█
eval_top5,▁▃▅▇█
lr,█▆▅▃▁
train_loss,█▅▄▂▁

0,1
epoch,9.0
eval_loss,2.67956
eval_top1,43.538
eval_top5,68.332
lr,2e-05
train_loss,3.45738


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668962250696494, max=1.0…

Scheduled epochs: 15. LR stepped per epoch.


In [18]:
print(model.state_dict()['blocks.0.attn.qkv.weight'])
print(old_model.state_dict()['blocks.0.attn.qkv.weight'])

with torch.no_grad():
    tmp = old_model.state_dict()
    for param_tensor in old_model.state_dict(): # 字典的遍历默认是遍历 key，所以param_tensor实际上是键值
        # print(param_tensor,'\t',model.state_dict()[param_tensor].size())
        # 如果不是 attn 或者 mlp 层
        if 'attn' not in param_tensor and 'mlp' not in param_tensor:
            continue
        old_size = old_model.state_dict()[param_tensor].size()
        new_size = model.state_dict()[param_tensor].size()
        flag = False
        if len(old_size) != len(new_size):
            continue
        elif len(old_size) == 1:
            # 如果整除
            if new_size[0] % old_size[0] == 0:
                times = new_size[0] // old_size[0]
                tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
            # if old_size[0]*2 == new_size[0]:
            #     tmp[param_tensor] = torch.cat((tmp[param_tensor], tmp[param_tensor]), 0)
                flag = True
            # 如果是 2:3
            elif old_size[0]*3 == new_size[0]*2:
                # 拼接tensor和它自己的一半
                if old_size[0] % 2 == 0:
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                    # tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][]), 0)
                    flag = True
                else:
                    # 未实现
                    print('error')
        elif len(old_size) == 2:
            if old_size[0] == new_size[0]:
                # if old_size[1]*2 == new_size[1]:
                if new_size[1] % old_size[1] == 0:
                    times = new_size[1] // old_size[1]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 1)
                    flag = True
                elif old_size[1]*3 == new_size[1]*2:
                    if old_size[1] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:,old_size[1]//2]), 1)
                        flag = True
                    else:
                        # 未实现
                        print('error')
            elif old_size[1] == new_size[1]:
                # if old_size[0]*2 == new_size[0]:
                if new_size[0] % old_size[0] == 0:
                    times = new_size[0] // old_size[0]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
                    flag = True
                elif old_size[0]*3 == new_size[0]*2:    
                    if old_size[0] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                        flag = True
                    else:
                        # 未实现
                        print('error')
        if not flag:
            if param_tensor in model.state_dict():
                tmp[param_tensor] = model.state_dict()[param_tensor]
            else:
                continue
    model.load_state_dict(tmp, strict=False)
    print(model.state_dict()['blocks.0.attn.qkv.weight'])

tensor([[-0.0047,  0.0291,  0.0084,  ...,  0.0061,  0.0207,  0.0245],
        [-0.0064, -0.0107, -0.0039,  ..., -0.0381, -0.0248, -0.0018],
        [ 0.0030,  0.0182, -0.0012,  ...,  0.0175, -0.0210, -0.0017],
        ...,
        [-0.0133,  0.0130,  0.0257,  ..., -0.0031,  0.0145,  0.0067],
        [ 0.0034,  0.0407,  0.0003,  ..., -0.0187,  0.0027,  0.0183],
        [ 0.0138, -0.0381, -0.0353,  ...,  0.0069, -0.0030,  0.0216]],
       device='cuda:0')
tensor([[ 0.0384,  0.0526,  0.0231,  ...,  0.0080, -0.0127, -0.0318],
        [ 0.0123,  0.0325, -0.0157,  ..., -0.0021, -0.1333, -0.1377],
        [-0.0235, -0.0248, -0.1257,  ...,  0.0762, -0.0377, -0.0869],
        ...,
        [-0.0528,  0.0271,  0.0363,  ...,  0.0818,  0.0323, -0.0372],
        [ 0.0458,  0.0301, -0.0118,  ...,  0.0388, -0.0197,  0.0305],
        [ 0.1512,  0.0830, -0.0156,  ...,  0.0436,  0.0078,  0.0451]],
       device='cuda:0')
tensor([[ 0.0384,  0.0526,  0.0231,  ...,  0.0080, -0.0127, -0.0318],
        [ 0.01

In [19]:
len(old_optimizer.param_groups[0]['params'])
# 接下来，遍历原始模型的参数，把他们加入新的优化器中
for param_group in range(len(old_optimizer.param_groups)):
    for i, (old_p, new_p) in enumerate(zip(old_optimizer.param_groups[param_group]['params'], optimizer.param_groups[param_group]['params'])):
        # 查看 old_p 和 new_p 的 shape
        print(old_p.shape, new_p.shape)
        with torch.no_grad():
            if old_p.shape == new_p.shape:
                new_p = old_p
            else:
                # 否则，用 0 填补 new_p 的后半部分
                # 如果 new_p 是 1 维
                if len(new_p.shape) == 1:
                    new_p[:old_p.shape[0]] = old_p
                # 如果 new_p 是 2 维
                elif len(new_p.shape) == 2:
                    new_p[:old_p.shape[0], :old_p.shape[1]] = old_p

torch.Size([1, 1, 384]) torch.Size([1, 1, 384])
torch.Size([1, 197, 384]) torch.Size([1, 197, 384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([576]) torch.Size([1152])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([768]) torch.Size([1536])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([576]) torch.Size([1152])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([768]) torch.Size([1536])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([576]) torch.Size([1152])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([384]) torch.Size([384])
torch.Size([768]) torch.Size([1536])
torch.Size([384]) torch.Size([

In [20]:
    try:
        for epoch in range(start_epoch, 15):
            if hasattr(dataset_train, 'set_epoch'):
                dataset_train.set_epoch(epoch)
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch,
                model,
                loader_train,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                # fish: add preconditioner
                preconditioner=preconditioner,
            )

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model,
                loader_eval,
                validate_loss_fn,
                args,
                amp_autocast=amp_autocast,
            )

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                ema_eval_metrics = validate(
                    model_ema.module,
                    loader_eval,
                    validate_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                )
                eval_metrics = ema_eval_metrics

            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

Test: [   0/390]  Time: 0.674 (0.674)  Loss:   1.293 ( 1.293)  Acc@1:  77.344 ( 77.344)  Acc@5:  89.844 ( 89.844)
Test: [  50/390]  Time: 0.053 (0.140)  Loss:   1.117 ( 2.156)  Acc@1:  72.656 ( 54.810)  Acc@5:  89.844 ( 76.210)
Test: [ 100/390]  Time: 0.120 (0.135)  Loss:   2.358 ( 2.177)  Acc@1:  43.750 ( 51.323)  Acc@5:  77.344 ( 76.709)
Test: [ 150/390]  Time: 0.051 (0.136)  Loss:   1.972 ( 2.167)  Acc@1:  48.438 ( 51.630)  Acc@5:  82.031 ( 76.857)
Test: [ 200/390]  Time: 0.051 (0.134)  Loss:   3.377 ( 2.353)  Acc@1:  23.438 ( 48.453)  Acc@5:  57.031 ( 73.791)
Test: [ 250/390]  Time: 0.051 (0.134)  Loss:   2.423 ( 2.468)  Acc@1:  55.469 ( 46.791)  Acc@5:  69.531 ( 71.704)
Test: [ 300/390]  Time: 0.391 (0.134)  Loss:   2.856 ( 2.561)  Acc@1:  47.656 ( 45.315)  Acc@5:  63.281 ( 69.944)
Test: [ 350/390]  Time: 0.051 (0.133)  Loss:   2.860 ( 2.633)  Acc@1:  42.188 ( 44.017)  Acc@5:  64.844 ( 68.739)
Test: [ 390/390]  Time: 0.034 (0.134)  Loss:   3.787 ( 2.593)  Acc@1:  22.500 ( 44.822) 

small   1310s

mid     1310s

large   1650s

跑小模型时显卡未达到瓶颈，利用率不高，因此对于small和mid大小可以适当放大batchsize。按时间比较，22min:27.5min=1:1.25，因此可以再训练2个epoch。



In [21]:
    try:
        for epoch in range(15, 21):
            if hasattr(dataset_train, 'set_epoch'):
                dataset_train.set_epoch(epoch)
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch,
                model,
                loader_train,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                # fish: add preconditioner
                preconditioner=preconditioner,
            )

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model,
                loader_eval,
                validate_loss_fn,
                args,
                amp_autocast=amp_autocast,
            )

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                ema_eval_metrics = validate(
                    model_ema.module,
                    loader_eval,
                    validate_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                )
                eval_metrics = ema_eval_metrics

            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

Test: [   0/390]  Time: 0.704 (0.704)  Loss:   1.227 ( 1.227)  Acc@1:  75.781 ( 75.781)  Acc@5:  90.625 ( 90.625)
Test: [  50/390]  Time: 0.050 (0.154)  Loss:   1.175 ( 1.959)  Acc@1:  74.219 ( 58.594)  Acc@5:  92.969 ( 79.427)
Test: [ 100/390]  Time: 0.051 (0.141)  Loss:   1.928 ( 1.983)  Acc@1:  57.031 ( 55.523)  Acc@5:  84.375 ( 80.067)
Test: [ 150/390]  Time: 0.052 (0.145)  Loss:   1.791 ( 1.957)  Acc@1:  54.688 ( 56.152)  Acc@5:  82.812 ( 80.381)
Test: [ 200/390]  Time: 0.052 (0.142)  Loss:   3.058 ( 2.135)  Acc@1:  29.688 ( 53.016)  Acc@5:  64.062 ( 77.437)
Test: [ 250/390]  Time: 0.052 (0.141)  Loss:   2.338 ( 2.248)  Acc@1:  56.250 ( 51.264)  Acc@5:  71.094 ( 75.557)
Test: [ 300/390]  Time: 0.806 (0.147)  Loss:   2.625 ( 2.342)  Acc@1:  51.562 ( 49.642)  Acc@5:  67.188 ( 73.845)
Test: [ 350/390]  Time: 0.053 (0.150)  Loss:   2.639 ( 2.415)  Acc@1:  45.312 ( 48.279)  Acc@5:  71.094 ( 72.630)
Test: [ 390/390]  Time: 0.034 (0.150)  Loss:   3.558 ( 2.381)  Acc@1:  22.500 ( 48.840) 

In [31]:
# lr_scheduler
args.epochs = 21

In [49]:
    # old_model = model
    # old_optimizer = optimizer
    # old_loss_scaler = loss_scaler
    # old_lr_scheduler = lr_scheduler
    # if args.use_eva:
    #     old_preconditioner = preconditioner

    # args.model = args.model.replace('baby', 'child')
    # args.model = args.model.replace('child', 'man')
    # args.model = args.model.replace('man', 'baby')
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **args.model_kwargs,
    )
    if args.head_init_scale is not None:
        with torch.no_grad():
            model.get_classifier().weight.mul_(args.head_init_scale)
            model.get_classifier().bias.mul_(args.head_init_scale)
    if args.head_init_bias is not None:
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    if utils.is_primary(args):
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')

    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not args.torchcompile
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
    )

    # fish: add eva preconditioner, not sure if to use model without ddp
    if args.use_eva:
        
        # preconditioner = Eva(model_without_ddp)
        preconditioner = Eva(
                model, lr=args.lr, factor_decay=args.stat_decay,
                damping=args.damping, kl_clip=args.kl_clip,
                fac_update_freq=args.kfac_cov_update_freq,
                kfac_update_freq=args.kfac_update_freq,
                #diag_blocks=args.diag_blocks,
                #diag_warmup=args.diag_warmup,
                #distribute_layer_factors=args.distribute_layer_factors, 
                exclude_parts=args.exclude_parts)

        kfac_param_scheduler = KFACParamScheduler(
               preconditioner,
               damping_alpha=args.damping_alpha,
               damping_schedule=args.damping_decay,
               update_freq_alpha=args.kfac_update_freq_alpha,
               update_freq_schedule=args.kfac_update_freq_decay,
               start_epoch=args.start_epoch)

        print(f"preconditioner eva is adapted")

    else:
        preconditioner = None

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        try:
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        except (AttributeError, TypeError):
            # fallback to CUDA only AMP for PyTorch < 1.10
            assert device.type == 'cuda'
            amp_autocast = torch.cuda.amp.autocast
        if device.type == 'cuda' and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    if args.torchcompile:
        # torch compile should be done after DDP
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        model = torch.compile(model, backend=args.torchcompile)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.to(device=device)
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            wandb.init(project='timm',name=args.experiment, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
    lr_scheduler, num_epochs = create_scheduler_v2(
        optimizer,
        **scheduler_kwargs(args),
        updates_per_epoch=updates_per_epoch,
    )
    start_epoch = 10
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        _logger.info(
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

Model bud_small_patch16_224_man created, param count:22050664
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.9
	crop_mode: center
Using native Torch AMP. Training in mixed precision.


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
eval_loss,▁
eval_top1,▁
eval_top5,▁
lr,▁
train_loss,▁

0,1
epoch,10.0
eval_loss,2.65839
eval_top1,43.356
eval_top5,68.516
lr,3e-05
train_loss,3.59273


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668225050671025, max=1.0…

Scheduled epochs: 21. LR stepped per epoch.


In [50]:
print(model.state_dict()['blocks.0.attn.qkv.weight'])
print(old_model.state_dict()['blocks.0.attn.qkv.weight'])

with torch.no_grad():
    tmp = old_model.state_dict()
    for param_tensor in old_model.state_dict(): # 字典的遍历默认是遍历 key，所以param_tensor实际上是键值
        # print(param_tensor,'\t',model.state_dict()[param_tensor].size())
        # 如果不是 attn 或者 mlp 层
        if 'attn' not in param_tensor and 'mlp' not in param_tensor:
            continue
        old_size = old_model.state_dict()[param_tensor].size()
        new_size = model.state_dict()[param_tensor].size()
        flag = False
        if len(old_size) != len(new_size):
            continue
        elif len(old_size) == 1:
            # 如果整除
            if new_size[0] % old_size[0] == 0:
                times = new_size[0] // old_size[0]
                tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
            # if old_size[0]*2 == new_size[0]:
            #     tmp[param_tensor] = torch.cat((tmp[param_tensor], tmp[param_tensor]), 0)
                flag = True
            # 如果是 2:3
            elif old_size[0]*3 == new_size[0]*2:
                # 拼接tensor和它自己的一半
                if old_size[0] % 2 == 0:
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                    # tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][]), 0)
                    flag = True
                else:
                    # 未实现
                    print('error')
        elif len(old_size) == 2:
            if old_size[0] == new_size[0]:
                # if old_size[1]*2 == new_size[1]:
                if new_size[1] % old_size[1] == 0:
                    times = new_size[1] // old_size[1]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 1)
                    flag = True
                elif old_size[1]*3 == new_size[1]*2:
                    if old_size[1] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:,old_size[1]//2]), 1)
                        flag = True
                    else:
                        # 未实现
                        print('error')
            elif old_size[1] == new_size[1]:
                # if old_size[0]*2 == new_size[0]:
                if new_size[0] % old_size[0] == 0:
                    times = new_size[0] // old_size[0]
                    tmp[param_tensor] = torch.cat((tmp[param_tensor],)*times, 0)
                    flag = True
                elif old_size[0]*3 == new_size[0]*2:    
                    if old_size[0] % 2 == 0:
                        tmp[param_tensor] = torch.cat((tmp[param_tensor],tmp[param_tensor][:old_size[0]//2]), 0)
                        flag = True
                    else:
                        # 未实现
                        print('error')
        if not flag:
            if param_tensor in model.state_dict():
                tmp[param_tensor] = model.state_dict()[param_tensor]
            else:
                continue
    model.load_state_dict(tmp, strict=False)
    print(model.state_dict()['blocks.0.attn.qkv.weight'])

tensor([[ 0.0117,  0.0170, -0.0500,  ..., -0.0457, -0.0069, -0.0012],
        [-0.0065,  0.0319,  0.0395,  ...,  0.0188,  0.0278, -0.0423],
        [-0.0169, -0.0090,  0.0215,  ..., -0.0005,  0.0468, -0.0011],
        ...,
        [-0.0097,  0.0373,  0.0293,  ..., -0.0073,  0.0088, -0.0163],
        [-0.0138, -0.0104,  0.0100,  ...,  0.0055,  0.0143, -0.0381],
        [-0.0040, -0.0056,  0.0030,  ..., -0.0136,  0.0052, -0.0294]],
       device='cuda:0')
tensor([[ 0.0384,  0.0526,  0.0231,  ...,  0.0080, -0.0127, -0.0318],
        [ 0.0123,  0.0325, -0.0157,  ..., -0.0021, -0.1333, -0.1377],
        [-0.0235, -0.0248, -0.1257,  ...,  0.0762, -0.0377, -0.0869],
        ...,
        [-0.0528,  0.0271,  0.0363,  ...,  0.0818,  0.0323, -0.0372],
        [ 0.0458,  0.0301, -0.0118,  ...,  0.0388, -0.0197,  0.0305],
        [ 0.1512,  0.0830, -0.0156,  ...,  0.0436,  0.0078,  0.0451]],
       device='cuda:0')
tensor([[ 0.0384,  0.0526,  0.0231,  ...,  0.0080, -0.0127, -0.0318],
        [ 0.01

In [43]:
# print(old_optimizer.param_groups[0]['params'][0][0][0][0])
# print(optimizer.param_groups[0]['params'][0][0][0][0])

tensor(0.0206, device='cuda:0', grad_fn=<SelectBackward0>)
tensor(0.0206, device='cuda:0', grad_fn=<SelectBackward0>)


In [51]:
len(old_optimizer.param_groups[0]['params'])
# 接下来，遍历原始模型的参数，把他们加入新的优化器中
for param_group in range(len(old_optimizer.param_groups)):
    for i, (old_p, new_p) in enumerate(zip(old_optimizer.param_groups[param_group]['params'], optimizer.param_groups[param_group]['params'])):
        # 查看 old_p 和 new_p 的 shape
        # print(old_p.shape, new_p.shape)
        with torch.no_grad():
            if old_p.shape == new_p.shape:
                new_p = old_p
            else:
                # 否则，用 0 填补 new_p 的后半部分
                # 如果 new_p 是 1 维
                if len(new_p.shape) == 1:
                    new_p[:old_p.shape[0]] = old_p
                # 如果 new_p 是 2 维
                elif len(new_p.shape) == 2:
                    new_p[:old_p.shape[0], :old_p.shape[1]] = old_p

In [52]:
    try:
        for epoch in range(10, 21):
            if hasattr(dataset_train, 'set_epoch'):
                dataset_train.set_epoch(epoch)
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch,
                model,
                loader_train,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                # fish: add preconditioner
                preconditioner=preconditioner,
            )

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model,
                loader_eval,
                validate_loss_fn,
                args,
                amp_autocast=amp_autocast,
            )

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                ema_eval_metrics = validate(
                    model_ema.module,
                    loader_eval,
                    validate_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                )
                eval_metrics = ema_eval_metrics

            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

Test: [   0/390]  Time: 0.701 (0.701)  Loss:   1.182 ( 1.182)  Acc@1:  75.781 ( 75.781)  Acc@5:  88.281 ( 88.281)
Test: [  50/390]  Time: 0.054 (0.152)  Loss:   1.412 ( 2.198)  Acc@1:  67.188 ( 53.079)  Acc@5:  85.938 ( 75.337)
Test: [ 100/390]  Time: 0.120 (0.142)  Loss:   2.373 ( 2.235)  Acc@1:  46.094 ( 49.822)  Acc@5:  79.688 ( 75.975)
Test: [ 150/390]  Time: 0.053 (0.146)  Loss:   1.895 ( 2.208)  Acc@1:  53.906 ( 50.440)  Acc@5:  81.250 ( 76.304)
Test: [ 200/390]  Time: 0.055 (0.141)  Loss:   3.564 ( 2.391)  Acc@1:  21.094 ( 47.579)  Acc@5:  54.688 ( 73.115)
Test: [ 250/390]  Time: 0.053 (0.140)  Loss:   2.589 ( 2.516)  Acc@1:  47.656 ( 45.708)  Acc@5:  67.969 ( 70.975)
Test: [ 300/390]  Time: 0.406 (0.140)  Loss:   2.971 ( 2.615)  Acc@1:  43.750 ( 44.145)  Acc@5:  64.844 ( 69.085)
Test: [ 350/390]  Time: 0.053 (0.138)  Loss:   2.834 ( 2.687)  Acc@1:  42.188 ( 42.995)  Acc@5:  66.406 ( 67.853)
Test: [ 390/390]  Time: 0.035 (0.138)  Loss:   3.869 ( 2.647)  Acc@1:  17.500 ( 43.770) 

In [53]:
optimizer

Lion (
Parameter Group 0
    betas: (0.9, 0.99)
    foreach: True
    initial_lr: 5e-05
    lr: 0
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    betas: (0.9, 0.99)
    foreach: True
    initial_lr: 5e-05
    lr: 0
    maximize: False
    weight_decay: 0.001
)