In [33]:
import math
import os
import random
import time
import shutil

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image

from tqdm import tqdm

from copy import deepcopy

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
# root dir
os.chdir("D:/2021/2학기 수업/CV/SSL/")
print(os.getcwd())

D:\2021\2학기 수업\CV\SSL


In [4]:
resume = ''
eval_steps = 2**10
total_steps = 2**20
batch_size = 64
lr = 0.03
weight_decay = 0.0005
exp_mov_avg_decay = 0.999
mu = 7
lambda_u = 1
threshold = 0.95

num_class = 10
num_labeled_data = 40

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

random.seed(5)
np.random.seed(5)
torch.manual_seed(5)
if device == 'cuda':
    torch.cuda.manual_seed_all(5)

In [5]:
PARAMETER_MAX = 10

def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX

def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)

def AutoContrast(img, **kwarg):
    return PIL.ImageOps.autocontrast(img)

def Brightness(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)

def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)


def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Cutout(img, v, max_v, bias=0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)


def CutoutAbs(img, v, **kwarg):
    w, h = img.size
    x0 = np.random.uniform(0, w)
    y0 = np.random.uniform(0, h)
    x0 = int(max(0, x0 - v / 2.))
    y0 = int(max(0, y0 - v / 2.))
    x1 = int(min(w, x0 + v))
    y1 = int(min(h, y0 + v))
    xy = (x0, y0, x1, y1)
    # gray
    color = (127, 127, 127)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def Equalize(img, **kwarg):
    return PIL.ImageOps.equalize(img)


def Identity(img, **kwarg):
    return img


def Invert(img, **kwarg):
    return PIL.ImageOps.invert(img)


def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)


def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)


def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)


def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

def fixMatchAugPool():
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs



In [6]:
class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixMatchAugPool()
        
    def __call__(self, img):
        ops = random.choices(self.augment_pool, k = self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v = v, max_v = max_v, bias = bias)
        img = CutoutAbs(img, int(32*0.5))
        return img

In [7]:
class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size = 32,
                                  padding = int(32 * 0.125),
                                  padding_mode = 'reflect')])
        
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size = 32,
                                  padding = int(32 * 0.125),
                                  padding_mode = 'reflect'),
            RandAugmentMC(n = 2, m = 10)])
        
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)])
        
    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)

In [8]:
root = './data'
labeled_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size = 32, padding = int(32 * 0.125), padding_mode = 'reflect'),
    transforms.ToTensor(),
    transforms.Normalize(mean = cifar10_mean, std = cifar10_std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = cifar10_mean, std = cifar10_std)
])

full_dataset = datasets.CIFAR10(root, train = True, download = False)

In [9]:
label_per_class = num_labeled_data // num_class
labels = np.array(full_dataset.targets)
train_labeled_idxs = []
train_unlabeled_idxs = np.array(range(len(labels)))
for i in range(num_class):
    idx = np.where(labels == i)[0]
    idx = np.random.choice(idx, label_per_class, False)
    train_labeled_idxs.extend(idx)
train_labeled_idxs = np.array(train_labeled_idxs)
print(len(train_labeled_idxs))

num_expand_x = math.ceil(batch_size * eval_steps / num_labeled_data)
train_labeled_idxs = np.hstack([train_labeled_idxs for _ in range(num_expand_x)])
print(len(train_labeled_idxs), len(train_unlabeled_idxs))

40
65560 50000


In [10]:
class Cifar10SSL(datasets.CIFAR10):
    def __init__(self, root, idxs, train = True, transform = None, target_transform = None, download = False):
        super().__init__(root, train = train, transform = transform, target_transform = target_transform, download = download)
        if idxs is not None:
            self.data = self.data[idxs]
            self.targets = np.array(self.targets)[idxs]
        
    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        img = Image.fromarray(img)
        
        if self.transform is not None:
            img = self.transform(img)
            
        if self.target_transform is not None:
            target = self.target_transform(target)
            
        return img, target

In [11]:
train_labeled_dataset = Cifar10SSL(root, train_labeled_idxs, train=True, transform=labeled_transform)
train_unlabeled_dataset = Cifar10SSL(root, train_unlabeled_idxs, train=True,
                                     transform=TransformFixMatch(mean = cifar10_mean, std = cifar10_std))
test_dataset = datasets.CIFAR10(
    root, train = False, transform = test_transform, download = False)

In [12]:
labeled_trainloader = DataLoader(train_labeled_dataset,
                                 sampler = RandomSampler(train_labeled_dataset),
                                 batch_size = batch_size,
                                 drop_last = True)
unlabeled_trainloader = DataLoader(train_unlabeled_dataset,
                                 sampler = RandomSampler(train_unlabeled_dataset),
                                 batch_size = batch_size * mu,
                                 drop_last = True)
test_loader = DataLoader(test_dataset,
                         sampler = SequentialSampler(test_dataset),
                         batch_size = batch_size)
                         

In [13]:
class WRNBasicBlock(torch.nn.Module):
    def __init__(self, in_planes, out_planes, stride, drop_rate = 0.0, activate_before_residual = False):
        super(WRNBasicBlock, self).__init__()      
        self.bn1 = torch.nn.BatchNorm2d(in_planes, momentum = 0.001)
        self.relu1 = torch.nn.LeakyReLU(negative_slope = 0.1, inplace = True)
        self.conv1 = torch.nn.Conv2d(in_planes, out_planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        
        self.bn2 = torch.nn.BatchNorm2d(out_planes, momentum = 0.001)
        self.relu2 = torch.nn.LeakyReLU(negative_slope = 0.1, inplace = True)
        self.conv2 = torch.nn.Conv2d(out_planes, out_planes, kernel_size = 3, stride = 1, padding = 1, bias = False)
        
        self.drop_rate = drop_rate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and torch.nn.Conv2d(in_planes, out_planes, kernel_size = 1, stride = stride,
                                                                padding = 0, bias = False) or None
        self.activate_before_residual = activate_before_residual
    
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training = self.training)
        out = self.conv2(out)
        
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

In [14]:
class WRNNetworkBlock(torch.nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate = 0.0, activate_before_residual = False):
        super(WRNNetworkBlock, self).__init__()
        self.layer = self._make_layer(
            block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
        
    def _make_layer(
            self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes,
                                i == 0 and stride or 1, drop_rate, activate_before_residual))
        return torch.nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)

In [15]:
class WRN(torch.nn.Module):
    def __init__(self, num_classes, depth = 28, widen_factor = 2, drop_rate = 0.0):
        super(WRN, self).__init__()
        channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        
        # basic block의 depth = 6, basic block 외의 depth = 4
        # 따라서 depth - 4 의 값은 6의 배수여야 함.
        assert((depth - 4) % 6 == 0)
        
        n = (depth - 4) / 6
        block = WRNBasicBlock
        
        self.conv1 = torch.nn.Conv2d(3, channels[0], kernel_size = 3, stride = 1, padding = 1, bias = False)
        
        self.block1 = WRNNetworkBlock(
            n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual = True)
        
        self.block2 = WRNNetworkBlock(
            n, channels[1], channels[2], block, 2, drop_rate)
        
        self.block3 = WRNNetworkBlock(
            n, channels[2], channels[3], block, 2, drop_rate)
        
        self.bn = torch.nn.BatchNorm2d(channels[3], momentum = 0.001)
        self.relu = torch.nn.LeakyReLU(negative_slope = 0.1, inplace = True)
        self.fc = torch.nn.Linear(channels[3], num_classes)
        self.channels = channels[3]
        
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'leaky_relu')
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.constant_(m.weight, 1.0)
                torch.nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_normal_(m.weight)
                torch.nn.init.constant_(m.bias, 0.0)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn(out))
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(-1, self.channels)
        return self.fc(out)

In [16]:
model = WRN(num_class)
model.to(device)

WRN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): WRNNetworkBlock(
    (layer): Sequential(
      (0): WRNBasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu1): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu2): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): WRNBasicBlock(
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
        (relu1): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv1): Conv2d(32, 32, kernel_size

In [17]:
no_decay = ['bias', 'bn']
grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = optim.SGD(grouped_parameters, lr=lr, momentum = 0.9, nesterov = True)

In [18]:
def getCosScheduleWithWarmup(optimizer,
                             num_warmup_steps,
                             num_training_steps,
                             num_cycles = 7./16.,
                             last_epoch = -1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))
    
    return LambdaLR(optimizer, _lr_lambda, last_epoch)

In [19]:
epochs = math.ceil(total_steps / eval_steps)
scheduler = getCosScheduleWithWarmup(optimizer, 0, total_steps)

In [20]:
# exponential moving avg
class ModelEMA(object):
    def __init__(self, model, decay):
        self.ema = deepcopy(model)
        self.ema.to(device)
        self.ema.eval()
        self.decay = decay
        self.ema_has_module = hasattr(self.ema, 'module')
        self.param_keys = [k for k, _ in self.ema.named_parameters()]
        self.buffer_keys = [k for k, _ in self.ema.named_buffers()]
        for p in self.ema.parameters():
            p.requires_grad_(False)
    
    def update(self, model):
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            esd = self.ema.state_dict()
            for k in self.param_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                model_v = msd[j].detach()
                ema_v = esd[k]
                esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v)

            for k in self.buffer_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                esd[k].copy_(msd[j])

In [21]:
ema_model = ModelEMA(model, 0.999)

In [22]:
start_epoch = 0

if resume:
    checkpoint = torch.load(resume)
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

In [23]:
class AvgMeter(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [24]:
def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])

def deInterleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])

In [25]:
def calAccuracy(output, target, topk = (1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [30]:
def test(test_loader, model, epoch):
    batch_time = AvgMeter()
    data_time = AvgMeter()
    losses = AvgMeter()
    top1 = AvgMeter()
    top5 = AvgMeter()
    end = time.time()

    test_loader = tqdm(test_loader, disable = False)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            prec1, prec5 = calAccuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            batch_time.update(time.time() - end)
            end = time.time()
            test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                batch=batch_idx + 1,
                iter=len(test_loader),
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                top1=top1.avg,
                top5=top5.avg,
            ))
        
        test_loader.close()

    return losses.avg, top1.avg

In [34]:
model.zero_grad()
global best_acc
end = time.time()
best_acc = 0

labeled_iter = iter(labeled_trainloader)
unlabeled_iter = iter(unlabeled_trainloader)

model.train()
for epoch in range(start_epoch, epochs):
    batch_time = AvgMeter()
    data_time = AvgMeter()
    losses = AvgMeter()
    losses_x = AvgMeter()
    losses_u = AvgMeter()
    mask_probs = AvgMeter()
    
    p_bar = tqdm(range(eval_steps), disable = False)
    
    for batch_idx in range(eval_steps):
        try:
            inputs_x, targets_x = labeled_iter.next()
        except:            
            labeled_iter = iter(labeled_trainloader)
            inputs_x, targets_x = labeled_iter.next()

        try:
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
        except:
            unlabeled_iter = iter(unlabeled_trainloader)
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
        
        data_time.update(time.time() - end)
        batch_size = inputs_x.shape[0]
        inputs = interleave(torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*mu + 1).to(device)
        targets_x = targets_x.to(device)
        logits = model(inputs)
        logits = deInterleave(logits, 2*mu + 1)
        logits_x = logits[:batch_size]
        logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
        del logits
        
        targets_x = targets_x.long()
        Lx = F.cross_entropy(logits_x, targets_x, reduction = 'mean')
        
        pseudo_label = torch.softmax(logits_u_w.detach(), dim = -1)
        max_probs, targets_u = torch.max(pseudo_label, dim = -1)
        mask = max_probs.ge(threshold).float()
        Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none') * mask).mean()
        
        loss = Lx + lambda_u * Lu
        
        loss.backward()
        losses.update(loss.item())
        losses_x.update(Lx.item())
        losses_u.update(Lu.item())
        optimizer.step()
        scheduler.step()
        ema_model.update(model)
        model.zero_grad()
        
        batch_time.update(time.time() - end)
        end = time.time()
        mask_probs.update(mask.mean().item())
        p_bar.set_description("Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Mask: {mask:.2f}. ".format(
            epoch=epoch + 1,
            epochs=epochs,
            batch=batch_idx + 1,
            iter=eval_steps,
            #lr=scheduler.get_last_lr()[0],
            #data=data_time.avg,
            #bt=batch_time.avg,
            loss=losses.avg,
            loss_x=losses_x.avg,
            loss_u=losses_u.avg,
            mask=mask_probs.avg))
        p_bar.update()
        
    p_bar.close()    
    
    test_model = ema_model.ema
    
    test_loss, test_acc = test(test_loader, test_model, epoch)
    
    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    
    model_to_save = model.module if hasattr(model, 'module') else model
    ema_to_save = ema_model.ema.module if hasattr(ema_model.ema, 'module') else ema_model.ema
    
    save_checkpoint_dir = './checkpoint/checkpoint.pth.tar'
    state = {'epoch': epoch + 1,
             'state_dict': model_to_save.state_dict(),
             'ema_state_dict': ema_to_save.state_dict(),
             'acc': test_acc,
             'best_acc': best_acc,
             'optimizer': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),
            }
    
    torch.save(state, save_checkpoint_dir)
    if is_best:
        shutil.copyfile(save_checkpoint_dir, './checkpoint/model_best.pth.tar')
    


  0%|                                                                                         | 0/1024 [00:54<?, ?it/s][A
Test Iter:  157/ 157. Data: 0.012s. Batch: 0.021s. Loss: 2.9075. top1: 21.43. top5: 72.11. : 100%|█| 157/157 [00:03<00:
  0%|                                                                                         | 0/1024 [00:03<?, ?it/s]

  0%|                                                                                          | 0/157 [00:00<?, ?it/s][A
Test Iter:    1/ 157. Data: 0.014s. Batch: 0.023s. Loss: 2.6418. top1: 21.88. top5: 81.25. :   0%| | 0/157 [00:00<?, ?i[A
Test Iter:    2/ 157. Data: 0.014s. Batch: 0.023s. Loss: 2.7397. top1: 23.44. top5: 77.34. :   0%| | 0/157 [00:00<?, ?i[A
Test Iter:    3/ 157. Data: 0.014s. Batch: 0.023s. Loss: 2.7790. top1: 20.83. top5: 77.60. :   0%| | 0/157 [00:00<?, ?i[A
Test Iter:    4/ 157. Data: 0.014s. Batch: 0.023s. Loss: 2.8472. top1: 21.48. top5: 74.61. :   0%| | 0/157 [00:00<?, ?i[A
Test Iter:    5/ 157

KeyboardInterrupt: 