In [1]:
import timm
import torchsummary
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from collections import OrderedDict
from timm.data import Dataset, create_loader, resolve_data_config,  FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model, apply_test_time_pool
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from munch import Munch
import yaml
import sys
from resnet_generator import Generator
from datetime import datetime
import numpy as np
import torch.distributed as dist
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
# from apex import amp
# from apex.parallel import DistributedDataParallel as DDP
# from apex.parallel import convert_syncbn_model
has_apex = False

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
os.environ['MASTER_ADDR'] = '14.49.45.144' 
os.environ['MASTER_PORT'] = '16022'

In [5]:
torch.backends.cudnn.benchmark = True

In [6]:
with open('config/train.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher
args.distributed = False
args.device = 'cuda'
args.world_size = 3
args.rank = 0
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

INFO:root:Training with a single process on 3 GPUs.


In [7]:
if args.distributed:
    args.num_gpu = 1
    args.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=args.rank, world_size=args.world_size)
    args.world_size = torch.distributed.get_world_size()
    args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

In [8]:
torch.manual_seed(args.seed + args.rank)

<torch._C.Generator at 0x7fc8fb034170>

In [9]:
model_ns = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
model_ns = model_ns.cuda()
model_ns = torch.nn.DataParallel(model_ns)

In [10]:
model_raw = timm.create_model('tf_efficientnet_b7', pretrained=True)
model_raw = model_raw.cuda()
model_raw = torch.nn.DataParallel(model_raw)

In [11]:
model = Generator(args, img_size=600, max_conv_dim=512)
model = model.cuda()
model = torch.nn.DataParallel(model)

In [12]:
train_dir = '/home/data/imagenet/train'
val_dir = '/home/data/imagenet/val'
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 600, 600)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.875


In [13]:
num_aug_splits = 0

In [14]:
optimizer = create_optimizer(args, model)

In [15]:
use_amp = False
if has_apex and args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    model_raw =  amp.initialize(model_raw)
    model_ns = amp.initialize(model_ns)
    use_amp = True
if args.local_rank == 0:
    logging.info('NVIDIA APEX {}. AMP {}.'.format(
        'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

INFO:root:NVIDIA APEX not installed. AMP off.


In [16]:
if args.distributed:
    if args.sync_bn:
        assert not args.split_bn
        try:
            if has_apex:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logging.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
        except Exception as e:
            logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
    # NOTE: EMA model does not need to be wrapped by DDP


In [17]:
lr_scheduler, num_epochs = create_scheduler(args, optimizer)

In [18]:
resume_state = {}
resume_epoch = None
start_epoch = 0
if args.start_epoch is not None:
    # a specified start_epoch will always override the resume epoch
    start_epoch = args.start_epoch
elif resume_epoch is not None:
    start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
    lr_scheduler.step(start_epoch)

In [19]:
if args.local_rank == 0:
    logging.info('Scheduled epochs: {}'.format(num_epochs))

INFO:root:Scheduled epochs: 200


In [20]:
train_dataset = Dataset(train_dir)
# val_dataset = Dataset(val_dir, load_bytes=False, class_map='')

In [21]:
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 4253205


In [22]:
collate_fn = None
if args.prefetcher and args.mixup > 0:
    assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
    collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

In [23]:
# model_raw, test_time_pool = apply_test_time_pool(model_raw, data_config, args)

In [24]:
train_loader = create_loader(
        train_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

In [25]:
# crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
# val_loader = create_loader(
#     val_dataset,
#     input_size=data_config['input_size'],
#     batch_size=args.batch_size,
#     is_training=False,
#     use_prefetcher=args.prefetcher,
#     interpolation=data_config['interpolation'],
#     mean=data_config['mean'],
#     std=data_config['std'],
#     num_workers=args.workers,
#     crop_pct=crop_pct,
#     pin_memory=args.pin_mem,
#     tf_preprocessing=args.tf_preprocessing)

In [26]:
# if args.jsd:
#     assert num_aug_splits > 1  # JSD only valid with aug splits set
#     train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
#     validate_loss_fn = nn.CrossEntropyLoss()
# elif args.mixup > 0.:
#     # smoothing is handled with mixup label transform
#     train_loss_fn = SoftTargetCrossEntropy()
#     validate_loss_fn = nn.CrossEntropyLoss()
# elif args.smoothing:
#     train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
#     validate_loss_fn = nn.CrossEntropyLoss()
# else:
#     train_loss_fn = nn.CrossEntropyLoss()
#     validate_loss_fn = train_loss_fn

loss_l1_fn = nn.L1Loss()
train_loss_fn = nn.L1Loss()

In [27]:
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
    output_base = args.output if args.output else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(data_config['input_size'][-1])
    ])
    output_dir = get_outdir(output_base, 'train', exp_name)
    decreasing = True if eval_metric == 'loss' else False
    saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)

In [28]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume=args.resume)

In [29]:
def train_epoch(epoch, model, model_raw, model_ns, loader, optimizer, loss_fn, loss_traj_fn, args,
               lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    losses_l1 = AverageMeter()
    losses_traj = AverageMeter()
    losses_recon = AverageMeter()
    
    model.train()
    model_ns.eval()
    model_raw.eval()
    
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (inputs, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            inputs, target = inputs.cuda(), target.cuda()
        
        inputs_out = model(inputs)
        inputs_z = inputs + inputs_out
        with torch.no_grad():
            out_ns, traj_ns = model_ns(inputs)
            output, traj_raw = model_raw(inputs_z)
        
        inputs = inputs.detach()
        out_ns = out_ns.detach()
        loss_recon = loss_fn(inputs_z, inputs)
        loss_l1 = loss_fn(output, out_ns)
        loss_traj = 0
        
        for i in range(len(traj_raw)):
            traj_ns[i] = traj_ns[i].detach()
            value = loss_traj_fn(traj_raw[i], traj_ns[i])        
            loss_traj += value
        
        loss = args.lambda_l1 * loss_l1 + args.lambda_traj * loss_traj + args.lambda_recon * loss_recon
        
        if not args.distributed:
            losses_l1.update(loss_l1.item(), inputs.size(0))
            losses_traj.update(loss_traj.item(), inputs.size(0))
            losses_recon.update(loss_recon.item(), inputs.size(0))
            losses_m.update(loss.item(), inputs.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), inputs.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Loss_l1: {loss_l1.val:>9.6f} ({loss_l1.avg:>6.4f})  '
                    'Loss_traj: {loss_traj.val:>9.6f} ({loss_traj.avg:>6.4f})  '
                    'Loss_recon: {loss_recon.val:>9.6f} ({loss_recon.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        loss_l1=losses_l1,
                        loss_traj=losses_traj,
                        loss_recon=losses_recon,
                        batch_time=batch_time_m,
                        rate=inputs.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        inputs_z,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
                model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])

In [30]:
def val_epoch(model_raw, model, val_loader, criterion, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_raw.eval()
    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        end = time.time()
        for i, (inputs, target) in enumerate(val_loader):
            if args.no_prefetcher:
                target = target.cuda()
                inputs = inputs.cuda()
                
            # synthesizing input + generator
            inputs_out = inputs + model_out
            # compute output
            output, foward_list = model_raw(inputs_out)
            loss = criterion(output, target)
            
            
            model_out = model(inputs)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(val_loader), batch_time=batch_time,
                        rate_avg=inputs.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
    return results

In [None]:
for epoch in range(start_epoch, num_epochs):
    if args.distributed:
        loader_train.sampler.set_epoch(epoch)

    train_metrics = train_epoch(
        epoch, model, model_ns, model_raw, train_loader, optimizer, train_loss_fn, loss_l1_fn, args,
        lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
        use_amp=use_amp, model_ema=model_ema)

#     if (epoch+1)%10 == 0:
#         eval_metrics = val_epoch(model_raw, model, val_loader, validate_loss_fn, args)

    if model_ema is not None and not args.model_ema_force_cpu:
        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
            distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

        ema_eval_metrics = validate(
            model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
        eval_metrics = ema_eval_metrics

    if lr_scheduler is not None:
        # step LR for next epoch
        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    if (epoch+1)%10 == 0:
        update_summary(
            epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
            write_header=best_metric is None)

    if saver is not None:
    # save proper checkpoint with eval metric
        save_metric = eval_metrics[eval_metric]
        best_metric, best_epoch = saver.save_checkpoint(
            model, optimizer, args,
            epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)



