In [3]:
from __future__ import absolute_import, print_function
import time
import argparse
import os
import sys
import torch.utils.data
from torch.backends import cudnn
from torch.autograd import Variable
import models
import losses
from utils import FastRandomIdentitySampler, mkdir_if_missing, logging, display
from utils.serialization import save_checkpoint, load_checkpoint
from trainer import train
from utils import orth_reg

import DataSet
import numpy as np
import os.path as osp

use_gpu = torch.cuda.is_available()

In [4]:
# Batch Norm Freezer : bring 2% improvement on CUB
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

In [5]:
def main(args):
    # s_ = time.time()

    save_dir = args.save_dir
    mkdir_if_missing(save_dir)

    sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt'))
    display(args)
    start = 0

    model = models.create(args.net, pretrained=True, dim=args.dim)

    # for vgg and densenet
    if args.resume is None:
        model_dict = model.state_dict()

    else:
        # resume model
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)

    model = torch.nn.DataParallel(model)
    if use_gpu:
        model = model.cuda()

    # freeze BN
    if args.freeze_BN is True:
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40 * '#', 'BatchNorm NOT frozen')

    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.module.classifier.parameters()))

    new_params = [p for p in model.module.parameters() if
                  id(p) in new_param_ids]

    base_params = [p for p in model.module.parameters() if
                   id(p) not in new_param_ids]

    param_groups = [
        {'params': base_params, 'lr_mult': 0.0},
        {'params': new_params, 'lr_mult': 1.0}]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups, lr=args.lr,
                                 weight_decay=args.weight_decay)

    criterion = losses.create(args.loss, margin=args.margin, alpha=args.alpha, base=args.loss_base).cuda()

    # Decor_loss = losses.create('decor').cuda()
    data = DataSet.create(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width,
                          root=args.data_root)

    train_loader = torch.utils.data.DataLoader(
        data.train, batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances),
        drop_last=True, pin_memory=True, num_workers=args.nThreads)

    # save the train information

    for epoch in range(start, args.epochs):

        train(epoch=epoch, model=model, criterion=criterion,
              optimizer=optimizer, train_loader=train_loader, args=args)

        if epoch == 1:
            optimizer.param_groups[0]['lr_mul'] = 0.1

        if (epoch + 1) % args.save_step == 0 or epoch == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar'))

In [23]:
from types import SimpleNamespace
args = SimpleNamespace(
    lr=1e-05, 
    batch_size=80, 
    num_instances=5, 
    dim=512, 
    width=227,
    origin_width=256,
    ratio=0.16, 
    
    alpha=40, 
    beta=0.1,
    orth_reg=0,
    k=16,
    margin=0.5,
    init='random',

    freeze_BN=True,
    data='cub', 
    data_root='/data/81ac6083fc2442619b54e05eab96c148/', 
    net='BN-Inception', 
    loss='LiftedStructure',
    epochs=600, 
    save_step=50, 
    resume=None, 
    
    print_freq=20,
    save_dir='ckps\LiftedStructure\cub\BN-Inception-dim512-lr0.000010-ratio0.160000-batch80',
    nThreads=8,
    momentum=0.9,
    weight_decay=5e-4,
    loss_base=0.75)
print(args)

In [None]:
main(args)