In [2]:
import argparse
import os
import random
import shutil
import datetime
import warnings
import time
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets


In [3]:
### parameter setting ###

data_path = '/dataset/ILSVRC2012/'
GPU_num = 4
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'

start_epoch = 0
epochs = 120

batch_in = 3
workers = 8 * GPU_num
batch_size = 256 * GPU_num * batch_in #baseline = resnet
learning_rate = 0.1 * GPU_num
momentum = 0.9
weight_decay = 1e-4
now = datetime.datetime.now

In [4]:
### model setting ###

### create model ###
## costom models
from models import custom_resnet_quant_2C_pact as network

## torchvision models
# import torchvision.models as models 


print("=> creating model")
## costom models
model = network.ResNet18()

## torchvision
# model = models.resnet18(pretrained=False) 

=> creating model


In [5]:
### GPU setting ###

device_ids=[i for i in range(GPU_num)]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Using", torch.cuda.device_count(), "GPUs")
# DataParallel will divide and allocate batch_size to all available GPUs
model = nn.DataParallel(model, device_ids=device_ids)

model.cuda()

cudnn.benchmark = True

Using 4 GPUs


In [6]:
### loss function & optimizer setting ###

# define loss function (criterion) and optimizer
criterion = nn.NLLLoss()
softmax = nn.Softmax(1)
optimizer = torch.optim.SGD(model.parameters(), learning_rate,
                            momentum=momentum,
                            weight_decay=weight_decay)

In [7]:
### Dataset setting ###

# Data loading code
traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
    num_workers=workers, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True)

In [8]:
# def lr_schedule(optimizer, epoch):
#     if epoch >= 30 and epoch < 60:
#         lr = 0.01 * GPU_num * batch_in
#     elif epoch >= 60 and epoch < 90:
#         lr = 0.001 * GPU_num * batch_in
#     elif epoch >= 90 and epoch < 120:
#         lr = 0.0001 * GPU_num * batch_in
#     else:
#         lr = 0.1 * GPU_num * batch_in
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr
        
def lr_schedule(optimizer, epoch):
    if epoch >= 30 and epoch < 60:
        lr = 0.001 * GPU_num * batch_in
    elif epoch >= 60 and epoch < 90:
        lr = 0.0001 * GPU_num * batch_in
    elif epoch >= 90 and epoch < 120:
        lr = 0.00001 * GPU_num * batch_in
    else:
        lr = 0.01 * GPU_num * batch_in
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [9]:
def train(optimizer, epoch, epochs):
    
    train_loss_temp = 0
    train_loss = 0
    correct = 0
    train_acc = 0
    total = 0

    # switch to train mode
    model.train()

    t = now()
    for batch_idx, (images, target) in enumerate(train_loader):
        # measure data loading time
        optimizer.zero_grad()
        images = images.to(device)
        target = target.to(device)
        
        out = model(images)
        out_sm = softmax(out)
        log = torch.log(out_sm+1e-7)
        loss = criterion(log, target)
        

        loss.backward()
        
        optimizer.step()
        
        train_loss_temp += loss.item()
        _, predicted = out_sm.max(1)
        correct += predicted.eq(target).sum().item()
        total += target.size(0)
        train_loss = train_loss_temp/(batch_idx+1)
        train_acc = 100*correct/total 
        
        sys.stdout.write("\x1b[2K\rTrain.. Epoch: {0:3}/{1:3} | Iter: {2:4}/{3:4} | Loss: {4:.4f} | Acc: {5:.4f}% | Time: {now:}".format(epoch, epochs, batch_idx, len(train_loader), train_loss, train_acc, now=(now()-t)))
        
        sys.stdout.flush()
    
    return train_loss, train_acc

In [10]:
def validate():
    global best_acc
    model.eval()
    correct = 0
    val_loss_temp = 0
    val_loss = 0
    val_acc = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            output = model(inputs)
            out_sm = softmax(output)
            log = torch.log(out_sm+1e-7)
            loss = criterion(log, targets)
            
            val_loss_temp += loss.item()
            _, predicted = out_sm.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

    val_loss = val_loss_temp/(batch_idx+1)
    val_acc = 100*correct/total
    
#     sys.stdout.write("\x1b[2K\rValidation.. Epoch: {0:3}/120 | Iter: {1:4}/196 | Loss: {2:.4f} | Acc: {3:.4f}%".format(epoch, batch_idx, train_loss, train_acc))
#     sys.stdout.flush()
        
    return val_loss, val_acc

In [11]:
checkpoint = torch.load('./checkpoint/resnet18_PACT_32_32_current__.pth')
model.load_state_dict(checkpoint['state_dict'])
epoch = checkpoint['epoch']
acc = checkpoint['acc']

start_epoch = epoch

print("Current | Epoch: %d | Accuracy: %.4f" %(epoch, acc))

Current | Epoch: 120 | Accuracy: 69.5340


In [None]:
best_acc_max = 0
best_acc_sum = 0

acc_history = torch.zeros([epochs,2], dtype=torch.float32, device=device)
loss_history = torch.zeros([epochs,2], dtype=torch.float32, device=device)

for epoch in range(0, 120):

    #adjust_learning_rate(optimizer, epoch)
    lr_schedule(optimizer, epoch)
    
    current_time = now()
    # train for one epoch
    train_loss, train_acc = train(optimizer, epoch, epochs)
    print("\n1 Epoch Time : %s" % (now()-current_time))
    print('------------------------------Train---------------------------------')
    print('Loss: ', train_loss)
    print('Acc: ', train_acc)
    
    # evaluate on validation set
    val_loss, val_acc = validate()
    print('------------------------------Validation---------------------------------')
    print("Loss: ", val_loss)
    print("Acc: ", val_acc)
    print('-------------------------------------------------------------------------')


    acc_history[epoch-start_epoch][0] = train_acc
    acc_history[epoch-start_epoch][1] = val_acc
    
    loss_history[epoch-start_epoch][0] = train_loss
    loss_history[epoch-start_epoch][1] = val_loss
    
    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'acc': val_acc,
        'epoch': epoch+1,
    }
    torch.save(state, './checkpoint/resnet18_PACT_2_8_current.pth')
    
    if val_acc > best_acc_max:
        best_acc_max = val_acc
        state = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'acc': val_acc,
            'epoch': epoch+1,
        }
        torch.save(state, './checkpoint/resnet18_PACT_2_8_max.pth')
        
    # Save checkpoint.
    if epoch % 10 == 0:
        print('Saving..')
        state = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'acc': val_acc,
            'epoch': epoch+1,
        }
        torch.save(state, './checkpoint/resnet18_PACT_2_8_epoch_%d.pth' % epoch)

Train.. Epoch:   0/120 | Iter:  417/ 418 | Loss: 2.2084 | Acc: 50.6932% | Time: 0:13:02.342184[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[2K[

In [None]:
checkpoint = torch.load('./checkpoint/name_max.pth')
model.load_state_dict(checkpoint['state_dict'])
epoch = checkpoint['epoch']
acc = checkpoint['acc']

print("Max | Epoch: %d | Accuracy: %.4f" %(epoch, acc))