In [1]:
from __future__ import print_function
import argparse
import numpy as np
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from models import *
from tqdm import tqdm
from models.slimmableops import bn_calibration_init
import USconfig as FLAGS
import random
torch.cuda.is_available()

True

In [2]:
from dotmap import DotMap

args = DotMap()

args.dataset = 'imagenet'
args.batch_size=32
args.workers=8
args.test_batch_size=4
args.epochs = 40
args.start_epoch=0
args.lr = 0.2
args.momentum=0.9
args.weight_decay=1e-4
args.resume=''
args.no_cuda=False
args.seed=1
args.save='checkpoints_mobilenetv2_imagenet'
args.arch='MobileNetV2'
args.sr=True
args.s=0.0001
args.test=True


args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


savepath = os.path.join(args.save, args.arch, 'sr' if args.sr else 'nosr')
if not os.path.exists(savepath):
    os.makedirs(savepath)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}

args.data = '/home/hongky/datasets/imagenet'

In [3]:


# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None)

test_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.test_batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

In [4]:
import torchvision
from models import *

torch_mobilenetv2 = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)

mobilenet2 = eval('MobileNetV2')(n_class=1000, input_size=224)


# classifier
mobilenet2.classifier = torch_mobilenetv2.classifier 

# features 0
mobilenet2.features[0].convbn.conv = torch_mobilenetv2.features[0][0]
mobilenet2.features[0].convbn.bn = torch_mobilenetv2.features[0][1]
mobilenet2.features[0].convbn.relu = torch_mobilenetv2.features[0][2]

# features 18
mobilenet2.features[18].convbn.conv = torch_mobilenetv2.features[18][0]
mobilenet2.features[18].convbn.bn = torch_mobilenetv2.features[18][1]
mobilenet2.features[18].convbn.relu = torch_mobilenetv2.features[18][2]

# feature 1
mobilenet2.features[1].conv.dw_conv = torch_mobilenetv2.features[1].conv[0][0]
mobilenet2.features[1].conv.dw_bn = torch_mobilenetv2.features[1].conv[0][1]
mobilenet2.features[1].conv.dw_relu = torch_mobilenetv2.features[1].conv[0][2]
mobilenet2.features[1].conv.project_conv = torch_mobilenetv2.features[1].conv[1]
mobilenet2.features[1].conv.project_bn = torch_mobilenetv2.features[1].conv[2]


        
for i in range(2, 18):
    mobilenet2.features[i].conv.expand_conv = torch_mobilenetv2.features[i].conv[0][0]
    mobilenet2.features[i].conv.expand_bn = torch_mobilenetv2.features[i].conv[0][1]
    mobilenet2.features[i].conv.expand_relu = torch_mobilenetv2.features[i].conv[0][2]
    
    mobilenet2.features[i].conv.dw_conv = torch_mobilenetv2.features[i].conv[1][0]
    mobilenet2.features[i].conv.dw_bn = torch_mobilenetv2.features[i].conv[1][1]
    mobilenet2.features[i].conv.dw_relu = torch_mobilenetv2.features[i].conv[1][2]
    
    mobilenet2.features[i].conv.project_conv = torch_mobilenetv2.features[i].conv[2]
    mobilenet2.features[i].conv.project_bn = torch_mobilenetv2.features[i].conv[3]


model = mobilenet2

print('Done eval model:', model)


# state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
#                                       progress=True)
# model.load_state_dict(state_dict)


Done eval model: MobileNetV2(
  (features): Sequential(
    (0): conv_bn_relu(
      (convbn): Sequential(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU6(inplace=True)
      )
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (dw_conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (dw_bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (dw_relu): ReLU6(inplace=True)
        (project_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (project_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (expand_conv): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (expand_bn)

In [5]:
# if args.cuda:
#     model.cuda()
model = nn.DataParallel(model).cuda()
best_prec1 = -1
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


In [6]:

def updateBN():
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s * torch.sign(m.weight.data))  # L1


def train():
    model.train()
    avg_loss = 0.
    train_acc = 0.
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        if args.cuda:
            target = target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        if args.sr:
            updateBN()
        optimizer.step()
        


def test(epoch,test_width=1.0,recal=False):
    model.eval()
    test_loss = 0
    correct = 0
    model.apply(lambda m: setattr(m, 'width_mult',test_width))
    if recal:
        model.apply(bn_calibration_init)
        model.train()
        for idx,(data, target) in enumerate(tqdm(train_loader, total=len(train_loader))):
            if idx==FLAGS.recal_batch:
                break
            if args.cuda:
                target = target.cuda()
            with torch.no_grad():
                output = model(data)
            del output
            
    model.eval()
    for data, target in tqdm(test_loader, total=len(test_loader)):
        if args.cuda:
            target = target.cuda()
        with torch.no_grad():
            output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nEpoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(epoch,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct.item() / float(len(test_loader.dataset))

def export2normal():
    newmodel=MobileNetV2()
    from collections import OrderedDict
    statedic=[]
    for k2,v in model.state_dict().items():
        if 'running' in k2 or 'num_batches_tracked' in k2:
            continue
        statedic.append(v)
    names=[]
    for k1,v1 in newmodel.state_dict().items():
        if 'running' in k1 or 'num_batches_tracked' in k1:
            continue
        names.append(k1)
    newdic=OrderedDict(zip(names,statedic))
    newmodel.load_state_dict(newdic,strict=False)
    torch.save(newmodel.state_dict(),os.path.join(savepath,'trans.pth'))
    print("save transferred ckpt at {}".format(os.path.join(savepath,'trans.pth')))

In [7]:
print("Test accuracy {}".format(test(0)))

100%|██████████| 12500/12500 [20:17<00:00, 10.26it/s]


Epoch: 0 Test set: Average loss: 1.1477, Accuracy: 35939/50000 (71.9%)

Test accuracy 0.71878





In [7]:
best_prec1 = 0. if best_prec1 == -1 else best_prec1
scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.epochs,eta_min=0)

for epoch in range(args.start_epoch, args.epochs):
    train()
    prec1 = test(epoch=epoch)
    scheduler.step(epoch)
    lr_current = optimizer.param_groups[0]['lr']
    print("currnt lr:{}".format(lr_current))
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    if is_best:
        ckptfile = os.path.join(savepath, 'model_best.pth.tar')
    else:
        ckptfile = os.path.join(savepath, 'checkpoint.pth.tar')
        
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
    }, ckptfile)

100%|██████████| 40037/40037 [4:17:24<00:00,  2.59it/s]  
100%|██████████| 12500/12500 [19:10<00:00, 10.87it/s]



Epoch: 0 Test set: Average loss: 6.9245, Accuracy: 50/50000 (0.1%)

currnt lr:0.2


 18%|█▊        | 7194/40037 [45:58<3:29:55,  2.61it/s]


KeyboardInterrupt: 

In [13]:

if args.arch=='USMobileNetV2':
    export2normal()
    res_acc=[1.0]*len(FLAGS.width_mult_list)
    for idx,width in enumerate(FLAGS.width_mult_list):
        acc=test(width,recal=True)
        res_acc[idx]=acc
        print("Test accuracy for width {} is {}".format(width,acc))
else:
    print("Test accuracy {}".format(test(0)))


100%|██████████| 40/40 [00:06<00:00,  6.00it/s]


Epoch: 0 Test set: Average loss: 0.2790, Accuracy: 9188/10000 (91.9%)

Test accuracy 0.9188





In [14]:
print(ckptfile)
torch.save({
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    'optimizer': optimizer.state_dict(),
}, ckptfile)

checkpoints/MobileNetV2/sr/checkpoint.pth.tar


In [16]:
!ls checkpoints/MobileNetV2/sr

0.2.json  0.4.json  0.6.json  checkpoint.pth.tar  model_best.pth.tar
