In [1]:
import timm
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from collections import OrderedDict
from timm.data import Dataset, create_loader, resolve_data_config,  FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model, apply_test_time_pool
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from munch import Munch
import yaml
import sys
from resnet_generator import Generator, Critic
from datetime import datetime
import numpy as np
import torch.distributed as dist
import torch.nn.functional as F
from transformer import Attention
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
from crd.criterion import CRDLoss
from dataset.imagenet import get_dataloader_sample
from timm.data import transforms_factory

In [2]:
# from apex import amp
# from apex.parallel import DistributedDataParallel as DDP
# from apex.parallel import convert_syncbn_model
has_apex = False

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.backends.cudnn.benchmark = True

In [5]:
with open('config/train.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher
args.distributed = False
args.device = 'cuda'
args.world_size = 3
args.rank = 0
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

INFO:root:Training with a single process on 3 GPUs.


In [6]:
if args.distributed:
    args.num_gpu = 1
    args.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=args.rank, world_size=args.world_size)
    args.world_size = torch.distributed.get_world_size()
    args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

In [7]:
torch.manual_seed(args.seed + args.rank)

<torch._C.Generator at 0x7f9d4825be90>

In [8]:
model_ns = timm.create_model('tf_efficientnet_b1_ns', pretrained=True)
model_ns = model_ns.cuda()
model_ns = torch.nn.DataParallel(model_ns)
model_ns = model_ns.eval()

In [9]:
# fc_name = 'module.classifier.'
# teacher_model_weights = {}
# for name, param in model_ns.named_parameters():
#     teacher_model_weights[name] = param.detach()

In [10]:
model_raw = timm.create_model('tf_efficientnet_b1', pretrained=True)
model_raw = model_raw.cuda()
model_raw = torch.nn.DataParallel(model_raw)
if args.resume:
    print('raw load')
    model_raw.load_state_dict(torch.load('raw').state_dict())

In [11]:
model_g = Generator(args, img_size=args.img_size, max_conv_dim=256)
model_g = model_g.cuda()

In [12]:
model_critic = Critic(args)
model_critic = model_critic.cuda()
model_critic = torch.nn.DataParallel(model_critic)
if args.resume:
    print('critic load')
    model_critic.load_state_dict(torch.load('critic').state_dict())

In [13]:
use_amp = False
if has_apex and args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    model_raw =  amp.initialize(model_raw)
    model_ns = amp.initialize(model_ns)
    use_amp = True
if args.local_rank == 0:
    logging.info('NVIDIA APEX {}. AMP {}.'.format(
        'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

INFO:root:NVIDIA APEX not installed. AMP off.


In [14]:
train_dir = '/home/data/imagenet/train'
val_dir = '/home/data/imagenet/val'
data_config = resolve_data_config(vars(args), model=model_g, verbose=args.local_rank == 0)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 240, 240)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.882


In [15]:
num_aug_splits = 0

In [16]:
if args.distributed:
    if args.sync_bn:
        assert not args.split_bn
        try:
            if has_apex:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logging.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
        except Exception as e:
            logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
    # NOTE: EMA model does not need to be wrapped by DDP

In [17]:
# train_dataset = Dataset(train_dir)
val_dataset = Dataset(val_dir, load_bytes=False, class_map='')

In [18]:
param_count = sum([m.numel() for m in model_g.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 3662851


In [19]:
collate_fn = None
if args.prefetcher and args.mixup > 0:
    assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
    collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

In [20]:
model_raw, test_time_pool = apply_test_time_pool(model_raw, data_config, args)

In [21]:
transform = transforms_factory.transforms_imagenet_train(img_size=args.img_size, auto_augment=False, color_jitter=args.color_jitter, interpolation=args.interpolation, re_count=args.recount, re_mode=args.remode, re_prob=args.reprob)

In [22]:
train_loader = get_dataloader_sample(transform, dataset='imagenet', batch_size=args.batch_size, is_sample=True, k=4096)

stage1 finished!
dataset initialized!
num_samples 1281167
num_class 1000


In [23]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
val_loader = create_loader(
    val_dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [24]:
if args.jsd:
    assert num_aug_splits > 1  # JSD only valid with aug splits set
    train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.mixup > 0.:
    # smoothing is handled with mixup label transform
    train_loss_fn = SoftTargetCrossEntropy()
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.smoothing:
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
else:
    train_loss_fn = nn.CrossEntropyLoss()
    validate_loss_fn = train_loss_fn
    
crd_loss_fn = CRDLoss(args).cuda()
critic_loss_fn = nn.BCELoss()
real_label = 1
fake_label = 0

In [25]:
trainable_list = nn.ModuleList([])
trainable_list.append(model_g)
# trainable_list.append(model_raw)
trainable_list.append(crd_loss_fn.embed_s)
trainable_list.append(crd_loss_fn.embed_t)

ModuleList(
  (0): Generator(
    (from_rgb): Conv2d(3, 68, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (encode): ModuleList(
      (0): ResBlk(
        (actv): LeakyReLU(negative_slope=0.2)
        (conv1): Conv2d(68, 68, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(68, 136, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm1): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (norm2): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (conv1x1): Conv2d(68, 136, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): ResBlk(
        (actv): LeakyReLU(negative_slope=0.2)
        (conv1): Conv2d(136, 136, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(136, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm1): InstanceNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  

In [26]:
optimizer = create_optimizer(args, trainable_list)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)

In [27]:
optimizer_critic = create_optimizer(args, model_critic)
lr_scheduler_critic, _ = create_scheduler(args, optimizer_critic)

In [28]:
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
    resume_state, resume_epoch = resume_checkpoint(model_g, args.resume)
if resume_state and not args.no_resume_opt:
    if 'optimizer' in resume_state:
        if args.local_rank == 0:
            logging.info('Restoring Optimizer state from checkpoint')
        optimizer.load_state_dict(resume_state['optimizer'])
    if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
        if args.local_rank == 0:
            logging.info('Restoring NVIDIA AMP state from checkpoint')
        amp.load_state_dict(resume_state['amp'])
del resume_state

In [29]:
model_g = torch.nn.DataParallel(model_g)

In [30]:
if args.local_rank == 0:
    logging.info('Scheduled epochs: {}'.format(num_epochs))

INFO:root:Scheduled epochs: 200


In [31]:
resume_state = {}
resume_epoch = None
start_epoch = 0
if args.start_epoch is not None:
    # a specified start_epoch will always override the resume epoch
    start_epoch = args.start_epoch
elif resume_epoch is not None:
    start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
    lr_scheduler.step(start_epoch)

In [32]:
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
    output_base = args.output if args.output else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(data_config['input_size'][-1])
    ])
    output_dir = get_outdir(output_base, 'train', exp_name)
    decreasing = True if eval_metric == 'loss' else False
    saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)

In [33]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume=args.resume)

In [34]:
def train_epoch(epoch, model_g, model_raw, model_ns, model_critic, loader, optimizer, optimizer_critic, loss_fn, crd_loss_fn, critic_loss_fn, args,
               lr_scheduler=None, lr_scheduler_critic=None, saver=None, output_dir='', use_amp=False, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    losses_crd = AverageMeter()
    losses_ce = AverageMeter()
    losses_kd = AverageMeter()
    losses_critic = AverageMeter()
    losses_g = AverageMeter()
    
    model_g.train()
    model_critic.train()
    model_raw.eval()
    model_ns.eval()
    
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
        
    for batch_idx, data in enumerate(loader):
        inputs, target, index, contrast_idx = data
        inputs, target, index, contrast_idx = inputs.cuda(), target.cuda(), index.cuda(), contrast_idx.cuda()
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        
        g_out = model_g(inputs)
        
        inputs_z = inputs + g_out
        output, f_raw = model_raw(inputs_z)
        with torch.no_grad():
#             inputs_ns = F.interpolate(inputs, size=300, mode='bicubic')
            out_ns, f_ns = model_ns(inputs)
            out_ns = out_ns.detach()
        
        # Critic
        model_critic.train()
        model_critic.zero_grad()
        label = torch.full((f_raw.size(0), ), real_label, device='cuda')
        critic_out = model_critic(f_ns).view(-1)
        loss_critic_ns = args.lambda_critic * critic_loss_fn(critic_out, label)
        loss_critic_ns.backward()
        loss_critic_ns_out = loss_critic_ns.item()
        
        label.fill_(fake_label)
        critic_out = model_critic(f_raw.detach()).view(-1)
        loss_critic_raw = critic_loss_fn(critic_out, label)
        loss_critic_raw.backward()
        loss_critic_raw_out = args.lambda_critic * loss_critic_raw.item()
        loss_critic = loss_critic_ns_out + loss_critic_raw_out
        optimizer_critic.step()
        
        # KD
        p_s = F.log_softmax(output/args.T, dim=1)
        p_t = F.softmax(out_ns/args.T, dim=1)
        loss_kd = F.kl_div(p_s, p_t, size_average=False) * (args.T ** 2) / output.shape[0]
        
       # CE
        loss_ce = loss_fn(output, target)
        
        # CRD
        loss_crd = crd_loss_fn(f_raw, f_ns, index, contrast_idx)
        
        # G
        model_critic.eval()
        label.fill_(real_label)
        critic_g_out = model_critic(f_raw).view(-1)
        loss_g = critic_loss_fn(critic_g_out, label)
        
        # overall loss
        loss = args.lambda_kd * loss_kd + args.lambda_ce * loss_ce + args.lambda_crd * loss_crd + args.lambda_g * loss_g
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if not args.distributed:
            losses_kd.update(loss_kd.item(), inputs.size(0))
            losses_ce.update(loss_ce.item(), inputs.size(0))
            losses_crd.update(loss_crd.item(), inputs.size(0))
            losses_critic.update(loss_critic, inputs.size(0))
            losses_g.update(loss_g.item(), inputs.size(0))
            losses_m.update(loss.item(), inputs.size(0))

        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), inputs.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Loss_kd: {loss_kd.val:>9.6f} ({loss_kd.avg:>6.4f})  '
                    'Loss_ce: {loss_ce.val:>9.6f} ({loss_ce.avg:>6.4f})  '
                    'Loss_crd: {loss_crd.val:>9.6f} ({loss_crd.avg:>6.4f})  '
                    'Loss_g: {loss_g.val:>9.6f} ({loss_g.avg:>6.4f})  '
                    'Loss_critic: {loss_critic.val:>9.6f} ({loss_critic.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.7f}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        loss_kd=losses_kd,
                        loss_ce=losses_ce,
                        loss_crd=losses_crd,
                        loss_g=losses_g,
                        loss_critic=losses_critic,
                        batch_time=batch_time_m,
                        rate=inputs.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        inputs_z,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
                model_raw, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
            
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
            lr_scheduler_critic.step_update(num_updates=num_updates, metric=losses_critic.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()
        

    return OrderedDict([('loss', losses_m.avg)])

In [35]:
def val_epoch(model_raw, model, val_loader, criterion, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_raw.eval()
    model.eval()
#     model_att.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        end = time.time()
        for i, (inputs, target) in enumerate(val_loader):
            if args.no_prefetcher:
                target = target.cuda()
                inputs = inputs.cuda()
            
            out = model(inputs)
            # synthesizing input + generator
            inputs_out = inputs + out
            # compute output
            output, _ = model_raw(inputs_out)
#             output = model_att(inputs_out, output)
            loss = criterion(output, target)
            
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(val_loader), batch_time=batch_time,
                        rate_avg=inputs.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
    return results

In [None]:
for epoch in range(start_epoch, num_epochs):
    if args.distributed:
        train_loader.sampler.set_epoch(epoch)

    train_metrics = train_epoch(
        epoch, model_g, model_raw, model_ns, model_critic, train_loader, optimizer, optimizer_critic, 
        train_loss_fn, crd_loss_fn, critic_loss_fn, args,
        lr_scheduler=lr_scheduler, lr_scheduler_critic=lr_scheduler_critic, 
        saver=saver, output_dir=output_dir,
        use_amp=use_amp, model_ema=model_ema)

    eval_metrics = val_epoch(model_raw, model_g, val_loader, validate_loss_fn, args)

    if model_ema is not None and not args.model_ema_force_cpu:
        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
            distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

    if lr_scheduler is not None:
        # step LR for next epoch
        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
        
    update_summary( 
        epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
        write_header=best_metric is None)

    if saver is not None:
    # save proper checkpoint with eval metric
        save_metric = eval_metrics[eval_metric]
        best_metric, best_epoch = saver.save_checkpoint(
            model_g, optimizer, args,
            epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
        torch.save(model_raw, 'raw')
        torch.save(model_critic, 'critic')
        torch.save(crd_loss_fn.embed_s, 'embed_s')
        torch.save(crd_loss_fn.embed_t, 'embed_t')



normalization constant Z_v1 is set to 2838143.5
normalization constant Z_v2 is set to 2847929.0










DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/327]  Time: 6.535s (6.535s,   23.41/s)  Loss:  0.5028 (0.5028)  Acc@1:  88.889 ( 88.889)  Acc@5:  96.732 ( 96.732)
INFO:root:Test: [  10/327]  Time: 0.347s (0.902s,  169.57/s)  Loss:  1.0811 (0.5789)  Acc@1:  75.817 ( 87.879)  Acc@5:  93.464 ( 96.791)
INFO:root:Test: [  20/327]  Time: 0.333s (0.632s,  242.05/s)  Loss:  1.1346 (0.7680)  Acc@1:  69.281 ( 81.295)  Acc@5:  95.425 ( 96.047)
INFO:root:Test: [  30/327]  Time: 0.336s (0.536s,  285.61/s)  Loss:  0.3137 (0.7547)  Acc@1:  94.118 ( 81.805)  Acc@5:  98.693 ( 95.952)
INFO:root:Test: [  40/327]  Time: 0.332s (0.486s,  314.56/s)  Loss:  0.8972 (0.7416)  Acc@1:  77.124 ( 82.401)  Acc@5:  95.425 ( 95.967)
INFO:root:Test: [  50/327]  Time: 0.331s (0.456s,  335.41/s)  Loss:  0.8830 (0.6973)  Acc@1:  78.431 ( 83.686)  Acc@5:  94.118 ( 96.181)
INFO:root:Test: [  60/327]  Time: 0.331s (0.436s,  350.96/s)  Loss:  1.0438 (0.7279)  Acc@1:  68.627 ( 82.760)  Acc@5:  96.078 ( 96.068)
INFO:root:Test: [  70/327]  Time: 0.331s 

DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401










INFO:root:Test: [   0/327]  Time: 7.871s (7.871s,   19.44/s)  Loss:  0.4724 (0.4724)  Acc@1:  89.542 ( 89.542)  Acc@5:  98.039 ( 98.039)
INFO:root:Test: [  10/327]  Time: 0.332s (1.027s,  148.99/s)  Loss:  1.1647 (0.6070)  Acc@1:  69.935 ( 86.512)  Acc@5:  93.464 ( 96.316)
INFO:root:Test: [  20/327]  Time: 0.331s (0.696s,  219.90/s)  Loss:  1.2116 (0.8206)  Acc@1:  66.667 ( 79.894)  Acc@5:  93.464 ( 95.238)
INFO:root:Test: [  30/327]  Time: 0.339s (0.579s,  264.38/s)  Loss:  0.3302 (0.8063)  Acc@1:  94.118 ( 80.434)  Acc@5:  98.039 ( 95.109)
INFO:root:Test: [  40/327]  Time: 0.331s (0.518s,  295.12/s)  Loss:  0.9247 (0.7965)  Acc@1:  77.778 ( 81.125)  Acc@5:  96.078 ( 95.010)
INFO:root:Test: [  50/327]  Time: 0.332s (0.482s,  317.32/s)  Loss:  0.8569 (0.7490)  Acc@1:  79.085 ( 82.443)  Acc@5:  96.732 ( 95.374)
INFO:root:Test: [  60/327]  Time: 0.331s (0.458s,  334.25/s)  Loss:  1.0357 (0.7823)  Acc@1:  70.588 ( 81.378)  Acc@5:  96.732 ( 95.371)
INFO:root:Test: [  70/327]  Time: 0.332s 

INFO:root:Test: [ 130/327]  Time: 0.332s (0.391s,  391.52/s)  Loss:  1.3236 (0.7840)  Acc@1:  65.359 ( 81.240)  Acc@5:  89.542 ( 95.799)
INFO:root:Test: [ 140/327]  Time: 0.331s (0.387s,  395.71/s)  Loss:  0.7915 (0.8095)  Acc@1:  82.353 ( 80.629)  Acc@5:  94.118 ( 95.513)
INFO:root:Test: [ 150/327]  Time: 0.332s (0.383s,  399.43/s)  Loss:  1.6408 (0.8403)  Acc@1:  56.209 ( 79.851)  Acc@5:  86.928 ( 95.152)
INFO:root:Test: [ 160/327]  Time: 0.331s (0.380s,  402.63/s)  Loss:  1.1153 (0.8739)  Acc@1:  74.510 ( 79.061)  Acc@5:  90.850 ( 94.747)
INFO:root:Test: [ 170/327]  Time: 0.332s (0.377s,  405.51/s)  Loss:  1.0085 (0.8995)  Acc@1:  74.510 ( 78.500)  Acc@5:  90.196 ( 94.400)
INFO:root:Test: [ 180/327]  Time: 0.332s (0.375s,  408.11/s)  Loss:  1.0961 (0.9163)  Acc@1:  75.817 ( 78.128)  Acc@5:  93.464 ( 94.237)
INFO:root:Test: [ 190/327]  Time: 0.332s (0.373s,  410.42/s)  Loss:  1.3854 (0.9189)  Acc@1:  69.281 ( 78.154)  Acc@5:  89.542 ( 94.179)
INFO:root:Test: [ 200/327]  Time: 0.332s 











INFO:root:Test: [   0/327]  Time: 7.260s (7.260s,   21.07/s)  Loss:  0.5036 (0.5036)  Acc@1:  86.275 ( 86.275)  Acc@5:  98.693 ( 98.693)
INFO:root:Test: [  10/327]  Time: 0.331s (0.965s,  158.58/s)  Loss:  1.1340 (0.6313)  Acc@1:  73.203 ( 85.977)  Acc@5:  95.425 ( 96.732)
INFO:root:Test: [  20/327]  Time: 0.331s (0.664s,  230.43/s)  Loss:  1.1891 (0.8505)  Acc@1:  66.667 ( 79.739)  Acc@5:  92.810 ( 95.051)
INFO:root:Test: [  30/327]  Time: 0.335s (0.557s,  274.58/s)  Loss:  0.3977 (0.8396)  Acc@1:  92.810 ( 79.928)  Acc@5:  98.039 ( 95.024)
INFO:root:Test: [  40/327]  Time: 0.331s (0.502s,  304.66/s)  Loss:  1.0272 (0.8271)  Acc@1:  73.856 ( 80.472)  Acc@5:  93.464 ( 94.931)
INFO:root:Test: [  50/327]  Time: 0.332s (0.469s,  326.38/s)  Loss:  0.8971 (0.7811)  Acc@1:  77.778 ( 81.712)  Acc@5:  96.078 ( 95.207)
INFO:root:Test: [  60/327]  Time: 0.342s (0.447s,  342.65/s)  Loss:  1.2089 (0.8177)  Acc@1:  66.013 ( 80.467)  Acc@5:  94.771 ( 95.082)
INFO:root:Test: [  70/327]  Time: 0.335s 

INFO:root: * Acc@1 74.512 (25.488) Acc@5 92.060 (7.940)
INFO:root:Current checkpoints:
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-0.pth.tar', 77.342)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-1.pth.tar', 75.54)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-2.pth.tar', 74.512)











DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401


INFO:root:Test: [   0/327]  Time: 8.037s (8.037s,   19.04/s)  Loss:  0.4955 (0.4955)  Acc@1:  86.928 ( 86.928)  Acc@5:  98.039 ( 98.039)
INFO:root:Test: [  10/327]  Time: 0.332s (1.037s,  147.53/s)  Loss:  1.1831 (0.6218)  Acc@1:  71.242 ( 86.393)  Acc@5:  94.118 ( 96.613)
INFO:root:Test: [  20/327]  Time: 0.331s (0.701s,  218.24/s)  Loss:  1.2623 (0.8313)  Acc@1:  66.013 ( 80.050)  Acc@5:  90.850 ( 95.082)
INFO:root:Test: [  30/327]  Time: 0.331s (0.582s,  262.94/s)  Loss:  0.4084 (0.8135)  Acc@1:  92.157 ( 80.582)  Acc@5:  97.386 ( 95.087)
INFO:root:Test: [  40/327]  Time: 0.331s (0.521s,  293.74/s)  Loss:  0.9643 (0.8028)  Acc@1:  77.124 ( 81.157)  Acc@5:  92.810 ( 95.010)
INFO:root:Test: [  50/327]  Time: 0.332s (0.484s,  316.28/s)  Loss:  0.9377 (0.7580)  Acc@1:  77.124 ( 82.379)  Acc@5:  95.425 ( 95.374)
INFO:root:Test: [  60/327]  Time: 0.331s (0.459s,  333.37/s)  Loss:  1.1417 (0.7928)  Acc@1:  65.359 ( 81.217)  Acc@5:  96.078 ( 95.328)
INFO:root:Test: [  70/327]  Time: 0.331s 







DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/327]  Time: 6.482s (6.482s,   23.60/s)  Loss:  0.5532 (0.5532)  Acc@1:  85.621 ( 85.621)  Acc@5:  97.386 ( 97.386)
INFO:root:Test: [  10/327]  Time: 0.341s (0.897s,  170.58/s)  Loss:  1.2459 (0.6583)  Acc@1:  68.627 ( 85.264)  Acc@5:  91.503 ( 95.781)
INFO:root:Test: [  20/327]  Time: 0.332s (0.628s,  243.75/s)  Loss:  1.2583 (0.8676)  Acc@1:  66.013 ( 79.116)  Acc@5:  89.542 ( 94.273)
INFO:root:Test: [  30/327]  Time: 0.342s (0.533s,  287.24/s)  Loss:  0.4230 (0.8551)  Acc@1:  92.810 ( 79.422)  Acc@5:  98.039 ( 94.497)
INFO:root:Test: [  40/327]  Time: 0.334s (0.484s,  316.20/s)  Loss:  1.0723 (0.8513)  Acc@1:  72.549 ( 79.882)  Acc@5:  93.464 ( 94.468)
INFO:root:Test: [  50/327]  Time: 0.332s (0.454s,  336.73/s)  Loss:  0.9447 (0.8034)  Acc@1:  77.124 ( 81.328)  Acc@5:  95.425 ( 94.784)
INFO:root:Test: [  60/327]  Time: 0.332s (0.434s,  352.17/s)  Loss:  1.2111 (0.8425)  Acc@1:  62.745 ( 80.017)  Acc@5:  94.118 ( 94.653)
INFO:root:Test: [  70/327]  Time: 0.332s 

INFO:root:Test: [ 130/327]  Time: 0.332s (0.380s,  402.63/s)  Loss:  1.3787 (0.8535)  Acc@1:  66.667 ( 79.808)  Acc@5:  90.850 ( 95.081)
INFO:root:Test: [ 140/327]  Time: 0.331s (0.377s,  406.03/s)  Loss:  0.9493 (0.8818)  Acc@1:  75.817 ( 79.113)  Acc@5:  93.464 ( 94.753)
INFO:root:Test: [ 150/327]  Time: 0.332s (0.374s,  409.22/s)  Loss:  1.8540 (0.9162)  Acc@1:  55.556 ( 78.315)  Acc@5:  84.967 ( 94.343)
INFO:root:Test: [ 160/327]  Time: 0.332s (0.371s,  412.03/s)  Loss:  1.1716 (0.9504)  Acc@1:  73.856 ( 77.550)  Acc@5:  88.889 ( 93.907)
INFO:root:Test: [ 170/327]  Time: 0.333s (0.369s,  414.53/s)  Loss:  1.1041 (0.9778)  Acc@1:  75.163 ( 76.941)  Acc@5:  92.157 ( 93.521)
INFO:root:Test: [ 180/327]  Time: 0.333s (0.367s,  416.69/s)  Loss:  1.2873 (0.9970)  Acc@1:  73.856 ( 76.528)  Acc@5:  88.889 ( 93.298)
INFO:root:Test: [ 190/327]  Time: 0.332s (0.365s,  418.71/s)  Loss:  1.3111 (0.9994)  Acc@1:  73.856 ( 76.601)  Acc@5:  91.503 ( 93.262)
INFO:root:Test: [ 200/327]  Time: 0.335s 









DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401


INFO:root:Test: [   0/327]  Time: 7.160s (7.160s,   21.37/s)  Loss:  0.5871 (0.5871)  Acc@1:  86.275 ( 86.275)  Acc@5:  97.386 ( 97.386)
INFO:root:Test: [  10/327]  Time: 0.335s (0.957s,  159.90/s)  Loss:  1.2135 (0.6617)  Acc@1:  67.320 ( 85.443)  Acc@5:  92.810 ( 95.900)
INFO:root:Test: [  20/327]  Time: 0.332s (0.660s,  231.95/s)  Loss:  1.2558 (0.8716)  Acc@1:  67.320 ( 79.054)  Acc@5:  91.503 ( 94.553)
INFO:root:Test: [  30/327]  Time: 0.337s (0.555s,  275.88/s)  Loss:  0.4519 (0.8534)  Acc@1:  91.503 ( 79.380)  Acc@5:  97.386 ( 94.708)
INFO:root:Test: [  40/327]  Time: 0.333s (0.501s,  305.26/s)  Loss:  1.0592 (0.8508)  Acc@1:  75.817 ( 79.914)  Acc@5:  92.810 ( 94.596)
INFO:root:Test: [  50/327]  Time: 0.332s (0.468s,  326.96/s)  Loss:  0.9241 (0.8050)  Acc@1:  79.085 ( 81.341)  Acc@5:  95.425 ( 94.874)
INFO:root:Test: [  60/327]  Time: 0.334s (0.446s,  343.17/s)  Loss:  1.1594 (0.8415)  Acc@1:  68.627 ( 80.156)  Acc@5:  96.078 ( 94.761)
INFO:root:Test: [  70/327]  Time: 0.340s 

INFO:root:Test: [ 320/327]  Time: 0.331s (0.354s,  431.64/s)  Loss:  1.3266 (1.1805)  Acc@1:  71.895 ( 72.973)  Acc@5:  88.889 ( 91.023)
INFO:root: * Acc@1 73.158 (26.842) Acc@5 91.090 (8.910)
INFO:root:Current checkpoints:
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-0.pth.tar', 77.342)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-1.pth.tar', 75.54)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-3.pth.tar', 75.03)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-2.pth.tar', 74.512)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-4.pth.tar', 73.576)
 ('./output/train/20200708-184950-tf_efficientnet_b1-240/checkpoint-5.pth.tar', 73.158)









DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/327]  Time: 6.875s (6.875s,   22.26/s)  Loss:  0.5754 (0.5754)  Acc@1:  86.275 ( 86.275)  Acc@5:  96.732 ( 96.732)
INFO:root:Test: [  10/327]  Time: 0.331s (0.932s,  164.23/s)  Loss:  1.2544 (0.6755)  Acc@1:  69.935 ( 85.502)  Acc@5:  90.196 ( 95.484)
INFO:root:Test: [  20/327]  Time: 0.332s (0.647s,  236.53/s)  Loss:  1.2821 (0.8860)  Acc@1:  66.667 ( 78.929)  Acc@5:  90.196 ( 94.024)
INFO:root:Test: [  30/327]  Time: 0.334s (0.546s,  280.39/s)  Loss:  0.4233 (0.8631)  Acc@1:  91.503 ( 79.254)  Acc@5:  98.693 ( 94.265)
INFO:root:Test: [  40/327]  Time: 0.336s (0.494s,  309.93/s)  Loss:  1.1353 (0.8592)  Acc@1:  73.203 ( 79.930)  Acc@5:  91.503 ( 94.261)
INFO:root:Test: [  50/327]  Time: 0.332s (0.462s,  331.20/s)  Loss:  0.9764 (0.8124)  Acc@1:  73.856 ( 81.200)  Acc@5:  96.078 ( 94.643)
INFO:root:Test: [  60/327]  Time: 0.333s (0.441s,  346.81/s)  Loss:  1.1581 (0.8557)  Acc@1:  67.320 ( 79.771)  Acc@5:  96.078 ( 94.557)
INFO:root:Test: [  70/327]  Time: 0.339s 









DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401


INFO:root:Test: [   0/327]  Time: 7.988s (7.988s,   19.15/s)  Loss:  0.5764 (0.5764)  Acc@1:  86.928 ( 86.928)  Acc@5:  97.386 ( 97.386)
INFO:root:Test: [  10/327]  Time: 0.334s (1.033s,  148.08/s)  Loss:  1.2457 (0.6595)  Acc@1:  69.935 ( 85.502)  Acc@5:  92.810 ( 96.494)
INFO:root:Test: [  20/327]  Time: 0.330s (0.700s,  218.73/s)  Loss:  1.2998 (0.8621)  Acc@1:  65.359 ( 79.676)  Acc@5:  90.850 ( 94.771)
INFO:root:Test: [  30/327]  Time: 0.332s (0.581s,  263.39/s)  Loss:  0.3993 (0.8443)  Acc@1:  92.157 ( 79.844)  Acc@5:  97.386 ( 94.856)
INFO:root:Test: [  40/327]  Time: 0.331s (0.520s,  294.12/s)  Loss:  1.1534 (0.8456)  Acc@1:  73.203 ( 80.376)  Acc@5:  90.196 ( 94.564)
INFO:root:Test: [  50/327]  Time: 0.332s (0.483s,  316.52/s)  Loss:  0.9117 (0.7990)  Acc@1:  79.085 ( 81.687)  Acc@5:  96.078 ( 94.861)
INFO:root:Test: [  60/327]  Time: 0.331s (0.459s,  333.51/s)  Loss:  1.1075 (0.8382)  Acc@1:  67.320 ( 80.328)  Acc@5:  96.078 ( 94.782)
INFO:root:Test: [  70/327]  Time: 0.331s 

INFO:root:Test: [  90/327]  Time: 0.335s (0.417s,  366.94/s)  Loss:  0.8182 (0.8619)  Acc@1:  79.085 ( 79.516)  Acc@5:  96.732 ( 95.037)
INFO:root:Test: [ 100/327]  Time: 0.333s (0.409s,  374.49/s)  Loss:  0.4980 (0.8553)  Acc@1:  89.542 ( 79.641)  Acc@5:  96.732 ( 95.166)
INFO:root:Test: [ 110/327]  Time: 0.331s (0.402s,  380.83/s)  Loss:  0.7202 (0.8430)  Acc@1:  84.967 ( 80.045)  Acc@5:  96.732 ( 95.225)
INFO:root:Test: [ 120/327]  Time: 0.333s (0.396s,  386.05/s)  Loss:  0.8917 (0.8408)  Acc@1:  81.699 ( 80.046)  Acc@5:  95.425 ( 95.257)
INFO:root:Test: [ 130/327]  Time: 0.332s (0.391s,  390.82/s)  Loss:  1.3716 (0.8519)  Acc@1:  66.667 ( 79.734)  Acc@5:  90.196 ( 95.101)
INFO:root:Test: [ 140/327]  Time: 0.335s (0.387s,  394.92/s)  Loss:  0.9549 (0.8812)  Acc@1:  78.431 ( 79.080)  Acc@5:  92.157 ( 94.785)
INFO:root:Test: [ 150/327]  Time: 0.331s (0.384s,  398.65/s)  Loss:  1.8785 (0.9167)  Acc@1:  52.941 ( 78.280)  Acc@5:  84.967 ( 94.408)
INFO:root:Test: [ 160/327]  Time: 0.337s 







