# FixMatch PyTorch Implementation

## 필요한 패키지 다운로드

In [1]:
!pip install colorama

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


## 필요한 패키지 중 이미 다운 받아진 패키지 부르기

In [19]:
import sys, os, copy, random, argparse, math
import numpy as np

import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets
from torchvision import transforms
from colorama import Fore
from tqdm import tqdm


## Global variable 정의하기
#### PARAMETER_MAX, cifar10의 mean, std

In [20]:
PARAMETER_MAX = 10

mean_cifar10 = (0.4914, 0.4822, 0.4465)
std_cifar10 = (0.2471, 0.2345, 0.2616)

In [21]:
def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX


def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)

## PIL 패키지 내 각종 Data Augmentation 함수 정의
#### https://pillow.readthedocs.io/en/stable/reference/ImageOps.html
#### https://pillow.readthedocs.io/en/stable/reference/ImageEnhance.html


In [22]:
def AutoContrast(img, **kwargs):
    return PIL.ImageOps.autocontrast(img)


def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)


def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def CutoutAbs(img, v, **kwargs):
    w, h = img.size
    x0, y0 = np.random.uniform(0, w), np.random.uniform(0, h)
    x0, y0 = int(max(0, x0 - v / 2.)), int(max(0, y0 - v / 2.))

    x1, y1 = int(min(w, x0 + v)), int(min(h, y0 + v))

    xy = (x0, y0, x1, y1)
    # gray
    color = (127, 127, 127)
    img = img.copy()
    
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def Cutout(img, v, max_v, bias=0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)


def Equalize(img, **kwargs):
    return PIL.ImageOps.equalize(img)


def Identity(img, **kwargs):
    return img


def Invert(img, **kwargs):
    return PIL.ImageOps.invert(img)

def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)


def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)


def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)


def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

## FixMatch 내 RandAugment를 사용하기 위한 PIL 기반 각종 Augmentation 정의

In [23]:
def fixmatch_augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs

## RandAugment를 위한 class 정의

In [24]:
class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        # self.augment_pool 내 n개를 선택
        # 순차적으로 augmentation 기법 추출 후 적용
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)

        # 최종적으로 CutOut Augmentation 진행
        img = CutoutAbs(img, int(32*0.5))
        return img

In [25]:
class CIFAR10_SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                transform=None, target_transform=None,
                download=False):
        super(CIFAR10_SSL, self).__init__(
            root, train=train, transform=transform,
            target_transform=target_transform, download=download
        )

        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
    
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target

In [26]:
# Mixmatch의 transform twice
class TransformFixMatch(object):
    def __init__(self, mean=mean_cifar10, std=std_cifar10):
        self.weak_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                padding=int(32*0.125),
                                padding_mode='reflect')
        ])

        self.strong_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                padding=int(32*0.125),
                                padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)
            # 기존 weak augmentation에 추가적으로 픽셀 값 왜곡과 같은 Strong augmentation 적용
        ])

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ]) # NHWC -> NCHW / 평균, 표준편차를 사용한 표준화
    
    def __call__(self, x):
        weak = self.weak_transform(x)
        strong = self.strong_transform(x)

        return self.normalize(weak), self.normalize(strong)

In [27]:
# MixMatch 코드 내 split 함수와 동일
def split_labeled_unlabeled(args, labels):
    label_per_class = args.n_labeled // args.n_classes
    labels = np.array(labels, dtype=int)
    indice_labeled, indice_unlabeled, indice_val = [], [], []

    for i in range(10):
        indice_tmp = np.where(labels==i)[0]

        indice_labeled.extend(indice_tmp[: label_per_class])
        indice_unlabeled.extend(indice_tmp[label_per_class: -500])
        indice_val.extend(indice_tmp[-500: ])
    
    for i in [indice_labeled, indice_unlabeled, indice_val]:
        np.random.shuffle(i)
    
    return np.array(indice_labeled), np.array(indice_unlabeled), np.array(indice_val)

In [28]:
def get_cifar10(args, data_dir):
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=int(32*0.125), padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean_cifar10, std=std_cifar10)
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean_cifar10, std=std_cifar10)
    ])

    base_dataset = datasets.CIFAR10(data_dir, train=True, download=True)

    indice_labeled, indice_unlabeled, indice_val = split_labeled_unlabeled(args, base_dataset.targets)

    # labeled dataset에 대해서는 transform_labeled augmentation 만 적용
    labeled_dataset = CIFAR10_SSL(
        data_dir, indice_labeled, train=True,
        transform=transform_labeled
    )

    # Unlabeled dataset에 대해서는 transform_labeled augmentation 및 strong augmentation 동시 적용
    unlabeled_dataset = CIFAR10_SSL(
        data_dir, indice_unlabeled, train=True,
        transform=TransformFixMatch(mean=mean_cifar10, std=std_cifar10)
    )

    # validation, test dataset에 대해서는 ToTensor & Normalization transformation 만 적용
    val_dataset = CIFAR10_SSL(
        data_dir, indice_val, train=True, transform=transform_val, download=False
    )

    test_dataset = datasets.CIFAR10(
        data_dir, train=False, transform=transform_val, download=False
    )
    return labeled_dataset, unlabeled_dataset, val_dataset, test_dataset

## WideResNet (MixMatch 와 동일)

In [29]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
        self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
        self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
        self.activate_before_residual = activate_before_residual
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

In [30]:
class WeightEMA(object): # EMA=Exponential Moving Average
    def __init__(self, model, decay):
        self.ema = copy.deepcopy(model)
        self.ema.eval()

        self.decay = decay

        self.ema_has_module = hasattr(self.ema, 'module')

        self.param_keys = [k for k, _ in self.ema.named_parameters()]
        self.buffer_keys = [k for k, _ in self.ema.named_buffers()]
        for p in self.ema.parameters():
            p.requires_grad_(False)

    # summary: ema_params_new = self.decay*ema_params_old + (1-self.decay)*params
    # MixMatch와 hyperparameter 이름만 변경
    def step(self, model):
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            esd = self.ema.state_dict()
            for k in self.param_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                model_v = msd[j].detach()
                ema_v = esd[k]
                esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v)

            for k in self.buffer_keys:
                if needs_module:
                    j = 'module.' + k
                else:
                    j = k
                esd[k].copy_(msd[j])

In [31]:
def accuracy(output, target, topk=(1, )):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        if k == 1:
            correct_k = correct[:k].view(-1).float().sum(0)
        if k > 1:
            correct_k = correct[:k].float().sum(0).sum(0)
        acc = correct_k.mul_(100.0 / batch_size)
        acc = acc.detach().cpu().numpy()
        res.append(acc)
    return res

In [32]:
def get_tqdm_config(total, leave=True, color='white'):
    fore_colors = {
        'red': Fore.LIGHTRED_EX,
        'green': Fore.LIGHTGREEN_EX,
        'yellow': Fore.LIGHTYELLOW_EX,
        'blue': Fore.LIGHTBLUE_EX,
        'magenta': Fore.LIGHTMAGENTA_EX,
        'cyan': Fore.LIGHTCYAN_EX,
        'white': Fore.LIGHTWHITE_EX,
    }
    return {
        'file': sys.stdout,
        'total': total,
        'desc': " ",
        'dynamic_ncols': True,
        'bar_format':
            "{l_bar}%s{bar}%s| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" % (fore_colors[color], Fore.RESET),
        'leave': leave
    }

In [33]:
def get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps,
    num_cycles=7.0/16.0, last_epoch=-1
    ):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step)/float(max(1, num_warmup_steps))
        
        no_progress = float(current_step-num_warmup_steps)/\
            (float(max(1, num_training_steps-num_warmup_steps)))
        return max(0.0, math.cos(math.pi*num_cycles*no_progress))
    
    return LambdaLR(optimizer, _lr_lambda, last_epoch)

<img src='image/img7.png' width='800'></img>

In [34]:
class FixMatchTrainer():
    def __init__(self, args):
        self.args = args

        # root_dir = '/content/FixMatch' # PROJECT directory
        root_dir = './FixMatch'
        data_dir = os.path.join(root_dir, 'data')

        self.experiment_dir = os.path.join(root_dir, 'results') # 학습된 모델을 저장할 폴더 경로 정의 및 폴더 생성
        os.makedirs(self.experiment_dir, exist_ok=True)

        name_exp = "_".join([str(self.args.n_labeled), str(self.args.T)]) # 주요 하이퍼 파라미터로 폴더 저장 경로 지정 
        self.experiment_dir = os.path.join(self.experiment_dir, name_exp)
        os.makedirs(self.experiment_dir, exist_ok=True)

        print("==> Preparing CIFAR10 dataset")
        labeled_set, unlabeled_set, _, _= get_cifar10(self.args, data_dir=data_dir)
        
        # RandomSampler: DataLoader(shuffle=True) 와 동일한 역할
        # SequentialSampler: DataLoader(shuffle=False) 와 동일한 역할

        # DataLoader 내 num_workers 옵션에 대한 사설
        # Window10는 다중 CPU 코어 사용 시 순차적으로 작동 시작
        # Linux(Ubuntu, CentOS) 계열은 동시에 CPU 코어 작동을 시작 가능
        # (Windows10+PyTorch)를 사용해 Deep Learning 모델 학습 시 num_workers=0을 사용하는 것을 권유
        # (Linux계열 운영체제+PyTorch)를 사용해 Deep Learning 모델 학습 시 CPU&GPU 사용량이 최대가 될 수 있도록 num_workers 조정 권유
        self.labeled_loader = DataLoader(
            labeled_set,
            sampler=RandomSampler(labeled_set),
            batch_size=self.args.batch_size,
            num_workers=0,
            drop_last=True
        )

        self.unlabeled_loader = DataLoader(
            unlabeled_set,
            sampler=RandomSampler(unlabeled_set),
            batch_size=self.args.batch_size,
            num_workers=0,
            drop_last=True
        )

        # Build WideResNet
        print("==> Preparing WideResNet")
        self.model = WideResNet(self.args.n_classes).to(self.args.cuda)
        self.model.zero_grad()

        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.cuda)

        # Hyperparamters
        no_decay = ['bias', 'bn']
        grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] # params의 이름 내 bias, bn이 들어가지 않는 경우에만 weight_decay 적용
        self.optimizer = torch.optim.SGD(grouped_parameters, lr=self.args.lr,
                            momentum=0.9, nesterov=self.args.nesterov)

        # Learning scheduler의 경우 사용이 까다로움
        # 특정 scheduler는 각각 iteration 마다 step을 진행
        # 또 다른 scheduler그룹은 한 epoch 종료 후 step 진행
        # 아래 Documentation 에서 사용할 lr_scheduler에 대한 설명을 정확히 읽고 사용
        # https://pytorch.org/docs/stable/optim.html
        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                    self.args.warmup,
                                                    self.args.total_steps)
        
        if self.args.use_ema:
            self.ema_model = WeightEMA(self.model, self.args.ema_decay)
        
        self.writer = SummaryWriter(self.experiment_dir)

    def train(self, epoch):
        losses_t, losses_x, losses_u, mask_probs = 0.0, 0.0, 0.0, 0.0
        
        self.model.train()
        iter_labeled = iter(self.labeled_loader)
        iter_unlabeled = iter(self.unlabeled_loader)

        with tqdm(**get_tqdm_config(total=self.args.eval_step,
                leave=True, color='blue')) as pbar:
            for batch_idx in range(self.args.eval_step):
                # 왜 try-except 문을 사용하나?
                # 코드 작성 후 iter&next가 정확히 작용하지 않는 경우가 있음을 확인
                # 다시 iter_labeled, iter_unlabeled를 정의해 학습에 문제가 없도록 다시 선언
                try:
                    inputs_x, targets_x = next(iter_labeled)
                except:
                    iter_labeled = iter(self.labeled_loader)
                    inputs_x, targets_x = next(iter_labeled)
                real_B = inputs_x.size(0)

                try:
                    (inputs_u_w, inputs_u_s), _ = next(iter_unlabeled)
                except:
                    iter_unlabeled = iter(self.unlabeled_loader)
                    (inputs_u_w, inputs_u_s), _ = next(iter_unlabeled)
                
                inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s), dim=0).to(self.args.cuda)
                targets_x = targets_x.to(self.args.cuda)

                logits = self.model(inputs)

                logits_x = logits[:real_B]
                logits_u_w, logits_u_s = logits[real_B:].chunk(2)
                del(logits)

                # Labeled loss
                loss_x = F.cross_entropy(logits_x, targets_x, reduction='mean')

                # Unlabeled loss
                # Unlabeled 데이터에 대한 로짓 산출 및 Temparature hyperparameter를 사용한 Sharpening
                # Pseudo label 생성
                pseudo_labels = torch.softmax(logits_u_w.detach()/self.args.T, dim=-1)
                max_prob, targets_u = torch.max(pseudo_labels, dim=-1) # (max, max_indices)
                mask = max_prob.ge(self.args.threshold).float() # input >= other element wise 
                # max_prob >= thr -> mask = True
                # max < thr -> maks = False

                # Pseudo label 과 strong augmentation된 이미지에서 산출된 logit 사이 cross_entropy 계산
                loss_u = (F.cross_entropy(logits_u_s, targets_u, reduction='none')*mask).mean()

                # Total loss
                loss = loss_x + self.args.lambda_u * loss_u 
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                if self.args.use_ema:
                    self.ema_model.step(self.model)
                
                self.model.zero_grad()

                losses_x += loss_x.item()
                losses_u += loss_u.item()
                losses_t += loss.item()
                mask_probs += max_prob.mean().item()

                self.writer.add_scalars(
                    'Training steps', {
                        'Total_loss': losses_t/(batch_idx+1),
                        'Labeled_loss':losses_x/(batch_idx+1),
                        'Unlabeled_loss':losses_u/(batch_idx+1),
                        'Mask probs': mask_probs/(batch_idx+1)
                    }, global_step=epoch*self.args.batch_size+batch_idx
                )

                pbar.set_description(
                    '[Train(%4d/ %4d)-Total: %.3f|Labeled: %.3f|Unlabeled: %.3f]'%(
                        (batch_idx+1), self.args.eval_step,
                        losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)
                    )
                )
                pbar.update(1)

            pbar.set_description(
                '[Train(%4d/ %4d)-Total: %.3f|Labeled: %.3f|Unlabeled: %.3f]'%(
                    epoch, self.args.epochs,
                    losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)
                )
            )
        return losses_t/(batch_idx+1), losses_x/(batch_idx+1), losses_u/(batch_idx+1)

In [35]:
def FixMatch_parser():
    parser = argparse.ArgumentParser(description="FixMatch PyTorch Implementation for LG Electornics education")
    
    # method arguments
    parser.add_argument('--n-labeled', type=int, default=3000)
    parser.add_argument('--n-classes', type=int, default=10)
    parser.add_argument("--expand-labels", action="store_true",
                        help="expand labels to fit eval steps")

    # training hyperparameters
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--total-steps', default=2**20, type=int)
    parser.add_argument('--eval-step', type=int, default=1024)
    parser.add_argument('--lr', type=float, default=0.03)
    parser.add_argument('--weight-decay', type=float, default=5e-4)
    parser.add_argument('--nesterov', action='store_true', default=True)
    parser.add_argument('--warmup', type=float, default=0.0)

    parser.add_argument('--use-ema', action='store_true', default=True)
    parser.add_argument('--ema-decay', type=float, default=0.999)

    parser.add_argument('--mu', type=int, default=7)
    parser.add_argument('--T', type=float, default=1.0)

    parser.add_argument('--threshold', type=float, default=0.95)
    parser.add_argument('--lambda-u', type=float, default=1.0)
    return parser

In [36]:
parser = FixMatch_parser()
args = parser.parse_args([])
args.cuda = torch.device("cuda:0")

# args.epochs = math.ceil(args.total_steps/args.eval_step)
args.epochs = 100

trainer = FixMatchTrainer(args)

best_loss = np.inf

losses, losses_x, losses_u = [], [], []

train_losses, train_top1s, train_top5s = [], [], []
val_losses, val_top1s, val_top5s = [], [], []
test_losses, test_top1s, test_top5s = [], [], []
for epoch in range(1, args.epochs+1, 1):
    loss, loss_x, loss_u = trainer.train(epoch)
    losses.append(loss)
    losses_x.append(loss_x)
    losses_u.append(loss_u)

    torch.save(trainer.model.state_dict(), os.path.join(trainer.experiment_dir, 'model.pth'))
    torch.save(trainer.ema_model.ema.state_dict(), os.path.join(trainer.experiment_dir, 'ema_model.pth'))


==> Preparing CIFAR10 dataset
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./FixMatch/data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13243014.46it/s]


Extracting ./FixMatch/data/cifar-10-python.tar.gz to ./FixMatch/data
==> Preparing WideResNet
[Train(   1/  100)-Total: 1.242|Labeled: 1.198|Unlabeled: 0.044]: 100%|[94m██████████[39m| [01:06<00:00, 15.34it/s]
[Train(   2/  100)-Total: 0.778|Labeled: 0.618|Unlabeled: 0.160]: 100%|[94m██████████[39m| [01:06<00:00, 15.41it/s]
[Train(   3/  100)-Total: 0.625|Labeled: 0.364|Unlabeled: 0.262]: 100%|[94m██████████[39m| [01:06<00:00, 15.36it/s]
[Train(   4/  100)-Total: 0.556|Labeled: 0.240|Unlabeled: 0.316]: 100%|[94m██████████[39m| [01:06<00:00, 15.31it/s]
[Train(   5/  100)-Total: 0.508|Labeled: 0.177|Unlabeled: 0.331]: 100%|[94m██████████[39m| [01:06<00:00, 15.30it/s]
[Train(   6/  100)-Total: 0.490|Labeled: 0.145|Unlabeled: 0.344]: 100%|[94m██████████[39m| [01:06<00:00, 15.37it/s]
[Train(   7/  100)-Total: 0.462|Labeled: 0.116|Unlabeled: 0.346]: 100%|[94m██████████[39m| [01:06<00:00, 15.33it/s]
[Train(   8/  100)-Total: 0.456|Labeled: 0.106|Unlabeled: 0.349]: 100%|[94m████