From f2e6136ee9f9cd2a3ba57f33ad7e92187a9ab783 Mon Sep 17 00:00:00 2001 From: eugenium Date: Tue, 29 Jan 2019 00:56:26 -0500 Subject: [PATCH] init --- README.md | 92 ++++ cifar.py | 309 +++++++++++++ imagenet.py | 487 ++++++++++++++++++++ imagenet_refactored/imagenet.py | 311 +++++++++++++ imagenet_refactored/imagenet_greedy.py | 381 ++++++++++++++++ imagenet_refactored/utils.py | 110 +++++ imagenet_single_layer.py | 601 +++++++++++++++++++++++++ model_greedy.py | 244 ++++++++++ utils.py | 100 ++++ 9 files changed, 2635 insertions(+) create mode 100644 README.md create mode 100644 cifar.py create mode 100644 imagenet.py create mode 100644 imagenet_refactored/imagenet.py create mode 100755 imagenet_refactored/imagenet_greedy.py create mode 100755 imagenet_refactored/utils.py create mode 100644 imagenet_single_layer.py create mode 100644 model_greedy.py create mode 100644 utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..c219210 --- /dev/null +++ b/README.md @@ -0,0 +1,92 @@ +# Greedy Layerwise CNN + +This is a peliminary research code. + +Code for experiments on greedy supervised layerwise CNNs + +## Imagenet +Imagenet experiments for 1-hidden layer use the standalone imagenet_single_layer.py + +Imagenet experiments for k=2+ can be run with imagenet.py + +Note k in the paper corresponds to nlin in the code + + +To obtain the results for Imagenet + +k=3 +``` +python IMAGENER_DIR -j THREADS imagenet.py --ncnn 8 --nlin 2 + +``` + +k=2 + +``` +python IMAGENER_DIR -j THREADS imagenet.py --ncnn 8 --nlin 1 + +``` + +k=1 model +``` +python IMAGENER_DIR -j THREADS imagenet_single_layer.py --ncnn 8 + +``` +### VGG-11 + +The VGG-11 model was trained with a new refactored and more modular codebase different from the codebase used for the above models and is thus run from the standalone directory +refactored_imagenet/ + +To train the VGG-11 with k=3 + +``` +python imagenet_greedy.py IMAGENER_DIR -j THREADS --arch vgg11_bn --half --dynamic-loss-scale + +``` +to train the baseline: + +``` +python imagenet.py IMAGENER_DIR -j THREADS --arch vgg11_bn --half --dynamic-loss-scale + +``` + +### Linear Separability +Linear separability experiments are in linear_separability folder. A notebook is included that produces the plots. to run different settings + + +This will create and train a model, using K non-linearity, F features and the model is stored in checkpoint. + +``` +python cifar.py --ncnn 5 --nlin K --feature_size F +``` + +This will use the model "filename", to train probes on top of these at layer "j" +``` +python train_lr.py filename j + +``` + +### CIFAR experiments +CIFAR experiments can be reproduced using cifar.py + +The CIFAR-10 models can be trained: + +k=3 (~91.7) +``` +python cifar.py --ncnn 4 --nlin 2 --feature_size 128 --down [1] --bn 1 + +``` + +k=2 (~90.4) + +``` +python cifar.py --ncnn 4 --nlin 1 --feature_size 128 --down [1] --bn 1 + +``` + +k=1 (~88.3) +``` +python cifar.py --ncnn 5 --nlin 0 --feature_size 256 + +``` + diff --git a/cifar.py b/cifar.py new file mode 100644 index 0000000..b1652f3 --- /dev/null +++ b/cifar.py @@ -0,0 +1,309 @@ +"Greedy layerwise cifar training" +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +import numpy as np + +import torchvision +import torchvision.transforms as transforms + +import os +import argparse + +from model_greedy import * +from torch.autograd import Variable + +from utils import progress_bar + +from random import randint +import datetime +import json + + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.1, type=float, help='learning rate') +parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') +parser.add_argument('--ncnn', default=5,type=int, help='depth of the CNN') +parser.add_argument('--nepochs', default=50,type=int, help='number of epochs') +parser.add_argument('--epochdecay', default=15,type=int, help='number of epochs') +parser.add_argument('--avg_size', default=16,type=int, help='size of averaging ') +parser.add_argument('--feature_size', default=128,type=int, help='feature size') +parser.add_argument('--ds-type', default=None, help="type of downsampling. Defaults to old block_conv with psi. Options 'psi', 'stride', 'avgpool', 'maxpool'") +parser.add_argument('--nlin', default=2,type=int, help='nlin') +parser.add_argument('--ensemble', default=1,type=int,help='compute ensemble') +parser.add_argument('--name', default='',type=str,help='name') +parser.add_argument('--batch_size', default=128,type=int,help='batch size') +parser.add_argument('--bn', default=0,type=int,help='use batchnorm') +parser.add_argument('--debug', default=0,type=int,help='debug') +parser.add_argument('--debug_parameters', default=0,type=int,help='verification that layers frozen') +parser.add_argument('-j', '--workers', default=6, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--width_aux', default=128,type=int,help='auxillary width') +parser.add_argument('--down', default='[1,2]', type=str, + help='layer at which to downsample') +parser.add_argument('--seed', default=None, help="Fixes the CPU and GPU random seeds to a specified number") + +args = parser.parse_args() +opts = vars(args) +args.ensemble = args.ensemble>0 +args.bn = args.bn > 0 +args.debug_parameters = args.debug_parameters > 0 + +if args.debug: + args.nepochs = 1 # we run just one epoch per greedy layer training in debug mode + +downsample = list(np.array(json.loads(args.down))) +in_size=32 +mode=0 + +if args.seed is not None: + seed = int(args.seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +save_name = 'layersize_'+str(args.feature_size)+'width_' \ + + str(args.width_aux) + 'depth_' + str(args.nlin) + 'ds_type_' + str(args.ds_type) +'down_' + args.down +#logging +time_stamp = str(datetime.datetime.now().isoformat()) + +name_log_dir = ''.join('{}{}-'.format(key, val) for key, val in sorted(opts.items()))+time_stamp +name_log_dir = 'runs/'+name_log_dir + +name_log_txt = time_stamp + save_name + str(randint(0, 1000)) + args.name +debug_log_txt = name_log_txt + '_debug.log' +name_save_model = name_log_txt + '.t7' +name_log_txt=name_log_txt +'.log' + +with open(name_log_txt, "a") as text_file: + print(opts, file=text_file) + + +use_cuda = torch.cuda.is_available() +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset_class = torchvision.datasets.CIFAR10(root='.', train=True, download=True,transform=transform_train) +trainloader_classifier = torch.utils.data.DataLoader(trainset_class, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) +testset = torchvision.datasets.CIFAR10(root='.', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2) + +# Model + +print('==> Building model..') +n_cnn=args.ncnn +#if args.ds_type is None: + # block_conv_ = block_conv +#else: + # from functools import partial + # block_conv_ = partial(ds_conv, ds_type=args.ds_type) +net = greedyNet(block_conv, 1, args.feature_size, downsample=downsample, batchnorm=args.bn) + + +if args.width_aux: + num_feat = args.width_aux +else: + num_feat = args.feature_size + +net_c = auxillary_classifier(avg_size=args.avg_size, in_size=in_size, + n_lin=args.nlin, feature_size=num_feat, + input_features=args.feature_size, batchn=args.bn) + + +with open(name_log_txt, "a") as text_file: + print(net, file=text_file) + print(net_c, file=text_file) + +net = torch.nn.DataParallel(nn.Sequential(net,net_c)).cuda() +cudnn.benchmark = True + +criterion_classifier = nn.CrossEntropyLoss() + +net.module[0].unfreezeGradient(0) +optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) +criterion = nn.CrossEntropyLoss() + + +def train_classifier(epoch,n): + print('\nSubepoch: %d' % epoch) + net.train() + for k in range(n): + net.module[0].blocks[k].eval() + + + if args.debug_parameters: + #This is used to verify that early layers arent updated + import copy + #store all parameters on cpu as numpy array + net_cpu = copy.deepcopy(net).cpu() + net_cpu_dict = net_cpu.module[0].state_dict() + with open(debug_log_txt, "a") as text_file: + print('n: %d'%n) + for param in net_cpu_dict.keys(): + net_cpu_dict[param]=net_cpu_dict[param].numpy() + print("parameter stored on cpu as numpy: %s "%(param),file=text_file) + + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader_classifier): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + + optimizer.zero_grad() + inputs, targets = Variable(inputs), Variable(targets) + outputs = net.forward([inputs,n]) + + loss = criterion_classifier(outputs, targets) + loss.backward() + optimizer.step() + loss_pers=0 + + if args.debug_parameters: + + s_dict = net.module[0].state_dict() + with open(debug_log_txt, "a") as text_file: + print("iteration %d" % (batch_idx), file=text_file) + for param in s_dict.keys(): + diff = np.sum((s_dict[param].cpu().numpy()-net_cpu_dict[param])**2) + print("n: %d parameter: %s size: %s changed by %.5f" % (n,param,net_cpu_dict[param].shape,diff),file=text_file) + + train_loss += loss.data[0] + _, predicted = torch.max(outputs.data, 1) + total += targets.size(0) + correct += predicted.eq(targets.data).cpu().sum() + + progress_bar(batch_idx, len(trainloader_classifier), 'Loss: %.3f | Acc: %.3f%% (%d/%d) | losspers: %.3f' + % (train_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total,loss_pers)) + + acc = 100.*float(correct)/float(total) + return acc + +all_outs = [[] for i in range(args.ncnn)] + +def test(epoch,n,ensemble=False): + global acc_test_ensemble + all_targs = [] + net.eval() + test_loss = 0 + correct = 0 + total = 0 + + all_outs[n] = [] + for batch_idx, (inputs, targets) in enumerate(testloader): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + + inputs, targets = Variable(inputs, volatile=True), Variable(targets) + outputs = net([inputs,n]) + + loss = criterion_classifier(outputs, targets) + + test_loss += loss.data[0] + _, predicted = torch.max(outputs.data, 1) + total += targets.size(0) + correct += predicted.eq(targets.data).cpu().sum() + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total)) + + if args.ensemble: + all_outs[n].append(outputs.data.cpu()) + all_targs.append(targets.data.cpu()) + acc = 100. * float(correct) / float(total) + + if ensemble: + all_outs[n] = torch.cat(all_outs[n]) + all_targs = torch.cat(all_targs) + #This is all on cpu so we dont care + weight = 2 ** (np.arange(n + 1)) / sum(2 ** np.arange(n + 1)) + total_out = torch.zeros((total,10)) + + #very lazy + for i in range(n+1): + total_out += float(weight[i])*all_outs[i] + + + _, predicted = torch.max(total_out, 1) + correct = predicted.eq(all_targs).sum() + acc_ensemble = 100*float(correct)/float(total) + print('Acc_ensemble: %.2f'%acc_ensemble) + if ensemble: + return acc,acc_ensemble + else: + return acc + +i=0 +num_ep = args.nepochs + +for n in range(n_cnn): + net.module[0].unfreezeGradient(n) + lr = args.lr*5.0# we run at epoch 0 the lr reset to remove non learnable param + + for epoch in range(0, num_ep): + i=i+1 + print('n: ',n) + if epoch % args.epochdecay == 0: + lr=lr/5.0 + to_train = list(filter(lambda p: p.requires_grad, net.parameters())) + optimizer = optim.SGD(to_train, lr=lr, momentum=0.9, weight_decay=5e-4) + print('new lr:',lr) + + acc_train = train_classifier(epoch,n) + if args.ensemble: + acc_test,acc_test_ensemble = test(epoch,n,args.ensemble) + + with open(name_log_txt, "a") as text_file: + print("n: {}, epoch {}, train {}, test {},ense {} " + .format(n,epoch,acc_train,acc_test,acc_test_ensemble), file=text_file) + else: + acc_test = test(epoch, n) + with open(name_log_txt, "a") as text_file: + print("n: {}, epoch {}, train {}, test {}, ".format(n,epoch,acc_train,acc_test), file=text_file) + + if args.debug: + break + + + if args.down and n in downsample: + args.avg_size = int(args.avg_size/2) + in_size = int(in_size/2) + args.feature_size = int(args.feature_size*2) + args.width_aux = args.width_aux * 2 + + if args.width_aux: + num_feat = args.width_aux + else: + num_feat = args.feature_size + + net_c = None + if n < n_cnn-1: + net_c = auxillary_classifier(avg_size=args.avg_size, in_size=in_size, + n_lin=args.nlin, feature_size=args.width_aux, + input_features=args.feature_size, batchn=args.bn).cuda() + net.module[0].add_block(n in downsample) + net = torch.nn.DataParallel(nn.Sequential(net.module[0], net_c)).cuda() + +state_final = { + 'net': net, + 'acc_test': acc_test, + 'acc_train': acc_train, + } +torch.save(state_final,save_name) diff --git a/imagenet.py b/imagenet.py new file mode 100644 index 0000000..7f6350f --- /dev/null +++ b/imagenet.py @@ -0,0 +1,487 @@ +import argparse +import os +import shutil +import time +from collections import OrderedDict +import torch +import torch._utils +try: + torch._utils._rebuild_tensor_v2 +except AttributeError: + def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): + tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + tensor._backward_hooks = backward_hooks + return tensor + torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 + + +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +from model_greedy import * + + +import numpy as np + + + + +from random import randint +import datetime + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') + +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, + help='distributed backend') + + +parser.add_argument('--ncnn', default=8,type=int, help='depth of the auxiliary CNN') +parser.add_argument('--bn', default=1,type=int, help='turn off/on batchnorm') +parser.add_argument('--nepochs', default=45,type=int, help='number of epochs') +parser.add_argument('--epochdecay', default=20,type=int, help='interval between lr decay') +parser.add_argument('--avg_size', default=7,type=int, help='size of the averaging') +parser.add_argument('--feature_size', default=128,type=int, help='width') +parser.add_argument('--nlin', default=2,type=int, help='number of non linearities in the auxillary') +parser.add_argument('--ds', default=2,type=int, help='initial downsampling') +parser.add_argument('--ensemble', default=1,type=int,help='ensemble') # not implemented yet +parser.add_argument('--name', default='',type=str,help='name') +parser.add_argument('--prog', default=0,type=int, help='increase width of auxillary at downsampling') +parser.add_argument('--debug', default=0,type=int,help='debugging') +parser.add_argument('--large_size_images', default=2,type=int,help='use small images for faster testing') +parser.add_argument('--n_resume', default=0,type=int,help='which layer we resume') +parser.add_argument('--resume_epoch', default=0,type=int,help='which epoch we resume') +parser.add_argument('--fixed_feat', default=512,type=int,help='auxillary width ') +parser.add_argument('--down', default=1,type=int,help='use downsampling') +parser.add_argument('--save_folder', default='.',type=str,help='down') + +args = parser.parse_args() +best_prec1 = 0 + +time_stamp = str(datetime.datetime.now().isoformat()) + +name_log_txt = time_stamp + str(randint(0, 1000)) + args.name + +name_log_txt=name_log_txt +'.log' + +args.ensemble = args.ensemble>0 +args.prog = args.prog >0 +args.debug = args.debug > 0 +args.bn = args.bn > 0 + +downsample = [1,3,5,7] + +args.down = args.down > 0 +def main(): + global args, best_prec1 + args = parser.parse_args() + + if args.large_size_images==0: + N_img = 112 + N_img_scale = 128 + print('using 112') + elif args.large_size_images==1: + N_img = 160 + N_img_scale = 182 + print('using 160') + elif args.large_size_images ==2: + N_img = 224 + N_img_scale= 256 + + in_size = N_img // args.ds + + with open(name_log_txt, "a") as text_file: + print(args, file=text_file) + + n_cnn = args.ncnn + + model = greedyNet(block_conv, 1, feature_size=args.feature_size, downsampling=args.ds, + downsample=downsample, batchnorm=args.bn) + + + if args.fixed_feat: + num_feat = args.fixed_feat + else: + num_feat = args.feature_size + model_c = auxillary_classifier(avg_size=args.avg_size, in_size=N_img // args.ds, + n_lin=args.nlin, feature_size=num_feat, + input_features=args.feature_size, batchn=args.bn, num_classes=1000) + + with open(name_log_txt, "a") as text_file: + print(model, file=text_file) + print(model_c, file=text_file) + + + model = torch.nn.DataParallel(nn.Sequential(model,model_c)).cuda() + model.module[0].unfreezeGradient(0) + + model_c = None + + + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + to_train = list(filter(lambda p: p.requires_grad, model.parameters())) #+ list(model_c.parameters()) + optimizer = torch.optim.SGD(to_train, args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(N_img), + + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + + + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(N_img_scale), + transforms.CenterCrop(N_img), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion) + return + + + if (args.resume): + def load_model(): + model_dict = torch.load(args.save_folder+'/'+args.resume + '_model.t7') + + for key in list(model_dict.keys()): + if key[0:8]=='module.1': + model_dict.pop(key,None) + else: + model_dict = OrderedDict((key[9:] if k == key else k, v) for k, v in model_dict.items()) + + # 1. filter out unnecessary keys + sub_dict = {k: v for k, v in model.module[0].items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(sub_dict) + model.module[0].load_state_dict(sub_dict) + load_model() + + correct_all = np.zeros(len(train_dataset)) + + + num_ep = args.nepochs + + + + for n in range(n_cnn): + # torch.save(model_c.state_dict(), name_log_txt + '_model_c.t7') + model.module[0].unfreezeGradient(n) + lr = args.lr * 10.0 + + for epoch in range(0, num_ep): + if n > 0 and not args.debug and epoch % 3==0: + torch.save(model.state_dict(), args.save_folder+'/'+name_log_txt + '_current_model.t7') + if epoch % args.epochdecay == 0: + lr = lr/10.0 + to_train = list(filter(lambda p: p.requires_grad, model.parameters())) + optimizer = torch.optim.SGD(to_train, lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + if (args.resume and args.resume_epoch>0 and n==args.n_resume): + if epoch < args.resume_epoch: + continue + if epoch == args.resume_epoch: + name = args.resume + '_current_model.t7' + model_dict = torch.load(name) + model.load_state_dict(model_dict) + if (args.resume and n best_prec1 + best_prec1 = max(prec1, best_prec1) + + with open(name_log_txt, "a") as text_file: + print("n: {}, epoch {}, train top1:{}(top5:{}), test top1:{} (top5:{}), top1ens:{} top5ens:{}" + .format(n, epoch, top1train, top5train, top1test,top5test,top1ens,top5ens), file=text_file) + if (args.resume and n500: + break + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + return top1.avg,top5.avg + + +all_outs = [[] for i in range(args.ncnn)] +def validate(val_loader, model, criterion, n): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + all_targs = [] + # switch to evaluate mode + model.eval() + # model_c.eval() + + end = time.time() + all_outs[n] = [] + + with torch.no_grad(): + total = 0 + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + input = input.cuda(non_blocking=True) + input_var = torch.autograd.Variable(input) + target_var = torch.autograd.Variable(target) + + # compute output + output = model([input_var, n]) + + + loss = criterion(output, target_var) + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.data[0], input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + if args.ensemble: + all_outs[n].append(F.softmax(output).data.cpu()) + all_targs.append(target_var.data.cpu()) + total += input_var.size(0) + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + if args.ensemble: + all_outs[n] = torch.cat(all_outs[n]) + all_targs = torch.cat(all_targs) + #This is all on cpu so we dont care + + weight = 2 ** (np.arange(n + 1)) / sum(2 ** np.arange(n + 1)) + total_out = torch.zeros([total, 1000]) + + # very lazy + for i in range(n + 1): + total_out += float(weight[i]) * all_outs[i] + + prec1, prec5 = accuracy(total_out, all_targs, topk=(1, 5)) + + print(' * Ensemble Prec@1 {top1:.3f} Prec@5 {top5:.3f}' + .format(top1=prec1[0], top5=prec5[0])) + return top1.avg,top5.avg,prec1[0],prec5[0] + return top1.avg, top5.avg,-1,-1 + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/imagenet_refactored/imagenet.py b/imagenet_refactored/imagenet.py new file mode 100644 index 0000000..efcafca --- /dev/null +++ b/imagenet_refactored/imagenet.py @@ -0,0 +1,311 @@ +import argparse +import os +import shutil +import time +from collections import OrderedDict +import torch +import torch.optim as optimizer +import torch._utils +from functools import partial +import itertools + + +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torchvision import models +from utils import AverageMeter, accuracy, convnet_half_precision,DataParallelSpecial +import json +import numpy as np +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + + + +from random import randint +import datetime +import torch.nn.functional as F +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='vgg11_bn', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--ds-type', default='maxpool', help="type of downsampling" + "Options 'psi', 'stride', 'avgpool', 'maxpool'") +parser.add_argument('--nepochs', default=45,type=int, help='number of epochs') +parser.add_argument('--epochdecay', default=20,type=int, help='number of epochs') +parser.add_argument('--name', default='',type=str,help='name') +parser.add_argument('--debug', default=0,type=int,help='debugging') +parser.add_argument('--start_epoch', default=1,type=int,help='which n we resume') +parser.add_argument('--save_folder', default='.',type=str,help='folder to save') +#related to mixed precision +parser.add_argument('--half', dest='half', action='store_true', + help='use half-precision(16-bit) ') +parser.add_argument('--static-loss-scale', type=float, default=1, + help='Static loss scale, positive power of 2 values can improve fp16 convergence.') +parser.add_argument('--dynamic-loss-scale', action='store_true', + help='Use dynamic loss scaling. If supplied, this argument supersedes ' + + '--static-loss-scale.') +args = parser.parse_args() +best_prec1 = 0 + + +################# Setup arguments +args.debug = args.debug > 0 + + +if args.half: + from fp16 import FP16_Optimizer + from fp16.fp16util import BN_convert_float + if args.half: + assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." +##################### Logs +time_stamp = str(datetime.datetime.now().isoformat()) +name_log_txt = time_stamp + str(randint(0, 1000)) + args.name +name_log_txt=name_log_txt +'.log' + +with open(name_log_txt, "a") as text_file: + print(args, file=text_file) + +def main(): + global args, best_prec1 + args = parser.parse_args() + + N_img = 224 + N_img_scale= 256 + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(N_img), + + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(N_img_scale), + transforms.CenterCrop(N_img), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + +#### #### To simplify data parallelism we make an nn module with multiple outs + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + with open(name_log_txt, "a") as text_file: + print(model, file=text_file) + model = nn.DataParallel(model) + model = model.cuda() + if args.half: + model = model.half() + model = BN_convert_float(model) + ############### Initialize all + num_ep = args.nepochs + + +############## Resume if we need to resume + if (args.resume): + name = args.resume + model_dict = torch.load(name) + model.load_state_dict(model_dict) + print('model loaded') +######################### Lets do the training + criterion = nn.CrossEntropyLoss().cuda() + + lr = args.lr + to_train = itertools.chain(model.parameters()) + optim = optimizer.SGD(to_train, lr=lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + if args.half: + optim = FP16_Optimizer(optim, + static_loss_scale=args.static_loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={'scale_window': 1000}) + + for epoch in range(args.start_epoch,num_ep+1): + # Make sure we set the bn right + model.train() + + #For each epoch let's store each layer individually + batch_time_total = AverageMeter() + data_time = AverageMeter() + lossm = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + + if epoch % args.epochdecay == 0: + lr = lr / 10.0 + to_train = itertools.chain(model.parameters()) + optim = optimizer.SGD(to_train, lr=lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + if args.half: + optim = FP16_Optimizer(optim, + static_loss_scale=args.static_loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={'scale_window': 1000}) + end = time.time() + + for i, (inputs, targets) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + targets = targets.cuda(non_blocking = True) + inputs = inputs.cuda(non_blocking = True) + inputs = torch.autograd.Variable(inputs) + targets = torch.autograd.Variable(targets) + if args.half: + inputs = inputs.half() + + end = time.time() + + # Forward + optim.zero_grad() + outputs = model(inputs) + + loss = criterion(outputs, targets) + # update + if args.half: + optim.backward(loss) + else: + loss.backward() + + optim.step() + + # measure accuracy and record loss + # measure elapsed time + batch_time_total.update(time.time() - end) + end = time.time() + prec1, prec5 = accuracy(outputs.data, targets, topk=(1, 5)) + lossm.update(float(loss.data[0]), float(inputs.size(0))) + top1.update(float(prec1[0]), float(inputs.size(0))) + top5.update(float(prec5[0]), float(inputs.size(0))) + + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time_total, + data_time=data_time, loss=lossm, top1=top1, top5=top5)) + + + if args.debug and i > 50: + break + + + top1test, top5test = validate(val_loader, model, criterion, epoch) + with open(name_log_txt, "a") as text_file: + print("lr: {}, epoch {}, train top1:{}(top5:{}), " + "test top1:{} (top5:{})" + .format(lr, epoch, top1.avg, top5.avg, + top1test, top5test), file=text_file) + + #####Checkpoint + if not args.debug: + torch.save(model.state_dict(), args.save_folder + '/' + \ + name_log_txt + '_current_model.t7') + + + ############Save the final model + torch.save(model.state_dict(), args.save_folder + '/' + name_log_txt + '_model.t7') + + +def validate(val_loader, model, criterion, epoch): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + model.eval() + + end = time.time() + + with torch.no_grad(): + total = 0 + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + input = input.cuda(non_blocking=True) + input = torch.autograd.Variable(input) + target = torch.autograd.Variable(target) + if args.half: + input = input.half() + + # compute output + output = model(input) + + + loss = criterion(output, target) + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(float(loss.data[0]), float(input.size(0))) + top1.update(float(prec1[0]), float(input.size(0))) + top5.update(float(prec5[0]), float(input.size(0))) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + total += input.size(0) + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg, top5.avg + + +if __name__ == '__main__': + main() diff --git a/imagenet_refactored/imagenet_greedy.py b/imagenet_refactored/imagenet_greedy.py new file mode 100755 index 0000000..1aeca86 --- /dev/null +++ b/imagenet_refactored/imagenet_greedy.py @@ -0,0 +1,381 @@ +import argparse +import os +import shutil +import time +from collections import OrderedDict +import torch +import torch.optim as optim +import torch._utils +from functools import partial +import itertools + +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import models +from utils import AverageMeter, accuracy, convnet_half_precision,DataParallelSpecial +import json +import numpy as np +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + + + +from random import randint +import datetime +import torch.nn.functional as F +parser = argparse.ArgumentParser(description='PyTorch ImageNet Greedy Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='vgg_11bn', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--nepochs', default=45,type=int, help='number of epochs') +parser.add_argument('--epochdecay', default=20,type=int, help='number of epochs') +parser.add_argument('--nlin', default=2,type=int, help='nlin') +parser.add_argument('--ensemble', default=1,type=int,help='ensemble') # not implemented yet +parser.add_argument('--name', default='',type=str,help='name') +parser.add_argument('--debug', default=0,type=int,help='debugging') +parser.add_argument('--large_size_images', default=1,type=int,help='use small image for dev') +parser.add_argument('--start_epoch', default=1,type=int,help='which n we resume') +parser.add_argument('--resume_epoch', default=0,type=int,help='which epoch we resume') +parser.add_argument('--resume_feat', default=0,type=int,help='dilate') +parser.add_argument('--save_folder', default='.',type=str,help='folder to save') +#related to mixed precision +parser.add_argument('--half', dest='half', action='store_true', + help='use half-precision(16-bit) ') +parser.add_argument('--static-loss-scale', type=float, default=1, + help='Static loss scale, positive power of 2 values can improve fp16 convergence.') +parser.add_argument('--dynamic-loss-scale', action='store_true', + help='Use dynamic loss scaling. If supplied, this argument supersedes ' + + '--static-loss-scale.') +args = parser.parse_args() +best_prec1 = 0 + + +################# Setup arguments +args.ensemble = args.ensemble>0 +args.debug = args.debug > 0 + +device_ids = [i for i in range(torch.cuda.device_count())] + + +if args.half: + from fp16 import FP16_Optimizer + from fp16.fp16util import BN_convert_float + if args.half: + assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." +##################### Logs +time_stamp = str(datetime.datetime.now().isoformat()) +name_log_txt = time_stamp + str(randint(0, 1000)) + args.name +name_log_txt=name_log_txt +'.log' + +with open(name_log_txt, "a") as text_file: + print(args, file=text_file) + +def main(): + global args, best_prec1 + args = parser.parse_args() + + +#### setup sizes and dataloaders + if args.large_size_images==0: + N_img = 112 + N_img_scale = 128 + print('using 112') + elif args.large_size_images ==1: + N_img = 224 + N_img_scale= 256 + + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(N_img), + + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(N_img_scale), + transforms.CenterCrop(N_img), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + +#### #### To simplify data parallelism we make an nn module with multiple outs + model = models.__dict__[args.arch](nlin=args.nlin).cuda() + + args.ncnn = len(model.main_cnn.blocks) + n_cnn = len(model.main_cnn.blocks) + with open(name_log_txt, "a") as text_file: + print(model, file=text_file) + if len(device_ids) == 1: + model = nn.DataParallel(model) #single gpu mode, we do the DataParallle so we can still do .module later + else: + model = DataParallelSpecial(model) + + if args.half: + model = model.half() + model = BN_convert_float(model) + ############### Initialize all + num_ep = args.nepochs + layer_epoch = [0] * n_cnn + layer_lr = [args.lr] * n_cnn + layer_optim = [None] * n_cnn + + +############## Resume if we need to resume + if (args.resume): + name = args.resume + model_dict = torch.load(name) + model.load_state_dict(model_dict) + print('model loaded') + for n in range(args.ncnn): + to_train = itertools.chain(model.module.main_cnn.blocks[n].parameters(), + model.module.auxillary_nets[n].parameters()) + layer_optim[n] = optim.SGD(to_train, lr=layer_lr[n], + momentum=args.momentum, + weight_decay=args.weight_decay) + if args.half: + layer_optim[n] = FP16_Optimizer(layer_optim[n], + static_loss_scale=args.static_loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={'scale_window': 1000}) +######################### Lets do the training + criterion = nn.CrossEntropyLoss().cuda() + for n in range(args.ncnn): + for epoch in range(args.start_epoch,num_ep): + + # Make sure we set the batchnorm right + model.train() + for k in range(n): + model.module.main_cnn.blocks[k].eval() + + #For each epoch let's store each layer individually + batch_time = AverageMeter() + batch_time_total = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + + if epoch % args.epochdecay == 0: + layer_lr[n] = layer_lr[n] / 10.0 + to_train = itertools.chain(model.module.main_cnn.blocks[n].parameters(), + model.module.auxillary_nets[n].parameters()) + layer_optim[n] = optim.SGD(to_train, lr=layer_lr[n], + momentum=args.momentum, + weight_decay=args.weight_decay) + if args.half: + layer_optim[n] = FP16_Optimizer(layer_optim[n], + static_loss_scale=args.static_loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={'scale_window': 1000}) + end = time.time() + + for i, (inputs, targets) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + targets = targets.cuda(non_blocking = True) + inputs = inputs.cuda(non_blocking = True) + inputs = torch.autograd.Variable(inputs) + targets = torch.autograd.Variable(targets) + if args.half: + inputs = inputs.half() + + #Main loop + if torch.cuda.device_count() > 1: + _,representation = model(inputs,init=True) #This only initializes the multi-gpu + else: + representation = inputs + + + + for k in range(n): + #forward only + outputs, representation = model(representation, n=k) + + if n>0: + if torch.cuda.device_count() > 1: + representation = [rep.detach() for rep in representation] + else: + representation = representation.detach() + + #update current layer + layer_optim[n].zero_grad() + outputs, representation = model(representation, n=n) + loss = criterion(outputs, targets) + + # update + if args.half: + layer_optim[n].backward(loss) + else: + loss.backward() + + layer_optim[n].step() + + + # measure accuracy and record loss + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + prec1, prec5 = accuracy(outputs.data, targets, topk=(1, 5)) + losses.update(float(loss.data[0]), float(inputs.size(0))) + top1.update(float(prec1[0]), float(inputs.size(0))) + top5.update(float(prec5[0]), float(inputs.size(0))) + + + if i % args.print_freq == 0: + print('n:{0} Epoch: [{1}][{2}/{3}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + n, epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + + if args.debug and i > 50: + break + + + + ##### evaluate on validation set + top1test, top5test, top1ens, top5ens = validate(val_loader, model, criterion, epoch, n) + with open(name_log_txt, "a") as text_file: + print("n: {}, epoch {}, train top1:{}(top5:{}), " + "test top1:{} (top5:{}), top1ens:{} top5ens:{}" + .format(n, epoch, top1.avg, top5.avg, + top1test, top5test, top1ens, top5ens), file=text_file) + + #####Checkpoint + if not args.debug: + torch.save(model.state_dict(), args.save_folder + '/' + \ + name_log_txt + '_current_model.t7') + + + ############Save the final model + torch.save(model.state_dict(), args.save_folder + '/' + name_log_txt + '_model.t7') + + +all_outs = [[] for n in range(50)] +def validate(val_loader, model, criterion, epoch, n): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + all_targs = [] + model.eval() + + end = time.time() + all_outs[n] = [] + + with torch.no_grad(): + total = 0 + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + input = input.cuda(non_blocking=True) + input = torch.autograd.Variable(input) + target = torch.autograd.Variable(target) + if args.half: + input = input.half() + + # compute output + if len(device_ids)>1: + _, representation = model(input,init=True) + else: + representation = input + output, _ = model(representation, n=n, upto=True) + + + loss = criterion(output, target) + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(float(loss.data[0]), float(input.size(0))) + top1.update(float(prec1[0]), float(input.size(0))) + top5.update(float(prec5[0]), float(input.size(0))) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + if args.ensemble: + all_outs[n].append(F.softmax(output).data.float().cpu()) + all_targs.append(target.data.cpu()) + total += input.size(0) + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + if args.ensemble: + all_outs[n] = torch.cat(all_outs[n]) + all_targs = torch.cat(all_targs) + #This is all on cpu + + weight = 2 ** (np.arange(n + 1)) / sum(2 ** np.arange(n + 1)) + total_out = torch.zeros([total, 1000]) + + for i in range(n + 1): + total_out += float(weight[i]) * all_outs[i] + + prec1, prec5 = accuracy(total_out, all_targs, topk=(1, 5)) + + print(' * Ensemble Prec@1 {top1:.3f} Prec@5 {top5:.3f}' + .format(top1=prec1[0], top5=prec5[0])) + return top1.avg,top5.avg,prec1[0],prec5[0] + return top1.avg, top5.avg,-1,-1 + + +if __name__ == '__main__': + main() diff --git a/imagenet_refactored/utils.py b/imagenet_refactored/utils.py new file mode 100755 index 0000000..ab85fad --- /dev/null +++ b/imagenet_refactored/utils.py @@ -0,0 +1,110 @@ +'''Some helper functions for PyTorch, including: + - get_mean_and_std: calculate the mean and std value of dataset. + - msr_init: net parameter initialization. + - progress_bar: progress bar mimic xlua.progress. +''' +import os +import sys +import time +import math +import torch.nn as nn +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from torchvision.datasets.cifar import CIFAR10 + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +from torch.nn.parallel import DataParallel +from torch.nn.parallel.scatter_gather import scatter_kwargs,scatter +class DataParallelSpecial(DataParallel): + def __init__(self, module, device_ids=None, output_device=None, dim=0): + super(DataParallelSpecial,self).__init__(module, device_ids=None, output_device=None, dim=0) + print('Initialized with GPUs:') + print(self.device_ids) + + def forward(self, *inputs, init=False, **kwargs): + if init: + if self.device_ids: + # -------- Here, we split the input tensor across GPUs + inputs_ = inputs + if not isinstance(inputs_, tuple): + inputs_ = (inputs_,) + + representation, _ = scatter_kwargs(inputs_, None, self.device_ids, 0) + self.replicas = self.replicate(self.module, self.device_ids[:len(representation)]) + # ---- + else: + representation = inputs + return None , representation + + if not self.device_ids: + return self.module(*inputs, **kwargs) + # inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + if len(self.device_ids) == 1: + import ipdb; ipdb.set_trace() + return self.module(*inputs[0][0], **kwargs) + + kwargs = scatter(kwargs, self.device_ids) if kwargs else [] + kwargs = tuple(kwargs) + outputs = self.parallel_apply(self.replicas, *inputs, kwargs) + + out1 = [] + out2 = [] + for i, tensor in enumerate(outputs): + with torch.cuda.device(tensor[0].get_device()): + # out_1[i] = torch.autograd.Variable(tensors[i]) + out1.append(outputs[i][0]) + out2.append(outputs[i][1]) + outputs = self.gather(out1, self.output_device) + representation = out2 + return outputs, representation + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def convnet_half_precision(model): + model.half() # convert to half precision + for layer in model.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.float() + if isinstance(layer, nn.BatchNorm1d): + layer.float() + return model + +def reset(m): + if isinstance(m, nn.Linear): + size = m.weight.size() + fan_out = size[0] # number of rows + fan_in = size[1] + variance = np.sqrt(2.0 / (fan_in + fan_out)) + m.weight.data.normal_(0.0, variance) + m.bias.data.zero_() diff --git a/imagenet_single_layer.py b/imagenet_single_layer.py new file mode 100644 index 0000000..a8c3b05 --- /dev/null +++ b/imagenet_single_layer.py @@ -0,0 +1,601 @@ +import argparse +import os +import shutil +import time +from collections import OrderedDict +import torch +import torch._utils +import torch.nn.functional as F +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import numpy as np +from random import randint +import datetime + +try: + torch._utils._rebuild_tensor_v2 +except AttributeError: + def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): + tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + tensor._backward_hooks = backward_hooks + return tensor + torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 + + +class identity(nn.Module): + def __init__(self): + super(identity, self).__init__() + def forward(self, input): + return input + +class psi(nn.Module): + def __init__(self, block_size): + super(psi, self).__init__() + self.block_size = block_size + self.block_size_sq = block_size*block_size + + def forward(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, s_height, s_width, s_depth) = output.size() + d_depth = s_depth * self.block_size_sq + d_height = int(s_height / self.block_size) + t_1 = output.split(self.block_size, 2) + stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] + output = torch.stack(stack, 1) + output = output.permute(0, 2, 1, 3) + output = output.permute(0, 3, 1, 2) + return output.contiguous() + +class block_conv(nn.Module): + expansion = 1 + def __init__(self, in_planes, planes, stride=1,downsample=False,batchn=True): + super(block_conv, self).__init__() + self.downsample = downsample + if downsample: + self.down = psi(2) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,stride=stride, padding=1,bias=not batchn) + if batchn: + self.bn1 = nn.BatchNorm2d(planes) + else: + self.bn1 = identity() #Identity + + def forward(self, x): + if self.downsample: + x = self.down(x) + out = F.relu(self.bn1(self.conv1(x))) + return out + +class greedyNet(nn.Module): + def __init__(self, block, num_blocks, feature_size=256,downsampling=1, downsample=[],batchnorm=True,strd=False): + super(greedyNet, self).__init__() + self.in_planes = feature_size + self.down_sampling = psi(downsampling) + self.downsample_init = downsampling + self.conv1 = nn.Conv2d(3*downsampling*downsampling, self.in_planes, kernel_size=3, stride=1, padding=1, bias=not batchnorm) + + if batchnorm: + self.bn1 = nn.BatchNorm2d(self.in_planes) + else: + self.bn1 = identity() #Identity + self.RELU = nn.ReLU() + self.blocks = [] + self.blocks.append(nn.Sequential(self.conv1,self.bn1,self.RELU)) #n=0 + self.batchn = batchnorm + for n in range(num_blocks-1): + inc = 2 + if n in downsample: + if self.strd: + dsample = False + strd=2 + init_sz=1 + else: + dsample = True + strd=1 + init_sz=4 + self.blocks.append(block(self.in_planes * init_sz, self.in_planes * inc, strd, downsample=dsample,batchn=batchnorm)) + self.in_planes = self.in_planes * inc + else: + self.blocks.append(block(self.in_planes,self.in_planes,1)) + + self.blocks = nn.ModuleList(self.blocks) + for n in range(num_blocks): + for p in self.blocks[n].parameters(): + p.requires_grad = False + + def unfreezeGradient(self,n): + for k in range(len(self.blocks)): + for p in self.blocks[k].parameters(): + p.requires_grad = False + + for p in self.blocks[n].parameters(): + p.requires_grad = True + + def add_block(self,block,downsample=False): + if downsample: + inc = 2 + if self.strd: + dsample = False + strd=2 + init_sz = 1 + else: + dsample = True + strd=1 + init_sz = 4 + + self.blocks.append( + block(self.in_planes * init_sz, self.in_planes * inc, strd, downsample=dsample,batchn=self.batchn)) + self.in_planes = self.in_planes * inc + else: + self.blocks.append(block(self.in_planes,self.in_planes,1,batchn=self.batchn)) + + def forward(self, a): + x=a[0] + N=a[1] + out = x + if self.downsample_init>1: + out = self.down_sampling(x) + for n in range(N+1): + out=self.blocks[n](out) + return out + + +class auxillary_classifier(nn.Module): + def __init__(self,avg_size=4,feature_size=256, in_size=32,num_classes=10,batchn=True): + super(auxillary_classifier, self).__init__() + if batchn: + self.bn = nn.BatchNorm2d(feature_size) + else: + self.bn = identity() #Identity + self.avg_size=avg_size + + self.classifier = nn.Linear(feature_size*(in_size//avg_size)*(in_size//avg_size), num_classes) + def forward(self, x): + out = F.avg_pool2d(x, self.avg_size) + out = self.bn(out) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, + help='distributed backend') + +parser.add_argument('--bn', default=0,type=int, help='depth of the CNN') +parser.add_argument('--ncnn', default=8,type=int, help='depth of the CNN') +parser.add_argument('--nepochs', default=45,type=int, help='number of epochs') +parser.add_argument('--epochdecay', default=20,type=int, help='number of epochs') +parser.add_argument('--avg_size', default=56,type=int, help='size of the averaging') +parser.add_argument('--feature_size', default=256,type=int, help='size of the averaging') +parser.add_argument('--ds', default=2,type=int, help='initial downsampling') +parser.add_argument('--ensemble', default=1,type=int,help='ensemble') # not implemented yet +parser.add_argument('--name', default='',type=str,help='name') +parser.add_argument('--debug', default=0,type=int,help='debug') +parser.add_argument('--large_size_images', default=2,type=int,help='small images for debugging') +parser.add_argument('--n_resume', default=0,type=int,help='which n we resume') +parser.add_argument('--resume_epoch', default=0,type=int,help='which n we resume') +parser.add_argument('--dilate', default=0,type=int,help='dilate') +parser.add_argument('--resume_feat', default=0,type=int,help='deprecated') +parser.add_argument('--down', default=1,type=int,help='perform downsampling ops') +parser.add_argument('--save_folder', default='.',type=str,help='folder saving') + +args = parser.parse_args() +best_prec1 = 0 + +time_stamp = str(datetime.datetime.now().isoformat()) + +name_log_txt = time_stamp + str(randint(0, 1000)) + args.name + +name_log_txt=name_log_txt +'.log' + +args.ensemble = args.ensemble>0 +args.debug = args.debug > 0 +args.bn = args.bn > 0 +args.dilate = args.dilate > 0 #toremove +args.nlin=0 # We only do k=1 here as its a special case +downsample = [1,2,3,5] +args.down = args.down > 0 + + +def main(): + global args, best_prec1 + args = parser.parse_args() + + if args.large_size_images==0: + N_img = 112 + N_img_scale = 128 + elif args.large_size_images==1: + N_img = 160 + N_img_scale = 182 + elif args.large_size_images ==2: + N_img = 224 + N_img_scale= 256 + + in_size = N_img // args.ds + + with open(name_log_txt, "a") as text_file: + print(args, file=text_file) + + n_cnn = args.ncnn + + + model = greedyNet(block_conv,1,feature_size=args.feature_size,downsampling=args.ds, + downsample=downsample,batchnorm=args.bn) + num_feat = args.feature_size + model_c = auxillary_classifier(avg_size=args.avg_size, in_size=in_size, feature_size=num_feat, num_classes=1000, batchn=args.bn) + + with open(name_log_txt, "a") as text_file: + print(model, file=text_file) + print(model_c, file=text_file) + + model = torch.nn.DataParallel(nn.Sequential(model,model_c)).cuda() + model.module[0].unfreezeGradient(0) + + model_c = None + + + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + to_train = list(filter(lambda p: p.requires_grad, model.parameters())) #+ list(model_c.parameters()) + optimizer = torch.optim.SGD(to_train, args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(N_img), + + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + + + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(N_img_scale), + transforms.CenterCrop(N_img), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion) + return + + + + if (args.resume): + model_dict = torch.load(args.save_folder+'/'+args.resume + '_model.t7') + + + for key in list(model_dict.keys()): + if key[0:8]=='module.1': + model_dict.pop(key,None) + else: + model_dict = OrderedDict((key[9:] if k == key else k, v) for k, v in model_dict.items()) + + model.module[0].load_state_dict(model_dict) + + + num_ep = args.nepochs + + + + for n in range(n_cnn): + model.module[0].unfreezeGradient(n) + lr = args.lr * 10.0 + + for epoch in range(0, num_ep): + if n > 0 and not args.debug and epoch % 3==0: + torch.save(model.state_dict(), args.save_folder+'/'+name_log_txt + '_current_model.t7') + if epoch % args.epochdecay == 0: + lr = lr/10.0 + to_train = list(filter(lambda p: p.requires_grad, model.parameters())) + optimizer = torch.optim.SGD(to_train, lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + if (args.resume and args.resume_epoch>0 and n==args.n_resume): + if epoch < args.resume_epoch: + continue + if epoch == args.resume_epoch: + name = args.resume + '_current_model.t7' + model_dict = torch.load(name) + # import ipdb; ipdb.set_trace() + model.load_state_dict(model_dict) + if (args.resume and n best_prec1 + best_prec1 = max(prec1, best_prec1) + + with open(name_log_txt, "a") as text_file: + print("n: {}, epoch {}, train top1:{}(top5:{}), test top1:{} (top5:{}), top1ens:{} top5ens:{}" + .format(n, epoch, top1train, top5train, top1test,top5test,top1ens,top5ens), file=text_file) + if (args.resume and n50: + break + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + return top1.avg,top5.avg + + +all_outs = [[] for i in range(args.ncnn)] +def validate(val_loader, model, criterion, n): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + all_targs = [] + # switch to evaluate mode + model.eval() + + + end = time.time() + all_outs[n] = [] + + total = 0 + for i, (input, target) in enumerate(val_loader): + target = target.cuda(async=True) + input_var = torch.autograd.Variable(input, volatile=True) + target_var = torch.autograd.Variable(target, volatile=True) + + # compute output + output = model([input_var, n]) + # output = model_c.forward(output) + + + loss = criterion(output, target_var) + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.data[0], input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + if args.ensemble: + all_outs[n].append(F.softmax(output).data.cpu()) + all_targs.append(target_var.data.cpu()) + total += input_var.size(0) + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + if args.ensemble: + all_outs[n] = torch.cat(all_outs[n]) + all_targs = torch.cat(all_targs) + #This is all on cpu so we dont care + + weight = 2 ** (np.arange(n + 1)) / sum(2 ** np.arange(n + 1)) + total_out = torch.zeros([total, 1000]) + + # very lazy + for i in range(n + 1): + total_out += float(weight[i]) * all_outs[i] + + prec1, prec5 = accuracy(total_out, all_targs, topk=(1, 5)) + + print(' * Ensemble Prec@1 {top1:.3f} Prec@5 {top5:.3f}' + .format(top1=prec1[0], top5=prec5[0])) + return top1.avg,top5.avg,prec1[0],prec5[0] + return top1.avg, top5.avg,-1,-1 + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/model_greedy.py b/model_greedy.py new file mode 100644 index 0000000..b2af459 --- /dev/null +++ b/model_greedy.py @@ -0,0 +1,244 @@ +"" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from functools import partial + + +class block_conv(nn.Module): + expansion = 1 + def __init__(self, in_planes, planes,downsample=False,batchn=True): + super(block_conv, self).__init__() + self.downsample = downsample + if downsample: + self.down = psi(2) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + if batchn: + self.bn1 = nn.BatchNorm2d(planes) + else: + self.bn1 = identity() # Identity + + def forward(self, x): + if self.downsample: + x = self.down(x) + out = F.relu(self.bn1(self.conv1(x))) + return out + + +class ds_conv(nn.Module): + """ds_conv defaults to block_conv but can implement other downsamplings. They all have the same shape behavior""" + + def __init__(self, in_planes, planes, downsample=False, ds_type='psi', batchn=True): + super(ds_conv, self).__init__() + self.downsample = downsample + self.ds_type = ds_type + self.in_planes = in_planes + self.planes = planes + self.batchn = batchn + + self.build() + + def build(self): + """Builds the forward model depending on the downsampler""" + planes = self.planes + in_planes = self.in_planes + if self.batchn: + self.bn1 = nn.BatchNorm2d(planes) + else: + self.bn1 = identity() + + if self.downsample is False: + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=not self.batchn) + self.conv_op = self.conv1 + elif self.ds_type == 'psi': + self.down = psi(2) + self.conv1 = nn.Conv2d(4 * in_planes, planes, kernel_size=3, stride=1, padding=1, bias=not self.batchn) + self.conv_op = nn.Sequential(self.down, self.conv1) + elif self.ds_type == 'stride': + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=2, padding=1, bias=not self.batchn) + self.conv_op = self.conv1 + elif self.ds_type == 'maxpool': + self.down = nn.MaxPool2d(2) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=not self.batchn) + self.conv_op = nn.Sequential(self.down, self.conv1) + elif self.ds_type == 'avgpool': + self.down = nn.AvgPool2d(2) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=not self.batchn) + self.conv_op = nn.Sequential(self.down, self.conv1) + else: + raise ValueError("I don't get {self.ds_type}. Only know False, True, 'psi', 'stride', 'maxpool', 'avgpool'") + + + def forward(self, x): + conv = self.conv_op(x) + out = F.relu(self.bn1(conv)) + return out + + + +class psi(nn.Module): + def __init__(self, block_size): + super(psi, self).__init__() + self.block_size = block_size + self.block_size_sq = block_size*block_size + + def forward(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, s_height, s_width, s_depth) = output.size() + d_depth = s_depth * self.block_size_sq + d_height = int(s_height / self.block_size) + t_1 = output.split(self.block_size, 2) + stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] + output = torch.stack(stack, 1) + output = output.permute(0, 2, 1, 3) + output = output.permute(0, 3, 1, 2) + return output.contiguous() + + +class psi2(nn.Module): + def __init__(self, block_size): + super(psi2, self).__init__() + self.block_size = block_size + + def forward(self, x): + """Expects x.shape == (batch, channel, height, width). + Converts to (batch, channel, height / block_size, block_size, + width / block_size, block_size), + transposes to put the two 'block_size' dims before channel, + then reshapes back into (batch, block_size ** 2 * channel, ...)""" + + bs = self.block_size + batch, channel, height, width = x.shape + if ((height % bs) != 0) or (width % bs != 0): + raise ValueError("height and width must be divisible by block_size") + + # reshape (creates a view) + x1 = x.reshape(batch, channel, height // bs, bs, width // bs, bs) + # transpose (also creates a view) + x2 = x1.permute(0, 3, 5, 1, 2, 4) + # reshape into new order (must copy and thus makes contiguous) + x3 = x2.reshape(batch, bs ** 2 * channel, height // bs, width // bs) + return x3 + + + + + +class auxillary_classifier(nn.Module): + def __init__(self,avg_size=16,feature_size=256,input_features=256, in_size=32,num_classes=10,n_lin=0,batchn=True): + super(auxillary_classifier, self).__init__() + self.n_lin=n_lin + + if n_lin==0: + feature_size = input_features + + self.blocks = [] + for n in range(self.n_lin): + if n==0: + input_features = input_features + else: + input_features = feature_size + + if batchn: + bn_temp = nn.BatchNorm2d(feature_size) + else: + bn_temp = identity() + + self.blocks.append(nn.Sequential(nn.Conv2d(input_features, feature_size, + kernel_size=3, stride=1, padding=1, bias=False),bn_temp)) + + self.blocks = nn.ModuleList(self.blocks) + if batchn: + self.bn = nn.BatchNorm2d(feature_size) + else: + self.bn = identity() # Identity + + self.avg_size=avg_size + self.classifier = nn.Linear(feature_size*(in_size//avg_size)*(in_size//avg_size), num_classes) + + def forward(self, x): + out = x + for n in range(self.n_lin): + out = self.blocks[n](out) + out = F.relu(out) + if(self.avg_size>1): + out = F.avg_pool2d(out, self.avg_size) + out = self.bn(out) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + +class identity(nn.Module): + def __init__(self): + super(identity, self).__init__() + + def forward(self, input): + return input + + + +class greedyNet(nn.Module): + def __init__(self, block, num_blocks, feature_size=256, downsampling=1, downsample=[], batchnorm=True): + super(greedyNet, self).__init__() + self.in_planes = feature_size + self.down_sampling = psi(downsampling) + self.downsample_init = downsampling + self.conv1 = nn.Conv2d(3 * downsampling * downsampling, self.in_planes, kernel_size=3, stride=1, padding=1, + bias=not batchnorm) + + if batchnorm: + self.bn1 = nn.BatchNorm2d(self.in_planes) + else: + self.bn1 = identity() # Identity + self.RELU = nn.ReLU() + self.blocks = [] + self.block = block + self.blocks.append(nn.Sequential(self.conv1, self.bn1, self.RELU)) # n=0 + self.batchn = batchnorm + for n in range(num_blocks - 1): + if n in downsample: + pre_factor = 4 + self.blocks.append(block(self.in_planes * pre_factor, self.in_planes * 2,downsample=True, batchn=batchnorm)) + self.in_planes = self.in_planes * 2 + else: + self.blocks.append(block(self.in_planes, self.in_planes,batchn=batchnorm)) + + self.blocks = nn.ModuleList(self.blocks) + for n in range(num_blocks): + for p in self.blocks[n].parameters(): + p.requires_grad = False + + def unfreezeGradient(self, n): + for k in range(len(self.blocks)): + for p in self.blocks[k].parameters(): + p.requires_grad = False + + for p in self.blocks[n].parameters(): + p.requires_grad = True + + def unfreezeAll(self): + for k in range(len(self.blocks)): + for p in self.blocks[k].parameters(): + p.requires_grad = True + + def add_block(self, downsample=False): + if downsample: + pre_factor = 4 # the old block needs this factor 4 + self.blocks.append( + self.block(self.in_planes * pre_factor, self.in_planes * 2, downsample=True, batchn=self.batchn)) + self.in_planes = self.in_planes * 2 + else: + self.blocks.append(self.block(self.in_planes, self.in_planes,batchn=self.batchn)) + + def forward(self, a): + x = a[0] + N = a[1] + out = x + if self.downsample_init > 1: + out = self.down_sampling(x) + for n in range(N + 1): + out = self.blocks[n](out) + return out + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..ada7e9e --- /dev/null +++ b/utils.py @@ -0,0 +1,100 @@ +'''Some helper functions for PyTorch, including: + - get_mean_and_std: calculate the mean and std value of dataset. + - msr_init: net parameter initialization. + - progress_bar: progress bar mimic xlua.progress. +''' +import os +import sys +import time +import math +import torch.nn as nn +import numpy as np +def reset(m): + if isinstance(m, nn.Linear): + size = m.weight.size() + fan_out = size[0] # number of rows + fan_in = size[1] + variance = np.sqrt(2.0 / (fan_in + fan_out)) + m.weight.data.normal_(0.0, variance) + m.bias.data.zero_() + +_, term_width = os.popen('stty size', 'r').read().split() +term_width = int(term_width) + +TOTAL_BAR_LENGTH = 65. +last_time = time.time() +begin_time = last_time +def progress_bar(current, total, msg=None): + global last_time, begin_time + if current == 0: + begin_time = time.time() # Reset for new bar. + + cur_len = int(TOTAL_BAR_LENGTH*current/total) + rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 + + sys.stdout.write(' [') + for i in range(cur_len): + sys.stdout.write('=') + sys.stdout.write('>') + for i in range(rest_len): + sys.stdout.write('.') + sys.stdout.write(']') + + cur_time = time.time() + step_time = cur_time - last_time + last_time = cur_time + tot_time = cur_time - begin_time + + L = [] + L.append(' Step: %s' % format_time(step_time)) + L.append(' | Tot: %s' % format_time(tot_time)) + if msg: + L.append(' | ' + msg) + + msg = ''.join(L) + sys.stdout.write(msg) + for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): + sys.stdout.write(' ') + + # Go back to the center of the bar. + for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): + sys.stdout.write('\b') + sys.stdout.write(' %d/%d ' % (current+1, total)) + + if current < total-1: + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + sys.stdout.flush() + +def format_time(seconds): + days = int(seconds / 3600/24) + seconds = seconds - days*3600*24 + hours = int(seconds / 3600) + seconds = seconds - hours*3600 + minutes = int(seconds / 60) + seconds = seconds - minutes*60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds*1000) + + f = '' + i = 1 + if days > 0: + f += str(days) + 'D' + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + 'h' + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + 'm' + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + 's' + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + 'ms' + i += 1 + if f == '': + f = '0ms' + return f \ No newline at end of file