# EfficientNetV2: Smaller Models and Faster Training <br>
In this project, we study the family of convolutional networks, called *EfficientNetV2*, presented in the paper: <br>

''Tan, Mingxing, and Quoc Le. "Efficientnetv2: Smaller models and faster training." In International Conference on Machine Learning, pp. 10096-10106. PMLR, 2021.'',

and explore its features and advantages over the existing methods in the literature. 


**Setup:** <br>
The version of the packages required for running the codes:<br>
python 3.7.7, cuda 10.1, pytorch 1.11.0, timm 0.5.4


In [None]:
!pip install timm



In [None]:
import timm 

netv2_s = timm.create_model('efficientnetv2_rw_s', pretrained = True)
# print("Architecture of EfficientNetV2-S: ", netv2_s.eval())
netv2_s_params = sum([m.numel() for m in netv2_s.parameters()])
print("EfficientNetV2-S Params: ", netv2_s_params)
print("EfficientNetV2-S Classifier: ", netv2_s.get_classifier())
netv2_m = timm.create_model('efficientnetv2_rw_m', pretrained = True)
netv2_m_params = sum([m.numel() for m in netv2_m.parameters()])
print("EfficientNetV2-M Params: ", netv2_m_params)
print("EfficientNetV2-M Classifier: ", netv2_m.get_classifier())


EfficientNetV2-S Params:  23941296
EfficientNetV2-S Classifier:  Linear(in_features=1792, out_features=1000, bias=True)
EfficientNetV2-M Params:  53236442
EfficientNetV2-M Classifier:  Linear(in_features=2152, out_features=1000, bias=True)


In [None]:
%%writefile train.py 

################################################################################
### IMPORTS AND INITIAL SETUP
################################################################################

import argparse, time, os, logging, matplotlib.pyplot as plt
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import torch, torch.nn as nn, torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config,\
      Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model,\
      model_parameters
from timm.utils import *
from timm.loss import *
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler

torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')


################################################################################
### PARSE INPUT ARGUMENTS
################################################################################

parser = argparse.ArgumentParser()
# Setup of the execution
parser.add_argument('--seed', '-s', type=int, default=42, metavar='S',
                    help='random seed')
parser.add_argument('--workers', '-w', type=int, default=2, metavar='WORKERS',
                    help='how many training processes to use')
parser.add_argument('--eval-metric', default='top1', type=str, 
                    metavar='EVAL_METRIC', help='Best evaluation metric')
parser.add_argument('--output', default='', type=str, metavar='PATH',
                    help='path to output folder (default: current dir)')

# Dataset Arguments:
parser.add_argument('data_dir', metavar='DIR', help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='DATA', default='torch/cifar100',
                    help='dataset (default: torch/cifar100)')
parser.add_argument('--train-split', metavar='TRAIN', default='train',
                    help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='VAL', default='validation',
                    help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
                    help='Allow download of dataset for torch/* .')
parser.add_argument('--input-size', '-i', default=None, nargs=3, type=int,
                    metavar='SIZE', help='''image dimensions (d h w, e.g. 
                    --input-size 3 224 224)''')
parser.add_argument('--num-classes', '-n', type=int, default=None,metavar='NUM',
                    help='number of label classes')
parser.add_argument('--batch-size', '-b', type=int, default=32, metavar='BATCH',
                    help='Training batch size')

# Model Arguments
parser.add_argument('--model', '-m', default='efficientnetv2_rw_s', type=str, 
                    metavar='MODEL', help='Name of model to train')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Load pretrained model')
parser.add_argument('--global_pool', '-g', default=None, type=str, 
                    metavar='POOL', help='''Global pool type, one of: (fast, 
                    avg, max, avgmax, avgmaxc); Model default if None.''')
parser.add_argument('--opt', '-o', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimization method')

# Augmentation Arguments
parser.add_argument('--no-aug', action='store_true', default=False,
                    help='Disable all training augmentation')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], 
                    metavar='PCT', help='Random resize scale')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], 
                    metavar='RATIO', help='Random resize aspect ratio')
parser.add_argument('--hflip', type=float, default=0.5,
                    help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
                    help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                    help='Color jitter factor')    
parser.add_argument('--aa', type=str, default=None, metavar='AUTOAUG',
                    help='Use AutoAugment policy: "v0" or "original"')
parser.add_argument('--aug-splits', type=int, default=0,
                    help='Number of augmentation splits: 0 or >=2')

# Regularization Arguments
parser.add_argument('--bce-loss', action='store_true', default=False,
                    help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0.0')
parser.add_argument('--cutmix', type=float, default=0.0,
                    help='cutmix alpha, cutmix enabled if > 0.0')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                    help='''cutmix min/max ratio, overrides alpha and enables 
                    cutmix if set''')
parser.add_argument('--mixup-prob', type=float, default=1.0,
                    help='''Probability of performing mixup or cutmix when 
                    either/both is enabled''')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                    help='''Probability of switching to cutmix when both mixup 
                    and cutmix enabled''')
parser.add_argument('--mixup-mode', type=str, default='batch',
                    help='''How to apply mixup/cutmix params. Per "batch", 
                    "pair", or "elem"''')
parser.add_argument('--smoothing',type=float,default=0.1,help='Label smoothing')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
                    help='Dropout rate')


# Learning rate Arguments 
parser.add_argument('--lr', '-l', type=float, default=0.05, metavar='LR',
                    help='learning rate')
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                    help='Learning rate scheduler')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, 
                    metavar='NOISE', help='learning rate noise')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, 
                    metavar='PERCENT', help='learning rate noise limit percent')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                    help='learning rate noise std-dev')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                    help='learning rate cycle len multiplier')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5,metavar='DECAY',
                    help='amount to decay each learning rate cycle')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='LIMIT',
                    help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
                    help='learning rate k-decay for cosine/poly')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='WARM',
                    help='warmup learning rate')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='MIN',
                    help='lower lr bound for cyclic schedulers that hit 0')

# Epoch Arguments
parser.add_argument('--epochs', '-e', type=int, default=100, metavar='EPOCHS',
                    help='number of epochs to train')
parser.add_argument('--epoch-repeats', type=float, default=0.0,metavar='REPEAT',
                    help='epoch repeat multiplier.')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='DLRE',
                    help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='WARMEPOCH',
                    help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='COOL',
                    help='epochs to cooldown LR at min_lr, after cyclic ends')
parser.add_argument('--patience-epochs', type=int, default=10, 
                    metavar='PATIENCE', help='patience epochs for LR scheduler')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, 
                    metavar='RATE', help='LR decay rate')


################################################################################
### GENERATE DATASETS
################################################################################

def get_datasets(args):
    train_data = create_dataset(args.dataset, root = args.data_dir, 
        split=args.train_split, is_training=True,download=args.dataset_download,
        batch_size = args.batch_size, repeats = args.epoch_repeats)
    test_data = create_dataset(args.dataset, root = args.data_dir, 
        split=args.val_split, is_training=False, download=args.dataset_download,
        batch_size = args.batch_size)
    return train_data, test_data


################################################################################
### CREATE DATA LOADERS
################################################################################

def get_loaders(train_data, test_data, args, config, num_aug, collate_fn):
    interp = 'random'
    if args.no_aug or not interp:
        interp = config['interpolation']

    train_loader = create_loader(train_data, 
        input_size = config['input_size'], batch_size = args.batch_size,
        is_training = True, use_prefetcher = args.prefetcher,
        no_aug = args.no_aug, re_prob = 0.0, re_mode = 'pixel', re_count = 1,
        re_split = False, scale = args.scale, ratio = args.ratio,
        hflip = args.hflip, vflip = args.vflip,color_jitter = args.color_jitter,
        auto_augment = args.aa, num_aug_repeats = 0,
        num_aug_splits = num_aug, interpolation = interp, mean = config['mean'],
        std = config['std'], num_workers = args.workers, distributed = False, 
        collate_fn = collate_fn, pin_memory = False, 
        use_multi_epochs_loader = False, worker_seeding = 'all')
    
    test_loader = create_loader(test_data, input_size = config['input_size'],
        batch_size = args.batch_size, is_training = False,
        use_prefetcher = args.prefetcher,
        interpolation = config['interpolation'], mean = config['mean'],
        std = config['std'], num_workers = args.workers, distributed = False, 
        crop_pct = config['crop_pct'], pin_memory = False)
    return train_loader, test_loader


################################################################################
### MAIN METHOD
################################################################################

def main():
    # Initialization of parameters
    setup_default_logging()
    args = parser.parse_args()
    args.mean, args.std, args.crop_pct, args.interpolation = None, None, None,''
    args.opt_eps, args.opt_betas, args.layer_decay = None, None, None
    args.momentum, args.weight_decay, args.log_interval = 0.9, 2e-5, 50
    args.prefetcher, args.device, args.world_size, args.rank = True,'cuda:0',1,0 
    _logger.info('Training with a single process on 1 GPU.')
    random_seed(args.seed, args.rank)
    train_data, test_data = get_datasets(args)
    
    # Generating the model
    model = create_model(args.model, pretrained = args.pretrained,
        num_classes = args.num_classes, drop_rate = args.drop,
        drop_connect_rate = None,  drop_path_rate = None,
        drop_block_rate = None, global_pool = args.global_pool,
        bn_momentum = None, bn_eps = None, scriptable = None)
    model.cuda()
    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    if args.num_classes is None:
        assert hasattr(model, 'num_classes')
        args.num_classes = model.num_classes 
    _logger.info(f'Model name: {safe_model_name(args.model)}')
    num_params = sum([m.numel() for m in model.parameters()])
    _logger.info(f'Number of Parameters: {num_params}')
    
    # Setup learning rate scheduler
    lr_sched, num_epochs = create_scheduler(args, optimizer)
    _logger.info('Scheduled epochs: {}'.format(num_epochs))

    # Setup data augmentation 
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits
    collate_fn, mixup_fn = None, None
    mixup_flag = args.mixup > 0 or args.cutmix > 0.0
    mixup_active = mixup_flag or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha = args.mixup, cutmix_alpha = args.cutmix, 
            cutmix_minmax = args.cutmix_minmax, prob = args.mixup_prob, 
            switch_prob = args.mixup_switch_prob, mode = args.mixup_mode,
            label_smoothing = args.smoothing, num_classes = args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)
    if num_aug_splits > 1:
        train_data = AugMixDataset(train_data, num_splits = num_aug_splits)

    # Generate data loaders (w/ augmentation)
    data_config = resolve_data_config(vars(args), model = model, verbose = True)   
    train_loader, test_loader = get_loaders(train_data, test_data, args, 
                                data_config, num_aug_splits, collate_fn)
    
    # Setup loss functions
    if mixup_active:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold = None)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing = args.smoothing, 
                                                  target_threshold = None)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    test_loss_fn = nn.CrossEntropyLoss().cuda()

    # Setup train and test metrics
    eval_metric = args.eval_metric
    saver, output_dir = None, None
    if args.rank == 0:
        exp_name = '-'.join([datetime.now().strftime("%Y%m%d-%H%M%S"),
            safe_model_name(args.model), str(data_config['input_size'][-1])])
        output_dir =get_outdir(args.output if args.output else './output/train', 
                               exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(model = model, optimizer = optimizer,args =args,
                model_ema = None, amp_scaler = None, checkpoint_dir =output_dir,
                recovery_dir = output_dir, decreasing = decreasing, 
                max_history = 10)

    try:
        loss, top_1, top_5 = [], [], []
        t_init = time.time()
        for epoch in range(num_epochs):
            train_metrics = train_one_epoch(epoch, model, train_loader, 
                            optimizer, train_loss_fn, args, lr_sched = lr_sched,
                            saver = saver, output_dir = output_dir,
                            loss_scaler = None, model_ema = None, 
                            mixup_fn = mixup_fn)
            loss.append(train_metrics['loss'])
            
            test_metrics = validate(model, test_loader, test_loss_fn, args)
            top_1.append(test_metrics['top1'])
            top_5.append(test_metrics['top5'])
            if lr_sched is not None:
                lr_sched.step(epoch + 1, test_metrics[eval_metric])
            if output_dir is not None:
                update_summary(epoch, train_metrics, test_metrics, 
                    os.path.join(output_dir, 'summary.csv'),
                    write_header = True, log_wandb = False)
        t_end = time.time()
        print("Total execution time: ", t_end - t_init)
        plot_figures(loss, top_1, top_5)
    except KeyboardInterrupt:
        pass


################################################################################
### TRAINING METHOD
################################################################################

def train_one_epoch(epoch, model, loader, optimizer, loss_fn, args,
        lr_sched = None, saver = None, output_dir = None,
        loss_scaler = None, model_ema = None, mixup_fn = None):
    batch_time_m, data_time_m = AverageMeter(), AverageMeter()
    losses_m = AverageMeter()
    model.train()
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)

    # Load the data batch
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
        
        # Forward propagate the network
        with suppress():
            output = model(input)
            loss = loss_fn(output, target)
        losses_m.update(loss.item(), input.size(0))

        # Backward propagate the network
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        num_updates += 1
        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl) 

            # Record Results      
            _logger.info('Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                    epoch, batch_idx, len(loader), 100. * batch_idx / last_idx,
                    loss = losses_m, batch_time = batch_time_m,
                    rate = input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                    lr = lr, data_time = data_time_m))
        if lr_sched is not None:
            lr_sched.step_update(num_updates = num_updates, metric=losses_m.avg)
        end = time.time()
        
    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()
    return OrderedDict([('loss', losses_m.avg)])


################################################################################
### EVALUATION METHOD
################################################################################

def validate(model, loader, loss_fn, args, log_suffix=''):
    batch_time_m, losses_m = AverageMeter(), AverageMeter()
    top1_m, top5_m = AverageMeter(), AverageMeter()
    model.eval()
    end = time.time()
    last_idx = len(loader) - 1

    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.cuda()
                target = target.cuda()
            
            # Predict test data
            with suppress():
                output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]

            # Compute loss and accuracy of predictions
            loss = loss_fn(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            reduced_loss = loss.data
            torch.cuda.synchronize()
            losses_m.update(reduced_loss.item(), input.size(0))
            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
            if (last_batch or batch_idx % args.log_interval == 0):
                log_name = 'Test' + log_suffix
                _logger.info('{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                    log_name, batch_idx, last_idx, batch_time = batch_time_m,
                    loss = losses_m, top1 = top1_m, top5 = top5_m))

    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), 
                           ('top5', top5_m.avg)])
    return metrics


################################################################################
### PLOTTING RESULTS
################################################################################

def plot_figures(loss, top_1, top_5):
    num_epochs = len(loss)
    epochs = [i for i in range(num_epochs)]
    loss_fig = plt.figure()
    plt.plot(epochs, loss)
    plt.xlim((0, num_epochs))
    plt.grid(True)
    plt.xlabel('Epochs')
    plt.ylabel('Training loss')
    loss_fig.savefig('output/train/loss.png', format='png')
    
    top1_fig = plt.figure()
    plt.plot(epochs, top_1)
    plt.xlim((0, num_epochs))
    plt.grid(True)
    plt.xlabel('Epochs')
    plt.ylabel('Test top_1 accuracy')
    top1_fig.savefig('output/train/top1.png', format='png')
    
    top5_fig = plt.figure()
    plt.plot(epochs, top_5)
    plt.xlim((0, num_epochs))
    plt.grid(True)
    plt.xlabel('Epochs')
    plt.ylabel('Test top_5 accuracy')
    top5_fig.savefig('output/train/top5.png', format='png')
    
    plt.show()
    

################################################################################
### GLOBAL METHODS AND VARIABLES
################################################################################

if __name__ == '__main__':
    main()


Overwriting train.py


In [None]:
%run train.py -d torch/cifar10 --dataset-download datasets --model efficientnet_b1 --epochs 2 -b 8

Training with a single process on 1 GPU.


Files already downloaded and verified
Files already downloaded and verified


Model name: efficientnet_b1
Number of Parameters: 7794184
Scheduled epochs: 12
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 1.0
