In [1]:
import timm
import torchsummary
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_scriptable, set_no_jit
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
from munch import Munch
import yaml
import sys
from efficientnet_pytorch import EfficientNet
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
torch.backends.cudnn.benchmark = True

In [3]:
with open('config/val_ns.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher

In [4]:
# model_list = timm.list_models(pretrained=True)

In [5]:
# model_list

In [6]:
data_dir = '/home/data/imagenet/val'

In [7]:
model_ns = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
model_ns = model_ns.cuda(0)
# model_ns = torch.nn.DataParallel(model_ns).cuda()

In [8]:
# model = EfficientNet.from_pretrained("efficientnet-b7", advprop=True)
# model_ap = model.cuda()
# model = torch.nn.DataParallel(model_ap).cuda()

In [9]:
model = timm.create_model('tf_efficientnet_b7', pretrained=True)
model = model.cuda(1)
# model = torch.nn.DataParallel(model).cuda()

In [10]:
criterion = nn.CrossEntropyLoss().cuda()

In [11]:
dataset = Dataset(data_dir, load_bytes=False, class_map='')

In [12]:
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 66347960


In [13]:
data_config = resolve_data_config(args, model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 600, 600)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.949


In [14]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
    dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [None]:
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

model.eval()
with torch.no_grad():
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    inputs = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
    model(inputs)
    end = time.time()
    for i, (inputs, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            inputs = inputs.cuda()

        # compute output
        output, foward_list = model(inputs)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_freq == 0:
            logging.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    i, len(loader), batch_time=batch_time,
                    rate_avg=inputs.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

results = OrderedDict(
    top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
    top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
    param_count=round(param_count / 1e6, 2),
    img_size=data_config['input_size'][-1],
    cropt_pct=crop_pct,
    interpolation=data_config['interpolation'])

logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
   results['top1'], results['top1_err'], results['top5'], results['top5_err']))


In [15]:
inputs = torch.randn((args.batch_size,) + data_config['input_size']).cuda()


In [15]:
for i, (inputs, target) in enumerate(loader):
    break

In [16]:
target = target.cuda()
inputs1 = inputs[:4,:].cuda(0)
inputs2 = inputs[:4,:].cuda(1)

In [17]:
o, f = model(inputs2)
on, fn = model_ns(inputs1)

In [20]:
o

tensor([[-1.8492,  1.1344, -0.2752,  ...,  1.0009,  1.1704,  0.7831],
        [ 0.3711,  1.1172,  1.3410,  ...,  0.4195, -0.7132, -0.5410],
        [-0.2846, -1.1747, -0.2825,  ..., -1.0386,  0.6629,  0.0532],
        [ 1.3348, -0.6319, -0.4744,  ...,  0.0430, -0.2146,  0.7475]],
       device='cuda:1', grad_fn=<AddmmBackward>)

In [21]:
on

tensor([[-2.6963, -0.5890, -0.7212,  ...,  0.7375,  0.4632,  1.6150],
        [-0.3165,  0.5732,  0.3430,  ..., -1.5601, -0.2313,  0.1448],
        [-1.3905, -1.3874,  0.1128,  ..., -0.8446, -0.7689, -0.1030],
        [ 2.4641,  0.8442, -0.2377,  ...,  0.9964,  1.2405, -0.8994]],
       device='cuda:0', grad_fn=<AddmmBackward>)

In [18]:
f

[tensor([[[[-1.8072e-01, -2.7845e-01, -2.3851e-01,  ..., -1.6028e-01,
            -1.5388e-01, -1.7731e-01],
           [-1.9664e-01, -2.5365e-01, -2.5874e-01,  ..., -1.5876e-01,
            -1.2922e-01, -2.0947e-01],
           [-2.7682e-01, -2.4239e-01, -2.7774e-01,  ..., -1.6911e-01,
            -1.6033e-01, -2.4927e-01],
           ...,
           [-1.9967e-01, -1.9759e-01, -2.4018e-01,  ..., -8.4723e-02,
            -2.1109e-01, -2.7607e-01],
           [-1.9164e-01, -2.1717e-01, -2.4689e-01,  ..., -2.2867e-01,
            -2.7108e-01, -2.6598e-01],
           [-2.4064e-01, -2.6160e-01, -2.7260e-01,  ..., -2.7627e-01,
            -2.7826e-01, -2.7675e-01]],
 
          [[-1.4843e-01, -2.5836e-01, -2.7673e-01,  ...,  1.1259e+00,
             3.5869e+00,  1.4454e+00],
           [-9.6783e-02, -9.9397e-02, -2.4847e-01,  ...,  7.1397e-01,
             2.9204e+00,  1.8880e+00],
           [-2.3112e-01, -2.5798e-01,  9.4891e-01,  ...,  8.7345e-01,
             3.2878e+00,  2.6539e+00],


In [19]:
fn

[tensor([[[[ 9.0378e-01,  3.9503e-01,  2.9614e-01,  ...,  8.0788e-01,
             1.0411e+00,  7.9025e-01],
           [ 5.4136e-01,  6.2745e-01,  3.0182e-01,  ...,  7.8725e-01,
             1.0023e+00,  6.0522e-01],
           [ 1.1092e-02,  8.0312e-02,  1.0142e-01,  ...,  4.6305e-01,
             7.9664e-01,  5.8861e-01],
           ...,
           [ 6.4856e-01,  6.0449e-01,  5.1282e-01,  ...,  7.6382e-01,
             4.0308e-01,  3.1252e-01],
           [ 6.1856e-01,  4.9918e-01,  4.4523e-01,  ...,  4.4014e-01,
             3.0136e-01,  3.2415e-01],
           [ 4.1480e-01,  3.7716e-01,  3.7557e-01,  ...,  4.5631e-01,
             4.4001e-01,  3.5067e-01]],
 
          [[-4.2117e-04, -5.1673e-04, -5.4898e-04,  ..., -3.6242e-04,
            -3.4613e-04, -4.1150e-04],
           [-4.5998e-04, -4.4726e-04, -5.0405e-04,  ..., -3.6301e-04,
            -3.5550e-04, -4.3605e-04],
           [-5.9450e-04, -5.7222e-04, -5.8756e-04,  ..., -4.1667e-04,
            -3.9299e-04, -4.5002e-04],
