In [1]:
import timm
import torchsummary
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from collections import OrderedDict
from timm.data import Dataset, create_loader, resolve_data_config,  FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model, apply_test_time_pool
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from munch import Munch
import yaml
import sys
from generator import Generator
from datetime import datetime
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
has_apex = True

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.backends.cudnn.benchmark = True

In [5]:
with open('config/train.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
args = Munch(config)
args.prefetcher = not args.no_prefetcher
args.distributed = False
args.device = 'cuda'
args.world_size = 1
args.rank = 0
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)

INFO:root:Training with a single process on 3 GPUs.


In [6]:
torch.manual_seed(args.seed + args.rank)

<torch._C.Generator at 0x7f5bbc854170>

In [7]:
model_ns = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
model_ns = model_ns.cuda()
# model_ns = torch.nn.DataParallel(model_ns).cuda()

In [8]:
model_raw = timm.create_model('tf_efficientnet_b7', pretrained=True)
model_raw = model_raw.cuda()
# model = torch.nn.DataParallel(model).cuda()

In [9]:
model = Generator((3, 600, 600), args)
model = model.cuda()

In [10]:
train_dir = '/home/data/imagenet/train'
val_dir = '/home/data/imagenet/val'
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

INFO:root:Data processing configuration for current model + dataset:
INFO:root:	input_size: (3, 600, 600)
INFO:root:	interpolation: bicubic
INFO:root:	mean: (0.485, 0.456, 0.406)
INFO:root:	std: (0.229, 0.224, 0.225)
INFO:root:	crop_pct: 0.875


In [11]:
num_aug_splits = 0

In [12]:
# if args.split_bn:
#     assert num_aug_splits > 1 or args.resplit
#     model = convert_splitbn_model(model, max(num_aug_splits, 2))

In [13]:
optimizer = create_optimizer(args, model)

In [14]:
use_amp = False
if has_apex and args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    use_amp = True
if args.local_rank == 0:
    logging.info('NVIDIA APEX {}. AMP {}.'.format(
        'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))

INFO:root:NVIDIA APEX installed. AMP on.


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
if args.distributed:
    if args.sync_bn:
        assert not args.split_bn
        try:
            if has_apex:
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                logging.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
        except Exception as e:
            logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
    # NOTE: EMA model does not need to be wrapped by DDP


In [16]:
lr_scheduler, num_epochs = create_scheduler(args, optimizer)

In [17]:
resume_state = {}
resume_epoch = None
start_epoch = 0
if args.start_epoch is not None:
    # a specified start_epoch will always override the resume epoch
    start_epoch = args.start_epoch
elif resume_epoch is not None:
    start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
    lr_scheduler.step(start_epoch)

In [18]:
if args.local_rank == 0:
    logging.info('Scheduled epochs: {}'.format(num_epochs))

INFO:root:Scheduled epochs: 200


In [19]:
train_dataset = Dataset(train_dir)
val_dataset = Dataset(val_dir, load_bytes=False, class_map='')

In [20]:
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model created, param count: %d' % (param_count))

INFO:root:Model created, param count: 70208448


In [21]:
collate_fn = None
if args.prefetcher and args.mixup > 0:
    assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
    collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

In [22]:
model_raw, test_time_pool = apply_test_time_pool(model_raw, data_config, args)
model_raw = torch.nn.DataParallel(model_raw).cuda()
model_ns = torch.nn.DataParallel(model_ns).cuda()

In [23]:
train_loader = create_loader(
        train_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

In [24]:
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
val_loader = create_loader(
    val_dataset,
    input_size=data_config['input_size'],
    batch_size=args.batch_size,
    is_training=False,
    use_prefetcher=args.prefetcher,
    interpolation=data_config['interpolation'],
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=args.workers,
    crop_pct=crop_pct,
    pin_memory=args.pin_mem,
    tf_preprocessing=args.tf_preprocessing)

In [25]:
if args.jsd:
    assert num_aug_splits > 1  # JSD only valid with aug splits set
    train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
    # smoothing is handled with mixup label transform
    train_loss_fn = SoftTargetCrossEntropy().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
    train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = train_loss_fn

loss_l1_fn = nn.L1Loss().cuda()

In [26]:
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
    output_base = args.output if args.output else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(data_config['input_size'][-1])
    ])
    output_dir = get_outdir(output_base, 'train', exp_name)
    decreasing = True if eval_metric == 'loss' else False
    saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)

In [27]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume=args.resume)

In [28]:
def train_epoch(epoch, model, model_ns, model_raw, loader, optimizer, loss_fn, loss_traj_fn, args,
               lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    losses_l1 = AverageMeter()
    losses_traj = AverageMeter()
    
    model.train()
    model_ns.eval()
    model_raw.eval()
    
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (inputs, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            inputs, target = inputs.cuda(), target.cuda()
        
        z_out = model(z)
        inputs_z = inputs + z_out
        inputs_z = Variable(inputs_z.data, requires_grad=True)
        
        output, traj_raw = model_raw(inputs_z)
        with torch.no_grad():
            output_ns, traj_ns = model_ns(inputs)
        
        loss_ns = loss_fn(output_ns, target)
        loss_raw = loss_fn(output, target)
        loss_l1 = torch.abs(loss_raw - loss_ns)
        
        loss_traj = 0
        for i in range(len(traj_raw)):
            value = loss_traj_fn(traj_raw[i], traj_ns[i])        
            loss_traj += value
        
        loss = args.lambda_l1 * loss_l1 + args.lambda_traj * loss_traj
        
        if not args.distributed:
            losses_l1.update(loss_l1.item(), inputs.size(0))
            losses_traj.update(loss_traj.item(), inputs.size(0))
            losses_m.update(loss.item(), inputs.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), inputs.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Loss_l1: {loss_l1.val:>9.6f} ({loss_l1.avg:>6.4f})  '
                    'Loss_traj: {loss_traj.val:>9.6f} ({loss_traj.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        loss_l1=losses_l1,
                        loss_traj=losses_traj,
                        batch_time=batch_time_m,
                        rate=inputs.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        inputs_z,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(
                model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])

In [29]:
def val_epoch(model_raw, model, val_loader, criterion, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_raw.eval()
    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        inputs = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
        z = Variable(torch.Tensor(np.random.normal(0, 1, (inputs.size(0), args.latent_dim)))).cuda()
        z_out = model(z)
        inputs_z = inputs + z_out
        
        model_raw(inputs_z)
        end = time.time()
        for i, (inputs, target) in enumerate(val_loader):
            if args.no_prefetcher:
                target = target.cuda()
                inputs = inputs.cuda()

            # synthesizing input + generator

            # compute output
            output, foward_list = model_raw(inputs_z)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(val_loader), batch_time=batch_time,
                        rate_avg=inputs.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
    return results

In [30]:
for epoch in range(start_epoch, num_epochs):
    if args.distributed:
        loader_train.sampler.set_epoch(epoch)

    train_metrics = train_epoch(
        epoch, model, model_ns, model_raw, train_loader, optimizer, train_loss_fn, loss_l1_fn, args,
        lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
        use_amp=use_amp, model_ema=model_ema)

    if (epoch+1)%5 == 0:
        eval_metrics = val_epoch(model_raw, model, val_loader, validate_loss_fn, args)

    if model_ema is not None and not args.model_ema_force_cpu:
        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
            distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

        ema_eval_metrics = validate(
            model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
        eval_metrics = ema_eval_metrics

    if lr_scheduler is not None:
        # step LR for next epoch
        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    if (epoch+1)%10 == 0:
        update_summary(
            epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
            write_header=best_metric is None)

    if saver is not None:
    # save proper checkpoint with eval metric
        save_metric = eval_metrics[eval_metric]
        best_metric, best_epoch = saver.save_checkpoint(
            model, optimizer, args,
            epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)







DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'sRGB' 41 1
DEBUG:PIL.PngImagePlugin:STREAM b'gAMA' 54 4
DEBUG:PIL.PngImagePlugin:STREAM b'cHRM' 70 32
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 114 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 135 65401




























KeyboardInterrupt: 

In [39]:
for batch_idx, (inputs, target) in enumerate(train_loader):
    print(inputs)
    break

tensor([[[[ 0.0056,  0.0056,  0.0056,  ...,  0.4851,  0.4337,  0.4508],
          [ 0.0056,  0.0056,  0.0056,  ...,  0.4851,  0.4337,  0.4337],
          [ 0.0056,  0.0056,  0.0056,  ...,  0.4679,  0.4679,  0.4679],
          ...,
          [ 1.4783,  1.4783,  1.4783,  ...,  2.0263,  2.0434,  2.0434],
          [ 1.4440,  1.4440,  1.4440,  ...,  2.0092,  2.0092,  2.0092],
          [ 1.4440,  1.4440,  1.4440,  ...,  1.9578,  1.9578,  1.9578]],

         [[ 1.5182,  1.5182,  1.5182,  ...,  1.7458,  1.7458,  1.7108],
          [ 1.5182,  1.5182,  1.5182,  ...,  1.7458,  1.7458,  1.7458],
          [ 1.5182,  1.5182,  1.5182,  ...,  1.7633,  1.7633,  1.7633],
          ...,
          [-1.8782, -1.8782, -1.8782,  ..., -2.0357, -2.0357, -2.0357],
          [-1.8782, -1.8782, -1.8782,  ..., -2.0357, -2.0357, -2.0357],
          [-1.8782, -1.8782, -1.8782,  ..., -2.0357, -2.0357, -2.0357]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.6400,  

In [40]:
for batch_idx, (inputs, target) in enumerate(train_loader):
    print(inputs)
    break

tensor([[[[ 2.0263,  2.0263,  2.0263,  ...,  1.9064,  1.8379,  1.7865],
          [ 2.0263,  2.0263,  2.0263,  ...,  1.7523,  1.7180,  1.7009],
          [ 2.0263,  2.0263,  2.0263,  ...,  1.4098,  1.5125,  1.5639],
          ...,
          [-0.1486, -0.1486, -0.1314,  ..., -0.4226, -0.4568, -0.4739],
          [-0.1657, -0.1657, -0.1657,  ..., -0.3712, -0.4397, -0.4739],
          [-0.1657, -0.1657, -0.1999,  ..., -0.3541, -0.4397, -0.4739]],

         [[ 2.2010,  2.2010,  2.2010,  ...,  0.0476,  0.1001,  0.1176],
          [ 2.2010,  2.2010,  2.2010,  ...,  0.0126,  0.0826,  0.1001],
          [ 2.2010,  2.2010,  2.2010,  ..., -0.1099, -0.0049,  0.0126],
          ...,
          [ 0.0126,  0.0126,  0.0301,  ..., -0.3025, -0.3025, -0.2675],
          [ 0.0126,  0.0126,  0.0126,  ..., -0.2675, -0.3025, -0.3025],
          [ 0.0126,  0.0126, -0.0049,  ..., -0.2675, -0.3025, -0.3200]],

         [[ 2.4134,  2.4134,  2.4134,  ...,  1.4374,  1.3851,  1.3851],
          [ 2.4134,  2.4134,  

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cutz/thesis/timm/models/efficientnet.py", line 403, in forward
    x, forward_list = self.forward_features(x)
  File "/home/cutz/thesis/timm/models/efficientnet.py", line 394, in forward_features
    x = block(x)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cutz/thesis/timm/models/efficientnet_blocks.py", line 267, in forward
    x = self.act2(x)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cutz/thesis/timm/models/layers/activations_me.py", line 54, in forward
    return SwishJitAutoFn.apply(x)
  File "/home/cutz/thesis/timm/models/layers/activations_me.py", line 37, in forward
    return swish_jit_fwd(x)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/cutz/thesis/timm/models/layers/activations_me.py", line 19, in swish_jit_fwd
@torch.jit.script
def swish_jit_fwd(x):
    return x.mul(torch.sigmoid(x))
           ~~~~~ <--- HERE
RuntimeError: CUDA out of memory. Tried to allocate 136.00 MiB (GPU 0; 31.72 GiB total capacity; 30.20 GiB already allocated; 15.56 MiB free; 30.72 GiB reserved in total by PyTorch)

