In [1]:
import timm
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from collections import OrderedDict
from timm.data import Dataset, create_loader, resolve_data_config,  FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model, apply_test_time_pool
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from munch import Munch
import yaml
import sys
from resnet_generator import Generator, Critic
from datetime import datetime
import numpy as np
import torch.distributed as dist
import torch.nn.functional as F
from transformer import Attention
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
from crd.criterion import CRDLoss
from dataset.imagenet import get_dataloader_sample
from timm.data import transforms_factory

In [2]:
# from apex import amp
# from apex.parallel import DistributedDataParallel as DDP
# from apex.parallel import convert_syncbn_model
has_apex = False

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.backends.cudnn.benchmark = True

In [5]:
with open('config/train.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher
args.distributed = False
args.device = 'cuda'
args.world_size = 3
args.rank = 0
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

INFO:root:Training with a single process on 3 GPUs.


In [6]:
if args.distributed:
    args.num_gpu = 1
    args.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=args.rank, world_size=args.world_size)
    args.world_size = torch.distributed.get_world_size()
    args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

In [7]:
torch.manual_seed(args.seed + args.rank)

<torch._C.Generator at 0x7fa2f0efae90>

In [8]:
model_ns = timm.create_model('tf_efficientnet_b1_ns', pretrained=True)
model_ns = model_ns.cuda()
model_ns = torch.nn.DataParallel(model_ns)
model_ns = model_ns.eval()

In [9]:
# fc_name = 'module.classifier.'
# teacher_model_weights = {}
# for name, param in model_ns.named_parameters():
#     teacher_model_weights[name] = param.detach()

In [10]:
model_raw = timm.create_model('tf_efficientnet_b1', pretrained=True)
model_raw = model_raw.cuda()
model_raw = torch.nn.DataParallel(model_raw)

In [11]:
model_g = Generator(args, img_size=args.img_size, max_conv_dim=256)
model_g = model_g.cuda()

In [12]:
model_critic = Critic(args)
model_critic = model_critic.cuda()
model_critic = torch.nn.DataParallel(model_critic)

In [13]:
use_amp = False
if has_apex and args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    model_raw =  amp.initialize(model_raw)
    model_ns = amp.initialize(model_ns)
    use_amp = True
if args.local_rank == 0:
    logging.info('NVIDIA APEX {}. AMP {}.'.format(
        'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

INFO:root:NVIDIA APEX not installed. AMP off.


In [14]:
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
    resume_state, resume_epoch = resume_checkpoint(model_g, args.resume)
if resume_state and not args.no_resume_opt:
    if 'optimizer' in resume_state:
        if args.local_rank == 0:
            logging.info('Restoring Optimizer state from checkpoint')
        optimizer.load_state_dict(resume_state['optimizer'])
    if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
        if args.local_rank == 0:
            logging.info('Restoring NVIDIA AMP state from checkpoint')
        amp.load_state_dict(resume_state['amp'])
del resume_state

In [15]:
model_g = torch.nn.DataParallel(model_g)

In [16]:
train_dir = '/home/data/imagenet/train'
val_dir = '/home/data/imagenet/val'
data_config = resolve_data_config(vars(args), model=model_g, verbose=args.local_rank == 0)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 240, 240)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.882


In [17]:
num_aug_splits = 0

In [18]:
if args.distributed:
    if args.sync_bn:
        assert not args.split_bn
        try:
            if has_apex:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logging.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
        except Exception as e:
            logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
    # NOTE: EMA model does not need to be wrapped by DDP

In [19]:
# train_dataset = Dataset(train_dir)
val_dataset = Dataset(val_dir, load_bytes=False, class_map='')

In [20]:
param_count = sum([m.numel() for m in model_g.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 3662851


In [21]:
collate_fn = None
if args.prefetcher and args.mixup > 0:
    assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
    collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

In [22]:
model_raw, test_time_pool = apply_test_time_pool(model_raw, data_config, args)

In [23]:
transform = transforms_factory.transforms_imagenet_train(img_size=args.img_size, auto_augment=False, color_jitter=args.color_jitter, interpolation=args.interpolation, re_count=args.recount, re_mode=args.remode, re_prob=args.reprob)

In [24]:
train_loader = get_dataloader_sample(transform, dataset='imagenet', batch_size=args.batch_size, is_sample=True, k=4096)

stage1 finished!
dataset initialized!
num_samples 1281167
num_class 1000


In [25]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
val_loader = create_loader(
    val_dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [26]:
if args.jsd:
    assert num_aug_splits > 1  # JSD only valid with aug splits set
    train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.mixup > 0.:
    # smoothing is handled with mixup label transform
    train_loss_fn = SoftTargetCrossEntropy()
    validate_loss_fn = nn.CrossEntropyLoss()
elif args.smoothing:
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
else:
    train_loss_fn = nn.CrossEntropyLoss()
    validate_loss_fn = train_loss_fn
    
crd_loss_fn = CRDLoss(args).cuda()
critic_loss_fn = nn.BCELoss()
real_label = 1
fake_label = 0

In [27]:
trainable_list = nn.ModuleList([])
trainable_list.append(model_g)
trainable_list.append(model_raw)
trainable_list.append(crd_loss_fn.embed_s)
trainable_list.append(crd_loss_fn.embed_t)

ModuleList(
  (0): DataParallel(
    (module): Generator(
      (from_rgb): Conv2d(3, 68, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (encode): ModuleList(
        (0): ResBlk(
          (actv): LeakyReLU(negative_slope=0.2)
          (conv1): Conv2d(68, 68, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Conv2d(68, 136, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (norm2): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (conv1x1): Conv2d(68, 136, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): ResBlk(
          (actv): LeakyReLU(negative_slope=0.2)
          (conv1): Conv2d(136, 136, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Conv2d(136, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(136, eps=1e-05

In [28]:
optimizer = create_optimizer(args, trainable_list)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)

In [29]:
optimizer_critic = create_optimizer(args, model_critic)
lr_scheduler_critic, _ = create_scheduler(args, optimizer_critic)

In [30]:
if args.local_rank == 0:
    logging.info('Scheduled epochs: {}'.format(num_epochs))

INFO:root:Scheduled epochs: 200


In [31]:
resume_state = {}
resume_epoch = None
start_epoch = 0
if args.start_epoch is not None:
    # a specified start_epoch will always override the resume epoch
    start_epoch = args.start_epoch
elif resume_epoch is not None:
    start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
    lr_scheduler.step(start_epoch)

In [32]:
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
    output_base = args.output if args.output else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(data_config['input_size'][-1])
    ])
    output_dir = get_outdir(output_base, 'train', exp_name)
    decreasing = True if eval_metric == 'loss' else False
    saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)

In [33]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume=args.resume)

In [34]:
def train_epoch(epoch, model_g, model_raw, model_ns, model_critic, loader, optimizer, optimizer_critic, loss_fn, crd_loss_fn, critic_loss_fn, args,
               lr_scheduler=None, lr_scheduler_critic=None, saver=None, output_dir='', use_amp=False, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    losses_crd = AverageMeter()
    losses_ce = AverageMeter()
    losses_kd = AverageMeter()
    losses_critic = AverageMeter()
    losses_g = AverageMeter()
    
    model_g.train()
    model_critic.train()
    model_raw.train()
    model_ns.eval()
    
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
        
    for batch_idx, data in enumerate(loader):
        inputs, target, index, contrast_idx = data
        inputs, target, index, contrast_idx = inputs.cuda(), target.cuda(), index.cuda(), contrast_idx.cuda()
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        
        g_out = model_g(inputs)
        
        inputs_z = inputs + g_out
        output, f_raw = model_raw(inputs_z)
        with torch.no_grad():
#             inputs_ns = F.interpolate(inputs, size=300, mode='bicubic')
            out_ns, f_ns = model_ns(inputs)
            out_ns = out_ns.detach()
        
        # Critic
        model_critic.train()
        model_critic.zero_grad()
        label = torch.full((f_raw.size(0), ), real_label, device='cuda')
        critic_out = model_critic(f_ns).view(-1)
        loss_critic_ns = critic_loss_fn(critic_out, label)
        loss_critic_ns.backward()
        loss_critic_ns_out = loss_critic_ns.item()
        
        label.fill_(fake_label)
        critic_out = model_critic(f_raw.detach()).view(-1)
        loss_critic_raw = critic_loss_fn(critic_out, label)
        loss_critic_raw.backward()
        loss_critic_raw_out = loss_critic_raw.item()
        loss_critic = loss_critic_ns_out + loss_critic_raw_out
        optimizer_critic.step()
        
        # KD
        p_s = F.log_softmax(output/args.T, dim=1)
        p_t = F.softmax(out_ns/args.T, dim=1)
        loss_kd = F.kl_div(p_s, p_t, size_average=False) * (args.T ** 2) / output.shape[0]
        
       # CE
        loss_ce = loss_fn(output, target)
        
        # CRD
        loss_crd = crd_loss_fn(f_raw, f_ns, index, contrast_idx)
        
        # G
        model_critic.eval()
        label.fill_(real_label)
        critic_g_out = model_critic(f_raw).view(-1)
        loss_g = critic_loss_fn(critic_g_out, label)
        
        # overall loss
        loss = args.lambda_kd * loss_kd + args.lambda_ce * loss_ce + args.lambda_crd * loss_crd + loss_g
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if not args.distributed:
            losses_kd.update(loss_kd.item(), inputs.size(0))
            losses_ce.update(loss_ce.item(), inputs.size(0))
            losses_crd.update(loss_crd.item(), inputs.size(0))
            losses_critic.update(loss_critic, inputs.size(0))
            losses_g.update(loss_g.item(), inputs.size(0))
            losses_m.update(loss.item(), inputs.size(0))

        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), inputs.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Loss_kd: {loss_kd.val:>9.6f} ({loss_kd.avg:>6.4f})  '
                    'Loss_ce: {loss_ce.val:>9.6f} ({loss_ce.avg:>6.4f})  '
                    'Loss_crd: {loss_crd.val:>9.6f} ({loss_crd.avg:>6.4f})  '
                    'Loss_g: {loss_g.val:>9.6f} ({loss_g.avg:>6.4f})  '
                    'Loss_critic: {loss_critic.val:>9.6f} ({loss_critic.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.7f}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        loss_kd=losses_kd,
                        loss_ce=losses_ce,
                        loss_crd=losses_crd,
                        loss_g=losses_g,
                        loss_critic=losses_critic,
                        batch_time=batch_time_m,
                        rate=inputs.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        inputs_z,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
                model_raw, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
            
        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
            lr_scheduler_critic.step_update(num_updates=num_updates, metric=losses_critic.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()
        

    return OrderedDict([('loss', losses_m.avg)])

In [35]:
def val_epoch(model_raw, model, val_loader, criterion, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_raw.eval()
    model.eval()
#     model_att.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        end = time.time()
        for i, (inputs, target) in enumerate(val_loader):
            if args.no_prefetcher:
                target = target.cuda()
                inputs = inputs.cuda()
            
            out = model(inputs)
            # synthesizing input + generator
            inputs_out = inputs + out
            # compute output
            output, _ = model_raw(inputs_out)
#             output = model_att(inputs_out, output)
            loss = criterion(output, target)
            
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(val_loader), batch_time=batch_time,
                        rate_avg=inputs.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
    return results

In [None]:
for epoch in range(start_epoch, num_epochs):
    if args.distributed:
        train_loader.sampler.set_epoch(epoch)

    train_metrics = train_epoch(
        epoch, model_g, model_raw, model_ns, model_critic, train_loader, optimizer, optimizer_critic, 
        train_loss_fn, crd_loss_fn, critic_loss_fn, args,
        lr_scheduler=lr_scheduler, lr_scheduler_critic=lr_scheduler_critic, 
        saver=saver, output_dir=output_dir,
        use_amp=use_amp, model_ema=model_ema)

    eval_metrics = val_epoch(model_raw, model_g, val_loader, validate_loss_fn, args)

    if model_ema is not None and not args.model_ema_force_cpu:
        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
            distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

    if lr_scheduler is not None:
        # step LR for next epoch
        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
        
    update_summary( 
        epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
        write_header=best_metric is None)

    if saver is not None:
    # save proper checkpoint with eval metric
        save_metric = eval_metrics[eval_metric]
        best_metric, best_epoch = saver.save_checkpoint(
            model_g, optimizer, args,
            epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
        torch.save(model_raw, 'raw')
        torch.save(model_critic, 'critic')



normalization constant Z_v1 is set to 2831728.0
normalization constant Z_v2 is set to 2842864.0










DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




INFO:root:Test: [   0/327]  Time: 7.410s (7.410s,   20.65/s)  Loss:  1.5958 (1.5958)  Acc@1:  71.242 ( 71.242)  Acc@5:  88.889 ( 88.889)
INFO:root:Test: [  10/327]  Time: 0.349s (0.985s,  155.27/s)  Loss:  2.2209 (1.7932)  Acc@1:  52.941 ( 65.597)  Acc@5:  79.085 ( 86.393)
INFO:root:Test: [  20/327]  Time: 0.333s (0.676s,  226.35/s)  Loss:  3.0016 (2.0914)  Acc@1:  37.908 ( 58.948)  Acc@5:  67.320 ( 83.287)
INFO:root:Test: [  30/327]  Time: 0.335s (0.565s,  270.69/s)  Loss:  1.2833 (2.0462)  Acc@1:  75.163 ( 60.046)  Acc@5:  90.850 ( 83.597)
INFO:root:Test: [  40/327]  Time: 0.336s (0.509s,  300.88/s)  Loss:  2.8338 (2.1539)  Acc@1:  42.484 ( 59.509)  Acc@5:  71.242 ( 82.767)
INFO:root:Test: [  50/327]  Time: 0.333s (0.474s,  322.80/s)  Loss:  1.7075 (2.0857)  Acc@1:  61.438 ( 60.887)  Acc@5:  89.542 ( 83.442)
INFO:root:Test: [  60/327]  Time: 0.333s (0.451s,  339.35/s)  Loss:  2.7078 (2.0920)  Acc@1:  39.869 ( 59.552)  Acc@5:  70.588 ( 83.199)
INFO:root:Test: [  70/327]  Time: 0.332s 

DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401






