# FlexMatch PyTorch Implementation

## 필요한 패키지 중 이미 다운 받아진 패키지 부르기

In [1]:
import sys, os, copy, random, argparse, math
import numpy as np

import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets
from torchvision import transforms
from colorama import Fore
from tqdm import tqdm

  warn(f"Failed to load image Python extension: {e}")


## Global variable 정의하기
#### PARAMETER_MAX, cifar10의 mean, std

In [2]:
########### 이 값을 두는 이유는 뭘까? ###########
PARAMETER_MAX = 10

# 이미지 정규화를 위한 평균 및 표준편차
mean_cifar10 = (0.4914, 0.4822, 0.4465)
std_cifar10 = (0.2471, 0.2345, 0.2616)

In [3]:
def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX


def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)

## PIL 패키지 내 각종 Data Augmentation 함수 정의

In [4]:
# Augmentation 함수들을 정의

def AutoContrast(img, **kwargs):
    return PIL.ImageOps.autocontrast(img)


def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)


def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def CutoutAbs(img, v, **kwargs):
    w, h = img.size
    x0, y0 = np.random.uniform(0, w), np.random.uniform(0, h)
    x0, y0 = int(max(0, x0 - v / 2.)), int(max(0, y0 - v / 2.))

    x1, y1 = int(min(w, x0 + v)), int(min(h, y0 + v))

    xy = (x0, y0, x1, y1)
    # gray
    color = (127, 127, 127)
    img = img.copy()
    
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def Cutout(img, v, max_v, bias=0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)


def Equalize(img, **kwargs):
    return PIL.ImageOps.equalize(img)


def Identity(img, **kwargs):
    return img


def Invert(img, **kwargs):
    return PIL.ImageOps.invert(img)


def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)


def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)


def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)


def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

In [5]:
# RandAugment를 사용하기 위한 전체 Augmentation List를 정의

def flexmatch_augment_pool():
    
    '''
    augs: 활용할 Augmentation의 전체집합
    '''
    
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs

## RandAugment를 위한 class 정의

In [6]:
# 위에서 구현된 Augmentpool에서 랜덤으로 선정하여 실제 Augmentation을 구현

class RandAugmentMC(object):
    
    def __init__(self, n, m):
        
        '''
        초기값 지정
        n: 1~
        m: 1~10
        augment_pool: augmentation 함수들이 모여있는 집합
        '''
        
        assert n >= 1
        assert 1 <= m <= 10
        
        self.n = n
        self.m = m
        self.augment_pool = flexmatch_augment_pool()
    
    def __call__(self, img):
        
        '''
        1. 함수가 불리면 augment_pool에서 n개만큼 선택
        2. m범위에서 랜덤하게 operation 강도를 선정
        3. 50$의 확률로 위 2가지 과정을 진행할지 결정
        4. 마지막에는 Cutout Augmentation 진행
        '''
        
        ops = random.choices(self.augment_pool, k=self.n)
        
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)

        img = CutoutAbs(img, int(32*0.5))
        
        return img

In [7]:
# train_data를 생성하는 함수

class CIFAR10_SSL(datasets.CIFAR10):
    
    def __init__(self, root, indexs, train=True,
                transform=None, target_transform=None,
                download=False):
        
        '''
        초기값 지정: Indexs가 None이 아니면, 해당 Index만큼 Train으로 설정할 것임
        self.data: train_x
        self.targets: train_y
        '''
        
        super(CIFAR10_SSL, self).__init__(
            root, train=train, transform=transform,
            target_transform=target_transform, download=download
        )

        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
    
    def __getitem__(self, index):
        
        '''
        getitem: index에 접근할 때 작동하는 함수
        1. self.data 및 self.targets (즉, train_x, train_y)에서 각각 index에 해당하는 값을 불러온다.
        2. transform이 지정되었다면, img를 Transform(Augmentation) 진행
        '''
        
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target

In [8]:
# weak_augmentation과 strong_augmentation된 객체를 반환

class TransformFlexMatch(object):
    
    def __init__(self, mean=mean_cifar10, std=std_cifar10):
        
        '''
        Augmentation하는 함수 초깃값 지정
        self.weak_transform: 약한 왜곡의 Augmentation으로 구성
        self.strong_transform: 큰 왜곡의 Augmentation으로 구성 --> Weak Augmentation에 추가적인 왜곡을 지정
        self.normalize: 정규화하는 함수 정의 ((N, H, W, C)-> (N, C, H, W))
        '''
        
        self.weak_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                padding=int(32*0.125),
                                padding_mode='reflect')
        ])

        self.strong_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                padding=int(32*0.125),
                                padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)
        ])

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ]) 
    
    
    def __call__(self, x):
        
        '''
        함수가 불리면 Weak Aug객체와 Strong Aug객체를 각각 생성 후 정규화한 값들을 반환
        '''
        
        weak = self.weak_transform(x)
        strong = self.strong_transform(x)

        return self.normalize(weak), self.normalize(strong)

In [9]:
# Labeled data와 Unlabeled data를 분리

def split_labeled_unlabeled(args, labels):
    
    '''
    1. 클래스 당 Labeled data의 개수를 정의
    2. Labeled data, Unlabeled data, Validation data의 Index를 담을 수 있는 List 초기화
    3. 각 Label별로 1에서 정의한 개수만큼 Labeled data를 지정하고, Validation data는 500개, 그 외 데이터는 모두 Unlabeled data로 지정
    4. 각 Index를 Shuffle
    5. Return Labeled data의 Index, Unlabeled data의 Index, Validation data의 Index
    '''
    
    label_per_class = args.n_labeled // args.n_classes
    labels = np.array(labels, dtype=int)
    indice_labeled, indice_unlabeled, indice_val = [], [], []

    for i in range(10):
        indice_tmp = np.where(labels==i)[0]

        indice_labeled.extend(indice_tmp[: label_per_class])
        indice_unlabeled.extend(indice_tmp[label_per_class: -500])
        indice_val.extend(indice_tmp[-500: ])
    
    for i in [indice_labeled, indice_unlabeled, indice_val]:
        np.random.shuffle(i)
    
    return np.array(indice_labeled), np.array(indice_unlabeled), np.array(indice_val)

In [10]:
def get_cifar10(args, data_dir):
    
    '''
    1. labeled data의 tranform정의 
    2. validation data의 tranform 정의: 정규화만 진행
    3. Cifar10 데이터셋을 불러온 후 Index에 따라 Labeled, Unlabeled, Validation data를 분류
    '''
    
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=int(32*0.125), padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean_cifar10, std=std_cifar10)
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean_cifar10, std=std_cifar10)
    ])

    base_dataset = datasets.CIFAR10(data_dir, train=True, download=True)

    indice_labeled, indice_unlabeled, indice_val = split_labeled_unlabeled(args, base_dataset.targets)

    '''
    4. labeled dataset에 대해서는 transform_labeled augmentation 만 적용
    5. Unlabeled dataset에 대해서는 transform_labeled augmentation 및 strong augmentation 동시 적용 
    6. validation, test dataset에 대해서는 ToTensor & Normalization transformation 만 적용
    '''
    
    labeled_dataset = CIFAR10_SSL(
        data_dir, indice_labeled, train=True,
        transform=transform_labeled
    )

    unlabeled_dataset = CIFAR10_SSL(
        data_dir, indice_unlabeled, train=True,
        transform=TransformFlexMatch(mean=mean_cifar10, std=std_cifar10)
    )

    val_dataset = CIFAR10_SSL(
        data_dir, indice_val, train=True, transform=transform_val, download=False
    )

    test_dataset = datasets.CIFAR10(
        data_dir, train=False, transform=transform_val, download=False
    )
    
    return labeled_dataset, unlabeled_dataset, val_dataset, test_dataset

## WideResNet (MixMatch 와 동일)
 - WideResNet Model 정의

In [11]:
# BasicBlock을 정의
class BasicBlock(nn.Module):
    
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
        self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
        self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
        self.activate_before_residual = activate_before_residual
        
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

    
# Network Block을 정의
class NetworkBlock(nn.Module):
    
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual)
        
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)


# WideResNet 모델 정의
class WideResNet(nn.Module):
    
    '''
    위에서 정의한 Basic Block 및 Network Block을 기반으로 Wide ResNet 모델 정의
    '''
    
    def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

In [12]:
# WeightEMA로 Parameter를 Update하는 함수를 정의 (EMA=Exponential Moving Average)

class WeightEMA(object): 
    
    '''
    MixMatch와 hyperparameter 이름만 변경
    WeightEMA를 하는 이유는 학습시간이 길어지거나, Trivial Solution을 방지, 과적합 방지 등. 
    --> 가중치를 업데이트 시 a(최근가중치)+(1-a)(이전가중치)
    --> summary: ema_params_new = self.decay*ema_params_old + (1-self.decay)*params
    '''
    
    def __init__(self, model, decay):
        
        self.ema = copy.deepcopy(model)
        self.ema.eval()

        self.decay = decay

        self.ema_has_module = hasattr(self.ema, 'module')

        self.param_keys = [k for k, _ in self.ema.named_parameters()]
        self.buffer_keys = [k for k, _ in self.ema.named_buffers()]
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def step(self, model):
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            esd = self.ema.state_dict()
            for k in self.param_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                model_v = msd[j].detach()
                ema_v = esd[k]
                esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v)

            for k in self.buffer_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                esd[k].copy_(msd[j])

In [13]:
# TopK Accuracy를 구하는 함수를 정의
def accuracy(output, target, topk=(1, )):
    
    '''  
    Pred값이 TopK개내에 있다면, 맞춘 것으로 정의
    '''
    
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        if k == 1:
            correct_k = correct[:k].view(-1).float().sum(0)
        if k > 1:
            correct_k = correct[:k].float().sum(0).sum(0)
        acc = correct_k.mul_(100.0 / batch_size)
        acc = acc.detach().cpu().numpy()
        res.append(acc)
        
    return res

In [14]:
# tqdm config 함수 정의
def get_tqdm_config(total, leave=True, color='white'):
    fore_colors = {
        'red': Fore.LIGHTRED_EX,
        'green': Fore.LIGHTGREEN_EX,
        'yellow': Fore.LIGHTYELLOW_EX,
        'blue': Fore.LIGHTBLUE_EX,
        'magenta': Fore.LIGHTMAGENTA_EX,
        'cyan': Fore.LIGHTCYAN_EX,
        'white': Fore.LIGHTWHITE_EX,
    }
    return {
        'file': sys.stdout,
        'total': total,
        'desc': " ",
        'dynamic_ncols': True,
        'bar_format':
            "{l_bar}%s{bar}%s| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" % (fore_colors[color], Fore.RESET),
        'leave': leave
    }

In [15]:
# Warmup을 적용한 Learning rate Scheduler 적용
def get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps,
    num_cycles=7.0/16.0, last_epoch=-1
    ):
    
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step)/float(max(1, num_warmup_steps))
        
        no_progress = float(current_step-num_warmup_steps)/\
            (float(max(1, num_training_steps-num_warmup_steps)))
        return max(0.0, math.cos(math.pi*num_cycles*no_progress))
    
    return LambdaLR(optimizer, _lr_lambda, last_epoch)

In [16]:
# trainer를 정의
class FlexMatchTrainer():
    
    '''
    DataLoader 내 num_workers 옵션에 대한 사설
    Window10는 다중 CPU 코어 사용 시 순차적으로 작동 시작
    Linux(Ubuntu, CentOS) 계열은 동시에 CPU 코어 작동을 시작 가능
    (Windows10+PyTorch)를 사용해 Deep Learning 모델 학습 시 num_workers=0을 사용하는 것을 권유
    (Linux계열 운영체제+PyTorch)를 사용해 Deep Learning 모델 학습 시 CPU&GPU 사용량이 최대가 될 수 있도록 num_workers 조정 권유
    '''
    
    # 초깃값 지정
    def __init__(self, args):
        
        '''
        초깃값 지정
        1. argument
        2. directory
        3. Dataset
        4. DataLoader
        5. Model(EMA Model), Optimzer, Model_parameter, LR Scheduler, Loss Function
        6. Tensorboard 객체
        '''
        
        # argment를 받아오기
        self.args = args
        
        # 각종 Directory를 지정
        root_dir = '/content/FlexMatch' ### Project Directory
        data_dir = os.path.join(root_dir, 'data') ### Data Directory
        
        self.experiment_dir = os.path.join(root_dir, 'results') ### 학습된 모델을 저장할 큰 폴더
        os.makedirs(self.experiment_dir, exist_ok=True)

        name_exp = "_".join([str(self.args.n_labeled), str(self.args.T)]) ### 학습된 모델을 저장할 세부 폴더 (하이퍼파라미터로 지정)
        self.experiment_dir = os.path.join(self.experiment_dir, name_exp)
        os.makedirs(self.experiment_dir, exist_ok=True)
        
        # Load Dataset (Labeled, Unlabeled, Valid, Test dataset)
        print("==> Preparing CIFAR10 dataset")
        labeled_set, unlabeled_set, val_set, test_set = get_cifar10(self.args, data_dir=data_dir)
        
        # DataLoader를 각각 정의 (Labeled, Unlabeled, Valid, Test dataset)                 
        self.labeled_loader = DataLoader(
            labeled_set,
            sampler=RandomSampler(labeled_set), ### RandomSampler: DataLoader(shuffle=True) 와 동일한 역할
            batch_size=self.args.batch_size,
            num_workers=0,
            drop_last=True
        )

        self.unlabeled_loader = DataLoader(
            unlabeled_set,
            sampler=RandomSampler(unlabeled_set),
            batch_size=self.args.batch_size,
            num_workers=0,
            drop_last=True
        )

        self.val_loader = DataLoader(
            val_set,
            sampler=SequentialSampler(val_set), ### SequentialSampler: DataLoader(shuffle=False) 와 동일한 역할
            batch_size=self.args.batch_size,
            num_workers=0,
            drop_last=True
        )

        self.test_loader = DataLoader(
            test_set,
            sampler=SequentialSampler(test_set),
            batch_size=self.args.batch_size,
            num_workers=0
        )

        # WideResNet모델 정의
        print("==> Preparing WideResNet")
        self.model = WideResNet(self.args.n_classes).to(self.args.cuda)
        
        # 모델의 Gradient 초기화 및 Loss Function을 정의
        self.model.zero_grad()
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.cuda)

        # Optimzer를 정의: params의 이름 내 bias, bn이 들어가지 않는 경우에만 weight_decay 적용
        no_decay = ['bias', 'bn']
        grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] 
        self.optimizer = torch.optim.SGD(grouped_parameters, lr=self.args.lr,
                            momentum=0.9, nesterov=self.args.nesterov)
        
        # Learning rate Scheduler를 적용
        
        '''
        Learning scheduler의 경우 사용이 까다로움
         - 특정 scheduler는 각각 iteration 마다 step을 진행
         - 또 다른 scheduler그룹은 한 epoch 종료 후 step 진행
         - 아래 Documentation 에서 사용할 lr_scheduler에 대한 설명을 정확히 읽고 사용
           - https://pytorch.org/docs/stable/optim.html
        '''
        
        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                    self.args.warmup,
                                                    self.args.total_steps)
        
        # EMA Model을 쓸건지 안 쓸건지 명시
        if self.args.use_ema:  
            self.ema_model = WeightEMA(self.model, self.args.ema_decay)
        
        # Tensorboard에 기록할 객체 정의
        self.writer = SummaryWriter(self.experiment_dir)

        
    # train을 위한 함수
    def train(self, epoch):
        
        # total, labeled, unlabeled loss 초기화 및 Mask probs(Threshold를 넘었는지 여부를 표시한 것) 초기화
        losses_t, losses_x, losses_u, mask_probs = 0.0, 0.0, 0.0, 0.0
        
        # 훈련모드 전환
        self.model.train()
        
        # iter함수로 Labeled data 및 Unlabeled data 불러오기
        iter_labeled = iter(self.labeled_loader)
        iter_unlabeled = iter(self.unlabeled_loader)

        with tqdm(**get_tqdm_config(total=self.args.eval_step,
                leave=True, color='blue')) as pbar:
            
            for batch_idx in range(self.args.eval_step): ### eval_step: 1024 // batch_size: 64
                
                '''
                왜 try-except 문을 사용하나?
                 - 코드 작성 후 iter&next가 정확히 작용하지 않는 경우가 있음을 확인
                 - 다시 iter_labeled, iter_unlabeled를 정의해 학습에 문제가 없도록 다시 선언
                '''
                
                ### Labeled Data(각각 데이터와 Target)
                try:
                    inputs_x, targets_x = next(iter_labeled)
                except:
                    iter_labeled = iter(self.labeled_loader)
                    inputs_x, targets_x = next(iter_labeled)
                real_B = inputs_x.size(0)
                
                ### Unlabeled Data (각각 Weak Aug, Strong Aug)
                try:
                    (inputs_u_w, inputs_u_s), _ = next(iter_unlabeled)
                except:
                    iter_unlabeled = iter(self.unlabeled_loader)
                    (inputs_u_w, inputs_u_s), _ = next(iter_unlabeled)
                
                ### Labeled data, Weak_aug Unlabeled data, Strong_aug Unlabeled data Concat하여 Input으로 활용
                inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s), dim=0).to(self.args.cuda)
                targets_x = targets_x.type(torch.LongTensor)
                targets_x = targets_x.to(self.args.cuda)
                
                logits = self.model(inputs) ##### 예측값이 들어있음
                
                ### Labeled data와 Unlabeled data를 구분
                
                '''
                real_B까지가 Labeled data Index, 그 외가 Unlabeled임
                --> chunk함수로 weak_aug 및 strong_aug 구분 (Unlabeled data에 이미 weak, strong aug 각각 적용한 객체가 남아있는 형태)
                '''
                
                logits_x = logits[:real_B]
                logits_u_w, logits_u_s = logits[real_B:].chunk(2)
                del(logits)

                # Labeled data에 대한 loss계산
                loss_x = F.cross_entropy(logits_x, targets_x, reduction='mean')

                # Unlabeled data에 대한 loss계산
                
                '''
                Unlabeled 데이터에 대한 로짓 산출 및 Temparature hyperparameter를 사용한 Sharpening
                 --> Pseudo label 생성
                1) Unlabeled data에 대한 예측값(logits_u_w)에 Softmax를 통과시킨 후 Sharpen 적용
                2) 가장 높은 확률을 Label로 지정 (targets_u)
                3) threshold값과 비교하여 mask 객체 생성
                 - 이는 각 샘플에 대하여 확률이 도출되고, 배치 내 있는 데이터 만큼 Threshold를 넘었는지 여부를 T/F로 도출 [T, T, F, T..]
                 - 근데, 지금 1개씩 가져와서 실험하다보니 결국 1개 sample에 대해서만 진행
                '''
                
                pseudo_labels = torch.softmax(logits_u_w.detach()/self.args.T, dim=-1) 
                max_prob, targets_u = torch.max(pseudo_labels, dim=-1)
                mask = max_prob.ge(self.args.threshold).float() ##### mask: Threshold보다 크면 True, 작으면 False를 반환

                ### strong augmentation된 이미지에서 산출된 logit과 Pseudo label 사이 cross_entropy 계산
                '''
                여기서 mask를 곱해줌으로써 True면 1, False면 0을 곱해주게 된다.
                --> 이를 통해 False일 경우 Loss연산에 이를 반영하지 않음
                '''
                loss_u = (F.cross_entropy(logits_u_s, targets_u, reduction='none')*mask).mean()

                ### Total loss: Labeled data loss와 Unlabeled data loss의 가중합
                loss = loss_x + self.args.lambda_u * loss_u
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                if self.args.use_ema:
                    self.ema_model.step(self.model)
                
                self.model.zero_grad()
                
                ### Tensorboard를 위해 loss값들을 기록
                losses_x += loss_x.item()
                losses_u += loss_u.item()
                losses_t += loss.item()
                mask_probs += max_prob.mean().item()
                
                ### Print log
                self.writer.add_scalars(
                    'Training steps', {
                        'Total_loss': losses_t/(batch_idx+1),
                        'Labeled_loss':losses_x/(batch_idx+1),
                        'Unlabeled_loss':losses_u/(batch_idx+1),
                        'Mask probs': mask_probs/(batch_idx+1)
                    }, global_step=epoch*self.args.batch_size+batch_idx
                )

                pbar.set_description(
                    '[Train(%4d/ %4d)-Total: %.3f|Labeled: %.3f|Unlabeled: %.3f]'%(
                        (batch_idx+1), self.args.eval_step,
                        losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)
                    )
                )
                pbar.update(1)

            pbar.set_description(
                '[Train(%4d/ %4d)-Total: %.3f|Labeled: %.3f|Unlabeled: %.3f]'%(
                    epoch, self.args.epochs,
                    losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)
                )
            )
        return losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)

    
    # Validation 함수 (MixMatch와 동일)
    @torch.no_grad()
    def validate(self, epoch, phase):
        if phase == 'Train': ### Train Loss
            data_loader = self.labeled_loader
            c = 'blue'
        elif phase == 'Valid': ### Valid Loss
            data_loader = self.val_loader
            c = 'green'
        elif phase == 'Test ': ### Test Loss
            data_loader = self.test_loader
            c = 'red'
        
        ### 값 초기화
        losses = 0.0
        top1s, top5s = [], []
        
        ### 데이터를 넣은 후 Output 및 Loss값, 정확도 산출
        with tqdm(**get_tqdm_config(total=len(data_loader),
                leave=True, color=c)) as pbar:
            for batch_idx, (inputs, targets) in enumerate(data_loader):
                
                targets = targets.type(torch.LongTensor)
                inputs, targets = inputs.to(self.args.cuda), targets.to(self.args.cuda)

                outputs = self.ema_model.ema(inputs)
                loss = self.criterion(outputs, targets)

                prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
                losses += loss.item()
                top1s.append(prec1)
                top5s.append(prec5)

                self.writer.add_scalars(
                    f'{phase} steps', {
                        'Total_loss': losses/(batch_idx+1),
                        'Top1 Acc': np.mean(top1s),
                        'Top5 Acc': np.mean(top5s)
                    }, global_step=epoch*self.args.batch_size+batch_idx
                )

                pbar.set_description(
                    '[%s-Loss: %.3f|Top1 Acc: %.3f|Top5 Acc: %.3f]'%(
                        phase,
                        losses/(batch_idx+1), np.mean(top1s), np.mean(top5s)
                    )
                )
                pbar.update(1)

            pbar.set_description(
                '[%s(%4d/ %4d)-Loss: %.3f|Top1 Acc: %.3f|Top5 Acc: %.3f]'%(
                    phase,
                    epoch, self.args.epochs,
                    losses/(batch_idx+1), np.mean(top1s), np.mean(top5s)
                )
            )

        return losses/(batch_idx+1), np.mean(top1s), np.mean(top5s)

In [17]:
# Argument 정의
def FlexMatch_parser():
    parser = argparse.ArgumentParser(description="FlexMatch PyTorch Implementation for LG Electornics education")
    
    # method arguments
    parser.add_argument('--n-labeled', type=int, default=4000) # labeled dat의 수
    parser.add_argument('--n-classes', type=int, default=10) # Class의 수
    parser.add_argument("--expand-labels", action="store_true", 
                        help="expand labels to fit eval steps")

    # training hyperparameters
    parser.add_argument('--batch-size', type=int, default=64) # 배치 사이즈
    parser.add_argument('--total-steps', default=2**20, type=int) # iteration마다 Scheduler가 적용되기에, Epoch가 아닌, Total-step을 정의
    parser.add_argument('--eval-step', type=int, default=1024) # Evaluation Step의 수
    parser.add_argument('--lr', type=float, default=0.03) # Learning rate
    parser.add_argument('--weight-decay', type=float, default=5e-4) # Weight Decay 정도
    parser.add_argument('--nesterov', action='store_true', default=True) # Nesterov Momentum
    parser.add_argument('--warmup', type=float, default=0.0) # Warmup 정도

    parser.add_argument('--use-ema', action='store_true', default=True) # EMA 사용여부
    parser.add_argument('--ema-decay', type=float, default=0.999) # EMA에서 Decay 정도

    parser.add_argument('--mu', type=int, default=7) # Labeled data의 mu배를 Unlabeled 데이터의 개수로 정의하기 위한 함수 (근데 위 Trainer에서는 안 쓰임)
    parser.add_argument('--T', type=float, default=1.0) # Sharpening 함수에 들어가는 하이퍼 파라미터

    parser.add_argument('--threshold', type=float, default=0.95) # Pseudo-Labeling이 진행되는 Threshold 정의
    parser.add_argument('--lambda-u', type=float, default=1.0) # Loss 가중치 정도
    return parser

In [18]:
# main함수 정의

def main():
    
    '''
    1. 사용자의 Parser를 받아온 후, Cuda지정 및 epoch 산정
    2. Trainer를 정의
    3. Loss값 초기화 
    4. 전체 Loss, Labeled data의 Loss, Unlabeled data의 Loss를 담을 리스트 초기화
    5. Train, Valid, Test 각각에 대해 Loss, top1 acc, top5 acc를 저장하기 위한 리스트 초기화
    6. 각 Epoch 단위로 학습할 때 마다 성능들을 기록 및 Checkpoint 저장
    7. 학습 중 Best_loss보다 개선되면, Best Loss를 변환 및 Model Save
    '''
    
    parser = FlexMatch_parser()
    args = parser.parse_args([])
    args.cuda = torch.device("cuda:0")
    args.epochs = math.ceil(args.total_steps/args.eval_step)

    trainer = FlexMatchTrainer(args)

    best_loss = np.inf
    losses, losses_x, losses_u = [], [], []
    
    train_losses, train _top1s, train_top5s = [], [], []
    val_losses, val_top1s, val_top5s = [], [], []
    test_losses, test_top1s, test_top5s = [], [], []
    
    # 각 Epoch단위로 학습할 때 마다 성능들을 기록
    for epoch in range(1, args.epochs+1, 1):
        loss, loss_x, loss_u = trainer.train(epoch)
        losses.append(loss)
        losses_x.append(loss_x)
        losses_u.append(loss_u)

        loss, top1, top5 = trainer.validate(epoch, 'Train')
        train_losses.append(loss)
        train_top1s.append(top1)
        train_top5s.append(top5)

        loss, top1, top5 = trainer.validate(epoch, 'Valid')
        val_losses.append(loss)
        val_top1s.append(top1)
        val_top5s.append(top5)

        if loss < best_loss:
            best_loss = loss
            torch.save(trainer.model, os.path.join(trainer.experiment_dir, 'model.pth'))
            torch.save(trainer.ema_model, os.path.join(trainer.experiment_dir, 'ema_model.pth'))

        loss, top1, top5 = trainer.validate(epoch, 'Test ')
        test_losses.append(loss)
        test_top1s.append(top1)
        test_top5s.append(top5)

        torch.save(trainer.model, os.path.join(trainer.experiment_dir, 'checkpooint_model.pth'))
        torch.save(trainer.ema_model, os.path.join(trainer.experiment_dir, 'checkpoint_ema_model.pth'))

In [19]:
# 실행
if __name__=="__main__":
    main()

==> Preparing CIFAR10 dataset
Files already downloaded and verified
==> Preparing WideResNet
[Train(   1/ 1024)-Total: 1.262|Labeled: 1.216|Unlabeled: 0.046]: 100%|[94m██████████████████████[39m| [03:35<00:00,  4.75it/s][0m
[Train(   1/ 1024)-Loss: 1.703|Top1 Acc: 39.919|Top5 Acc: 85.912]: 100%|[94m█████████████████████[39m| [00:02<00:00, 23.07it/s][0m
[Valid(   1/ 1024)-Loss: 1.708|Top1 Acc: 39.323|Top5 Acc: 86.218]: 100%|[92m█████████████████████[39m| [00:02<00:00, 34.75it/s][0m
[Test (   1/ 1024)-Loss: 1.712|Top1 Acc: 39.023|Top5 Acc: 86.186]: 100%|[91m█████████████████████[39m| [00:04<00:00, 35.32it/s][0m
[Train(   2/ 1024)-Total: 0.838|Labeled: 0.689|Unlabeled: 0.148]: 100%|[94m██████████████████████[39m| [03:31<00:00,  4.83it/s][0m
[Train(   2/ 1024)-Loss: 0.877|Top1 Acc: 70.363|Top5 Acc: 97.833]: 100%|[94m█████████████████████[39m| [00:02<00:00, 24.09it/s][0m
[Valid(   2/ 1024)-Loss: 0.970|Top1 Acc: 66.346|Top5 Acc: 96.615]: 100%|[92m█████████████████████[39m

[Train(  31/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 23.43it/s][0m
[Valid(  31/ 1024)-Loss: 0.474|Top1 Acc: 88.822|Top5 Acc: 99.559]: 100%|[92m█████████████████████[39m| [00:02<00:00, 35.38it/s][0m
[Test (  31/ 1024)-Loss: 0.501|Top1 Acc: 87.998|Top5 Acc: 99.482]: 100%|[91m█████████████████████[39m| [00:04<00:00, 35.61it/s][0m
[Train(  32/ 1024)-Total: 0.381|Labeled: 0.062|Unlabeled: 0.319]: 100%|[94m██████████████████████[39m| [03:32<00:00,  4.83it/s][0m
[Train(  32/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 23.52it/s][0m
[Valid(  32/ 1024)-Loss: 0.473|Top1 Acc: 89.183|Top5 Acc: 99.559]: 100%|[92m█████████████████████[39m| [00:02<00:00, 35.26it/s][0m
[Test (  32/ 1024)-Loss: 0.494|Top1 Acc: 88.376|Top5 Acc: 99.473]: 100%|[91m█████████████████████[39m| [00:04<00:00, 35.56it/s][0m
[Train(  33/ 1024)-Total: 0.377|Labeled: 0.057|Unlabeled: 0.32

[Test (  61/ 1024)-Loss: 0.444|Top1 Acc: 89.640|Top5 Acc: 99.612]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.29it/s][0m
[Train(  62/ 1024)-Total: 0.365|Labeled: 0.048|Unlabeled: 0.317]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.05it/s][0m
[Train(  62/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.85it/s][0m
[Valid(  62/ 1024)-Loss: 0.409|Top1 Acc: 90.325|Top5 Acc: 99.700]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.12it/s][0m
[Test (  62/ 1024)-Loss: 0.438|Top1 Acc: 89.729|Top5 Acc: 99.572]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.26it/s][0m
[Train(  63/ 1024)-Total: 0.358|Labeled: 0.049|Unlabeled: 0.309]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.05it/s][0m
[Train(  63/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.83it/s][0m
[Valid(  63/ 1024)-Loss: 0.411|Top1 Acc: 90.325|Top5 Acc: 99.7

[Train(  92/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.94it/s][0m
[Valid(  92/ 1024)-Loss: 0.389|Top1 Acc: 90.705|Top5 Acc: 99.740]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.05it/s][0m
[Test (  92/ 1024)-Loss: 0.418|Top1 Acc: 90.197|Top5 Acc: 99.642]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.24it/s][0m
[Train(  93/ 1024)-Total: 0.350|Labeled: 0.043|Unlabeled: 0.307]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.05it/s][0m
[Train(  93/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.96it/s][0m
[Valid(  93/ 1024)-Loss: 0.385|Top1 Acc: 90.805|Top5 Acc: 99.820]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.05it/s][0m
[Test (  93/ 1024)-Loss: 0.418|Top1 Acc: 90.098|Top5 Acc: 99.612]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.24it/s][0m
[Train(  94/ 1024)-Total: 0.349|Labeled: 0.042|Unlabeled: 0.30

[Test ( 122/ 1024)-Loss: 0.392|Top1 Acc: 90.764|Top5 Acc: 99.691]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.23it/s][0m
[Train( 123/ 1024)-Total: 0.347|Labeled: 0.042|Unlabeled: 0.305]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.04it/s][0m
[Train( 123/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.95it/s][0m
[Valid( 123/ 1024)-Loss: 0.374|Top1 Acc: 90.865|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.06it/s][0m
[Test ( 123/ 1024)-Loss: 0.399|Top1 Acc: 90.516|Top5 Acc: 99.622]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.20it/s][0m
[Train( 124/ 1024)-Total: 0.344|Labeled: 0.040|Unlabeled: 0.304]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.06it/s][0m
[Train( 124/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.87it/s][0m
[Valid( 124/ 1024)-Loss: 0.371|Top1 Acc: 90.946|Top5 Acc: 99.7

[Train( 153/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.93it/s][0m
[Valid( 153/ 1024)-Loss: 0.352|Top1 Acc: 91.226|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.19it/s][0m
[Test ( 153/ 1024)-Loss: 0.386|Top1 Acc: 90.864|Top5 Acc: 99.652]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.32it/s][0m
[Train( 154/ 1024)-Total: 0.340|Labeled: 0.039|Unlabeled: 0.301]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.05it/s][0m
[Train( 154/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.33it/s][0m
[Valid( 154/ 1024)-Loss: 0.350|Top1 Acc: 91.366|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.19it/s][0m
[Test ( 154/ 1024)-Loss: 0.388|Top1 Acc: 90.844|Top5 Acc: 99.642]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.25it/s][0m
[Train( 155/ 1024)-Total: 0.336|Labeled: 0.039|Unlabeled: 0.29

[Test ( 183/ 1024)-Loss: 0.385|Top1 Acc: 91.083|Top5 Acc: 99.642]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.27it/s][0m
[Train( 184/ 1024)-Total: 0.340|Labeled: 0.037|Unlabeled: 0.303]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.06it/s][0m
[Train( 184/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.81it/s][0m
[Valid( 184/ 1024)-Loss: 0.344|Top1 Acc: 91.246|Top5 Acc: 99.800]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.14it/s][0m
[Test ( 184/ 1024)-Loss: 0.391|Top1 Acc: 91.003|Top5 Acc: 99.652]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.23it/s][0m
[Train( 185/ 1024)-Total: 0.344|Labeled: 0.039|Unlabeled: 0.304]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.05it/s][0m
[Train( 185/ 1024)-Loss: 0.001|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.89it/s][0m
[Valid( 185/ 1024)-Loss: 0.349|Top1 Acc: 91.186|Top5 Acc: 99.7

[Train( 214/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.79it/s][0m
[Valid( 214/ 1024)-Loss: 0.343|Top1 Acc: 91.767|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.06it/s][0m
[Test ( 214/ 1024)-Loss: 0.375|Top1 Acc: 91.212|Top5 Acc: 99.672]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.30it/s][0m
[Train( 215/ 1024)-Total: 0.333|Labeled: 0.037|Unlabeled: 0.295]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.06it/s][0m
[Train( 215/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.99it/s][0m
[Valid( 215/ 1024)-Loss: 0.337|Top1 Acc: 91.567|Top5 Acc: 99.800]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.08it/s][0m
[Test ( 215/ 1024)-Loss: 0.373|Top1 Acc: 90.943|Top5 Acc: 99.662]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.30it/s][0m
[Train( 216/ 1024)-Total: 0.330|Labeled: 0.038|Unlabeled: 0.29

[Test ( 244/ 1024)-Loss: 0.371|Top1 Acc: 91.302|Top5 Acc: 99.592]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.23it/s][0m
[Train( 245/ 1024)-Total: 0.333|Labeled: 0.038|Unlabeled: 0.295]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.06it/s][0m
[Train( 245/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.93it/s][0m
[Valid( 245/ 1024)-Loss: 0.331|Top1 Acc: 91.647|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.08it/s][0m
[Test ( 245/ 1024)-Loss: 0.371|Top1 Acc: 91.182|Top5 Acc: 99.652]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.32it/s][0m
[Train( 246/ 1024)-Total: 0.332|Labeled: 0.035|Unlabeled: 0.297]: 100%|[94m██████████████████████[39m| [03:22<00:00,  5.06it/s][0m
[Train( 246/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.91it/s][0m
[Valid( 246/ 1024)-Loss: 0.327|Top1 Acc: 91.847|Top5 Acc: 99.8

[Train( 275/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.86it/s][0m
[Valid( 275/ 1024)-Loss: 0.342|Top1 Acc: 91.667|Top5 Acc: 99.840]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.05it/s][0m
[Test ( 275/ 1024)-Loss: 0.369|Top1 Acc: 91.093|Top5 Acc: 99.652]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.23it/s][0m
[Train( 276/ 1024)-Total: 0.328|Labeled: 0.034|Unlabeled: 0.294]: 100%|[94m██████████████████████[39m| [03:25<00:00,  4.99it/s][0m
[Train( 276/ 1024)-Loss: 0.000|Top1 Acc: 100.000|Top5 Acc: 100.000]: 100%|[94m███████████████████[39m| [00:02<00:00, 24.91it/s][0m
[Valid( 276/ 1024)-Loss: 0.338|Top1 Acc: 91.567|Top5 Acc: 99.780]: 100%|[92m█████████████████████[39m| [00:02<00:00, 38.05it/s][0m
[Test ( 276/ 1024)-Loss: 0.368|Top1 Acc: 91.152|Top5 Acc: 99.632]: 100%|[91m█████████████████████[39m| [00:04<00:00, 38.17it/s][0m
[Train( 277/ 1024)-Total: 0.328|Labeled: 0.034|Unlabeled: 0.29


KeyboardInterrupt

