In [9]:
## please delete below code after convertion in converted script(py) file
## + 필요없는 내용 삭제(초반부 1,3~14열, In[ ]형태의 주석제거)
!jupyter nbconvert --to script utils.ipynb
!sed -i '/^#[ ]In\[/d' utils.py
!sed -i -e '1d;3,14d' utils.py

[NbConvertApp] Converting notebook utils.ipynb to script
[NbConvertApp] Writing 3388 bytes to utils.py


In [3]:
import os
import logging
import shutil
import torch
import torchvision.datasets as dset
import numpy as np
from tools import preproc

In [8]:
def get_data(dataset, data_path, cutout_length, validation):
    dataset = dataset.lower()
    
    if dataset == 'cifar10':
        dset_cls = dset.CIFAR10
        n_classes = 10
    elif dataset == 'mnist':
        dset_cls = dset.MNIST
        n_classes = 10
    elif dataset == 'fashionmnist':
        dset_cls = dset.FashionMNIST
        n_classes = 10
    else :
        raise ValueError(dataset)
        
    trn_transform, val_transform = preproc.data_transforms(dataset, cutout_length)
    trn_data = dset_cls(root=data_path, train=True, download=True, transform=trn_transform)
    
    shape = trn_data.data.shape
    input_channels = 3 if len(shape) == 4 else 1 # 컬러일 경우 shape길이가 4일것임 흑백의 경우는 3
    assert shape[1] == shape[2], "not expected shape = {}".format(shape)
    input_size = shape[1]
    
    ret = [input_size, input_channels, n_classes, trn_data]
    if validation:
        ret.append(dset_cls(root=data_path, train=False, download=True, transform=val_transform))
        
    return ret

In [3]:
def get_logger(file_path):
    logger = logging.getLogger('darts')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)
    
    return logger

In [4]:
def param_size(model):
    """ Compute parameter size in Mb"""
    n_params = sum(np.prod(v.size()) for k, v in model.named_parameters() if not k.startswith('aux_head'))
    return n_params / 1024. / 1024.

In [5]:
class AverageMeter():
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val *n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    
    _, pred = output.topk(maxk, 1 ,True, True)
    pred = pred.t()
    
    if target.ndimension() > 1:
        terget = terget.max(1)[1]
        
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(1.0 / batch_size))
        
    return res

In [7]:
def save_checkpoint(state, ckpt_dir, is_best=False):
    filename = os.path.join(ckpt_dir, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(ckpt_dir, 'best.pth.tar')
        shutil.copyfile(filename, best_filename)