In [1]:
import timm
import torchsummary
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_scriptable, set_no_jit
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
from munch import Munch
import yaml
import sys
from efficientnet_pytorch import EfficientNet
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
torch.backends.cudnn.benchmark = True

In [3]:
with open('config/val_ns.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher

In [4]:
model_list = timm.list_models(pretrained=True)

In [5]:
model_list

['adv_inception_v3',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dpn68',
 'dpn68b',
 'dpn92',
 'dpn98',
 'dpn107',
 'dpn131',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnetlight',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b2a',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b3a',
 'efficientnet_es',
 'ens_adv_inception_resnet_v2',
 'ese_vovnet19b_dw',
 'ese_vovnet39b',
 'fbnetc_100',
 'gluon_inception_v3',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_resnet50_v1c',
 'gluon_resnet50_v1d',
 'gluon_resnet50_v1s',
 'gluon_resnet101_v1b',
 'gluon_resnet101_v1c',
 'gluon_resnet101_v1d',
 'gluon_resnet101_v1s',
 'gluon_r

In [6]:
data_dir = '/home/data/imagenet/val'

In [7]:
# model_ns = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
# model = model_ns.cuda()
# model = torch.nn.DataParallel(model).cuda()

In [8]:
# model = EfficientNet.from_pretrained("efficientnet-b7", advprop=True)
# model_ap = model.cuda()
# model = torch.nn.DataParallel(model_ap).cuda()

In [9]:
model = timm.create_model('tf_efficientnet_b7', pretrained=True)
model = model.cuda()
model = torch.nn.DataParallel(model).cuda()

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth" to /home/cutz/.cache/torch/checkpoints/tf_efficientnet_b7_ra-6c08e654.pth


In [10]:
criterion = nn.CrossEntropyLoss().cuda()

In [11]:
dataset = Dataset(data_dir, load_bytes=False, class_map='')

In [12]:
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 66347960


In [13]:
data_config = resolve_data_config(args, model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 600, 600)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.875


In [14]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
    dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [15]:
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

model.eval()
with torch.no_grad():
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
    model(input)
    end = time.time()
    for i, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_freq == 0:
            logging.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    i, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

results = OrderedDict(
    top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
    top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
    param_count=round(param_count / 1e6, 2),
    img_size=data_config['input_size'][-1],
    cropt_pct=crop_pct,
    interpolation=data_config['interpolation'])

logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
   results['top1'], results['top1_err'], results['top5'], results['top5_err']))


INFO:root:Test: [   0/196]  Time: 9.382s (9.382s,   27.29/s)  Loss:  0.3436 (0.3436)  Acc@1:  95.312 ( 95.312)  Acc@5:  98.828 ( 98.828)
INFO:root:Test: [  10/196]  Time: 2.188s (2.954s,   86.67/s)  Loss:  0.7012 (0.4905)  Acc@1:  83.594 ( 89.347)  Acc@5:  98.438 ( 98.651)
INFO:root:Test: [  20/196]  Time: 2.187s (2.599s,   98.50/s)  Loss:  0.3845 (0.5060)  Acc@1:  94.922 ( 89.118)  Acc@5:  98.047 ( 98.344)
INFO:root:Test: [  30/196]  Time: 2.466s (2.513s,  101.85/s)  Loss:  0.5904 (0.4790)  Acc@1:  88.281 ( 90.033)  Acc@5:  96.484 ( 98.438)
INFO:root:Test: [  40/196]  Time: 2.188s (2.456s,  104.25/s)  Loss:  0.5406 (0.5204)  Acc@1:  89.453 ( 89.167)  Acc@5:  97.656 ( 98.180)
INFO:root:Test: [  50/196]  Time: 2.231s (2.407s,  106.34/s)  Loss:  0.3706 (0.5288)  Acc@1:  94.141 ( 88.863)  Acc@5:  98.828 ( 98.200)
INFO:root:Test: [  60/196]  Time: 2.188s (2.372s,  107.91/s)  Loss:  0.6686 (0.5460)  Acc@1:  85.156 ( 88.320)  Acc@5:  96.484 ( 98.162)
INFO:root:Test: [  70/196]  Time: 2.184s 