In [1]:
import timm
import torchsummary
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from collections import OrderedDict
from timm.data import Dataset, create_loader, resolve_data_config,  FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model, apply_test_time_pool
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from munch import Munch
import yaml
import sys
from gd import Generator, Discriminator
from datetime import datetime
import numpy as np
import torch.distributed as dist
import torch.nn.functional as F
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
# from apex import amp
# from apex.parallel import DistributedDataParallel as DDP
# from apex.parallel import convert_syncbn_model
has_apex = False

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.backends.cudnn.benchmark = True

In [5]:
with open('config/train.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher
args.distributed = False
args.device = 'cuda'
args.world_size = 3
args.rank = 0
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

INFO:root:Training with a single process on 3 GPUs.


In [6]:
if args.distributed:
    args.num_gpu = 1
    args.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=args.rank, world_size=args.world_size)
    args.world_size = torch.distributed.get_world_size()
    args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

In [7]:
torch.manual_seed(args.seed + args.rank)

<torch._C.Generator at 0x7fbb74e01e90>

In [8]:
model_ns = timm.create_model('tf_efficientnet_b1_ns', pretrained=True)
model_ns = model_ns.cuda()
model_ns = torch.nn.DataParallel(model_ns)

In [9]:
model_raw = timm.create_model('tf_efficientnet_b1', pretrained=True)
model_raw = model_raw.cuda()
model_raw = torch.nn.DataParallel(model_raw)

In [10]:
model_g = Generator(args, img_size=240, max_conv_dim=256)
model_g = model_g.cuda()

In [11]:
optimizer = create_optimizer(args, model_g)

In [12]:
use_amp = False
if has_apex and args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    model_raw =  amp.initialize(model_raw)
    model_ns = amp.initialize(model_ns)
    use_amp = True
if args.local_rank == 0:
    logging.info('NVIDIA APEX {}. AMP {}.'.format(
        'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

INFO:root:NVIDIA APEX not installed. AMP off.


In [13]:
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
    resume_state, resume_epoch = resume_checkpoint(model_g, args.resume)
if resume_state and not args.no_resume_opt:
    if 'optimizer' in resume_state:
        if args.local_rank == 0:
            logging.info('Restoring Optimizer state from checkpoint')
        optimizer.load_state_dict(resume_state['optimizer'])
    if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
        if args.local_rank == 0:
            logging.info('Restoring NVIDIA AMP state from checkpoint')
        amp.load_state_dict(resume_state['amp'])
del resume_state

INFO:root:Loaded checkpoint './output/train/20200625-003932-tf_efficientnet_b1-240/checkpoint-0.pth.tar' (epoch 0)
INFO:root:Restoring Optimizer state from checkpoint


In [14]:
model_g = torch.nn.DataParallel(model_g)

In [15]:
train_dir = '/home/data/imagenet/train'
val_dir = '/home/data/imagenet/val'
data_config = resolve_data_config(vars(args), model=model_g, verbose=args.local_rank == 0)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 240, 240)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.882


In [16]:
num_aug_splits = 0

In [17]:
if args.distributed:
    if args.sync_bn:
        assert not args.split_bn
        try:
            if has_apex:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logging.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
        except Exception as e:
            logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
    # NOTE: EMA model does not need to be wrapped by DDP


In [18]:
lr_scheduler, num_epochs = create_scheduler(args, optimizer)

In [19]:
resume_state = {}
resume_epoch = None
start_epoch = 0
if args.start_epoch is not None:
    # a specified start_epoch will always override the resume epoch
    start_epoch = args.start_epoch
elif resume_epoch is not None:
    start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
    lr_scheduler.step(start_epoch)

In [20]:
if args.local_rank == 0:
    logging.info('Scheduled epochs: {}'.format(num_epochs))

INFO:root:Scheduled epochs: 200


In [21]:
train_dataset = Dataset(train_dir)
val_dataset = Dataset(val_dir, load_bytes=False, class_map='')

In [22]:
param_count = sum([m.numel() for m in model_g.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 4192780


In [23]:
collate_fn = None
if args.prefetcher and args.mixup > 0:
    assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
    collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

In [24]:
model_raw, test_time_pool = apply_test_time_pool(model_raw, data_config, args)

In [25]:
train_loader = create_loader(
        train_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

In [26]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
val_loader = create_loader(
    val_dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [27]:
if args.jsd:
    assert num_aug_splits > 1  # JSD only valid with aug splits set
    train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.mixup > 0.:
    # smoothing is handled with mixup label transform
    train_loss_fn = SoftTargetCrossEntropy()
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.smoothing:
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
else:
    train_loss_fn = nn.CrossEntropyLoss()
    validate_loss_fn = train_loss_fn

gan_loss_fn = nn.BCELoss()

In [28]:
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
    output_base = args.output if args.output else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(data_config['input_size'][-1])
    ])
    output_dir = get_outdir(output_base, 'train', exp_name)
    decreasing = True if eval_metric == 'loss' else False
    saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)

In [29]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume=args.resume)

In [30]:
def train_epoch(epoch, model_g, model_raw, model_ns, loader, optimizer, loss_fn, gan_loss_fn, args,
               lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    losses_g = AverageMeter()
    losses_ce = AverageMeter()
    losses_kd = AverageMeter()
    
    model_g.train()
    model_ns.eval()
    model_raw.eval()

    real, fake = 1, 0
    
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (inputs, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        z = torch.randn(args.batch_size, 1, 15, 15, device='cuda')
        g_out = model_g(z)
        
        # G train
        model_g.zero_grad()
#         label.fill_(real)
#         g_d_out = model_d(inputs+g_out)
#         loss_g = gan_loss_fn(g_d_out, label)
        
        # KD train
        inputs_z = inputs + g_out
        output, traj_raw = model_raw(inputs_z)
        with torch.no_grad():
            out_ns, traj_ns = model_ns(inputs)
            out_ns = out_ns.detach()
            
        p_s = F.log_softmax(output/args.T, dim=1)
        p_t = F.softmax(out_ns/args.T, dim=1)
        loss_kd = F.kl_div(p_s, p_t, size_average=False) * (args.T ** 2) / output.shape[0]
        
        # CE train
        loss_ce = loss_fn(output, target)
        
        # overall loss
        loss = args.lambda_kd * loss_kd + args.lambda_ce * loss_ce
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if not args.distributed:
#             losses_d.update(loss_d.item(), inputs.size(0))
#             losses_g.update(loss_g.item(), inputs.size(0))
            losses_kd.update(loss_kd.item(), inputs.size(0))
            losses_ce.update(loss_ce.item(), inputs.size(0))
            losses_m.update(loss.item(), inputs.size(0))

        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), inputs.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
#                     'Loss_d: {loss_d.val:>9.6f} ({loss_d.avg:>6.4f})  '
#                     'Loss_g: {loss_g.val:>9.6f} ({loss_g.avg:>6.4f})  '
                    'Loss_kd: {loss_kd.val:>9.6f} ({loss_kd.avg:>6.4f})  '
                    'Loss_ce: {loss_ce.val:>9.6f} ({loss_ce.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
#                         loss_d=losses_d,
#                         loss_g=losses_g,
                        loss_kd=losses_kd,
                        loss_ce=losses_ce,
                        batch_time=batch_time_m,
                        rate=inputs.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        inputs_z,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
                model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])

In [31]:
def val_epoch(model_raw, model, val_loader, criterion, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_raw.eval()
    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        end = time.time()
        for i, (inputs, target) in enumerate(val_loader):
            if args.no_prefetcher:
                target = target.cuda()
                inputs = inputs.cuda()
            z = torch.randn(inputs.shape[0], 1, 15, 15, device='cuda')
            
            out = model(z)
            # synthesizing input + generator
            inputs_out = inputs + out
            # compute output
            output, foward_list = model_raw(inputs_out)
            loss = criterion(output, target)
            
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(val_loader), batch_time=batch_time,
                        rate_avg=inputs.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
    return results

In [None]:
for epoch in range(start_epoch, num_epochs):
    if args.distributed:
        train_loader.sampler.set_epoch(epoch)

    train_metrics = train_epoch(
        epoch, model_g, model_raw, model_ns, train_loader, optimizer, train_loss_fn, gan_loss_fn, args,
        lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
        use_amp=use_amp, model_ema=model_ema)

    eval_metrics = val_epoch(model_raw, model_g, val_loader, validate_loss_fn, args)

    if model_ema is not None and not args.model_ema_force_cpu:
        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
            distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')


    if lr_scheduler is not None:
        # step LR for next epoch
        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
        
    update_summary(
        epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
        write_header=best_metric is None)

    if saver is not None:
    # save proper checkpoint with eval metric
        save_metric = eval_metrics[eval_metric]
        best_metric, best_epoch = saver.save_checkpoint(
            model_g, optimizer, args,
            epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)



DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/239]  Time: 5.281s (5.281s,   39.76/s)  Loss:  0.5095 (0.5095)  Acc@1:  88.095 ( 88.095)  Acc@5:  97.619 ( 97.619)
INFO:root:Test: [  10/239]  Time: 0.326s (0.789s,  266.00/s)  Loss:  0.9716 (0.6208)  Acc@1:  77.619 ( 85.584)  Acc@5:  93.333 ( 97.143)
INFO:root:Test: [  20/239]  Time: 0.327s (0.570s,  368.49/s)  Loss:  0.4691 (0.7224)  Acc@1:  90.952 ( 82.608)  Acc@5:  97.143 ( 96.327)
INFO:root:Test: [  30/239]  Time: 0.325s (0.492s,  427.13/s)  Loss:  0.5344 (0.6752)  Acc@1:  88.571 ( 84.163)  Acc@5:  95.714 ( 96.375)
INFO:root:Test: [  40/239]  Time: 0.331s (0.452s,  464.68/s)  Loss:  0.7856 (0.6633)  Acc@1:  82.857 ( 84.506)  Acc@5:  94.762 ( 96.458)
INFO:root:Test: [  50/239]  Time: 0.323s (0.428s,  490.36/s)  Loss:  0.6504 (0.6785)  Acc@1:  83.333 ( 84.006)  Acc@5:  96.190 ( 96.536)
INFO:root:Test: [  60/239]  Time: 0.327s (0.412s,  510.12/s)  Loss:  0.3491 (0.6817)  Acc@1:  93.333 ( 83.692)  Acc@5:  99.048 ( 96.667)
INFO:root:Test: [  70/239]  Time: 0.327s 



DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401


INFO:root:Test: [   0/239]  Time: 5.175s (5.175s,   40.58/s)  Loss:  0.5108 (0.5108)  Acc@1:  88.095 ( 88.095)  Acc@5:  97.619 ( 97.619)
INFO:root:Test: [  10/239]  Time: 0.331s (0.775s,  270.99/s)  Loss:  0.9671 (0.6212)  Acc@1:  77.619 ( 85.455)  Acc@5:  93.810 ( 97.143)
INFO:root:Test: [  20/239]  Time: 0.349s (0.564s,  372.22/s)  Loss:  0.4680 (0.7232)  Acc@1:  90.952 ( 82.744)  Acc@5:  98.095 ( 96.372)
INFO:root:Test: [  30/239]  Time: 0.324s (0.488s,  430.25/s)  Loss:  0.5364 (0.6754)  Acc@1:  88.571 ( 84.286)  Acc@5:  95.714 ( 96.406)
INFO:root:Test: [  40/239]  Time: 0.327s (0.449s,  467.34/s)  Loss:  0.7853 (0.6639)  Acc@1:  82.857 ( 84.599)  Acc@5:  94.286 ( 96.446)
INFO:root:Test: [  50/239]  Time: 0.324s (0.426s,  493.31/s)  Loss:  0.6604 (0.6796)  Acc@1:  83.810 ( 84.118)  Acc@5:  96.190 ( 96.517)
INFO:root:Test: [  60/239]  Time: 0.325s (0.410s,  512.72/s)  Loss:  0.3534 (0.6828)  Acc@1:  93.333 ( 83.833)  Acc@5:  99.048 ( 96.628)
INFO:root:Test: [  70/239]  Time: 0.327s 

INFO:root:Current checkpoints:
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-1.pth.tar', 78.926)
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-2.pth.tar', 78.916)



DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/239]  Time: 5.742s (5.742s,   36.57/s)  Loss:  0.5127 (0.5127)  Acc@1:  88.095 ( 88.095)  Acc@5:  97.619 ( 97.619)
INFO:root:Test: [  10/239]  Time: 0.331s (0.833s,  252.10/s)  Loss:  0.9687 (0.6200)  Acc@1:  77.143 ( 85.541)  Acc@5:  93.810 ( 97.186)
INFO:root:Test: [  20/239]  Time: 0.327s (0.593s,  354.19/s)  Loss:  0.4722 (0.7219)  Acc@1:  90.952 ( 82.562)  Acc@5:  97.619 ( 96.417)
INFO:root:Test: [  30/239]  Time: 0.333s (0.508s,  413.41/s)  Loss:  0.5297 (0.6745)  Acc@1:  89.524 ( 84.178)  Acc@5:  96.190 ( 96.452)
INFO:root:Test: [  40/239]  Time: 0.331s (0.465s,  451.97/s)  Loss:  0.7802 (0.6627)  Acc@1:  82.857 ( 84.541)  Acc@5:  94.762 ( 96.469)
INFO:root:Test: [  50/239]  Time: 0.332s (0.438s,  479.18/s)  Loss:  0.6532 (0.6781)  Acc@1:  83.810 ( 84.052)  Acc@5:  96.190 ( 96.527)
INFO:root:Test: [  60/239]  Time: 0.328s (0.420s,  499.54/s)  Loss:  0.3479 (0.6813)  Acc@1:  93.810 ( 83.794)  Acc@5:  99.048 ( 96.651)
INFO:root:Test: [  70/239]  Time: 0.341s 





INFO:root:Test: [   0/239]  Time: 6.633s (6.633s,   31.66/s)  Loss:  0.5108 (0.5108)  Acc@1:  88.095 ( 88.095)  Acc@5:  97.619 ( 97.619)
INFO:root:Test: [  10/239]  Time: 0.332s (0.908s,  231.25/s)  Loss:  0.9686 (0.6200)  Acc@1:  77.619 ( 85.541)  Acc@5:  93.810 ( 97.143)
INFO:root:Test: [  20/239]  Time: 0.331s (0.633s,  331.93/s)  Loss:  0.4709 (0.7223)  Acc@1:  91.429 ( 82.676)  Acc@5:  98.095 ( 96.417)
INFO:root:Test: [  30/239]  Time: 0.324s (0.534s,  393.14/s)  Loss:  0.5314 (0.6750)  Acc@1:  89.048 ( 84.209)  Acc@5:  96.190 ( 96.452)
INFO:root:Test: [  40/239]  Time: 0.326s (0.484s,  433.69/s)  Loss:  0.7786 (0.6634)  Acc@1:  82.857 ( 84.576)  Acc@5:  94.762 ( 96.469)
INFO:root:Test: [  50/239]  Time: 0.327s (0.454s,  462.94/s)  Loss:  0.6569 (0.6786)  Acc@1:  83.810 ( 84.108)  Acc@5:  96.190 ( 96.536)
INFO:root:Test: [  60/239]  Time: 0.326s (0.433s,  484.89/s)  Loss:  0.3528 (0.6818)  Acc@1:  93.333 ( 83.841)  Acc@5:  99.048 ( 96.659)
INFO:root:Test: [  70/239]  Time: 0.329s 

INFO:root:Test: [ 230/239]  Time: 0.324s (0.356s,  589.12/s)  Loss:  1.8873 (0.8967)  Acc@1:  56.190 ( 78.969)  Acc@5:  85.714 ( 94.257)
INFO:root: * Acc@1 78.986 (21.014) Acc@5 94.306 (5.694)
INFO:root:Current checkpoints:
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-4.pth.tar', 78.986)
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-3.pth.tar', 78.962)
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-1.pth.tar', 78.926)
 ('./output/train/20200625-133839-tf_efficientnet_b1-240/checkpoint-2.pth.tar', 78.916)



