# Deep Learning Homework 2 Part B

In [1]:
# check whether the torch cuda is ok
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
%matplotlib inline

In [2]:
import torch.nn as nn
from torchvision import models
def create_model(args):
    model = models.resnet18(pretrained=False)
    in_features = model.fc.in_features
    model.fc=nn.Linear(in_features,args.num_classes)
    model.name = "resnet18"
    print(args.device)
    model.to(args.device)
    criteria_x = nn.CrossEntropyLoss().to(args.device)
    criteria_u = nn.CrossEntropyLoss(reduction='none').to(args.device)
    return model,criteria_x,criteria_u

In [3]:
import types 
args=types.SimpleNamespace()

args.gpu_id = 1
args.num_workers = 4
args.dataset = 'dataset'
args.num_labeled = 250
args.expand_labels = True
args.arch = 'resnet'
args.total_steps = 2**20
args.eval_steps = 1024
args.start_epoch = 0
args.batch_size = 16
args.lr = 0.03
args.warmu = 0
args.wdecay = 5e-4
args.momentum = 0.9
args.nesterov = True
args.use_ema = True
args.ema_decay = 0.999
args.mu = 7
args.lambda_u = 1
args.T = 1
args.threshold = 0.95 
args.out = 'result'
args.resume = 'resume'
args.seed = None
# args.amp = True
# args.opt_level = '01'
args.local_rank = -1
args.no_progress = False
args.input_size = 224
args.num_classes = 10
args.num_images_per_epoch = 2**14
args.num_epoches = 128


args

namespace(gpu_id=1,
          num_workers=4,
          dataset='dataset',
          num_labeled=250,
          expand_labels=True,
          arch='resnet',
          total_steps=1048576,
          eval_steps=1024,
          start_epoch=0,
          batch_size=16,
          lr=0.03,
          warmu=0,
          wdecay=0.0005,
          momentum=0.9,
          nesterov=True,
          use_ema=True,
          ema_decay=0.999,
          mu=7,
          lambda_u=1,
          T=1,
          threshold=0.95,
          out='result',
          resume='resume',
          seed=None,
          local_rank=-1,
          no_progress=False,
          input_size=224,
          num_classes=10,
          num_images_per_epoch=16384,
          num_epoches=128)

In [4]:
if args.local_rank == -1:
    device = torch.device('cuda', args.gpu_id)
    args.world_size = 1
    args.n_gpu = torch.cuda.device_count()
else:
    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda', args.local_rank)
    torch.distributed.init_process_group(backend='nccl')
    args.world_size = torch.distributed.get_world_size()
    args.n_gpu = 1

args.device = device

In [5]:
import cv2
import numpy as np


## aug functions
def identity_func(img):
    return img


def autocontrast_func(img, cutoff=0):
    '''
        same output as PIL.ImageOps.autocontrast
    '''
    n_bins = 256

    def tune_channel(ch):
        n = ch.size
        cut = cutoff * n // 100
        if cut == 0:
            high, low = ch.max(), ch.min()
        else:
            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
            low = np.argwhere(np.cumsum(hist) > cut)
            low = 0 if low.shape[0] == 0 else low[0]
            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
        if high <= low:
            table = np.arange(n_bins)
        else:
            scale = (n_bins - 1) / (high - low)
            offset = -low * scale
            table = np.arange(n_bins) * scale + offset
            table[table < 0] = 0
            table[table > n_bins - 1] = n_bins - 1
        table = table.clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def equalize_func(img):
    '''
        same output as PIL.ImageOps.equalize
        PIL's implementation is different from cv2.equalize
    '''
    n_bins = 256

    def tune_channel(ch):
        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
        non_zero_hist = hist[hist != 0].reshape(-1)
        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
        if step == 0: return ch
        n = np.empty_like(hist)
        n[0] = step // 2
        n[1:] = hist[:-1]
        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def rotate_func(img, degree, fill=(0, 0, 0)):
    '''
    like PIL, rotate by degree, not radians
    '''
    H, W = img.shape[0], img.shape[1]
    center = W / 2, H / 2
    M = cv2.getRotationMatrix2D(center, degree, 1)
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
    return out


def solarize_func(img, thresh=128):
    '''
        same output as PIL.ImageOps.posterize
    '''
    table = np.array([el if el < thresh else 255 - el for el in range(256)])
    table = table.clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def color_func(img, factor):
    '''
        same output as PIL.ImageEnhance.Color
    '''
    ## implementation according to PIL definition, quite slow
    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
    #  out = blend(degenerate, img, factor)
    #  M = (
    #      np.eye(3) * factor
    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
    #  )[np.newaxis, np.newaxis, :]
    M = (
            np.float32([
                [0.886, -0.114, -0.114],
                [-0.587, 0.413, -0.587],
                [-0.299, -0.299, 0.701]]) * factor
            + np.float32([[0.114], [0.587], [0.299]])
    )
    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
    return out


def contrast_func(img, factor):
    """
        same output as PIL.ImageEnhance.Contrast
    """
    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
    table = np.array([(
        el - mean) * factor + mean
        for el in range(256)
    ]).clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def brightness_func(img, factor):
    '''
        same output as PIL.ImageEnhance.Contrast
    '''
    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def sharpness_func(img, factor):
    '''
    The differences the this result and PIL are all on the 4 boundaries, the center
    areas are same
    '''
    kernel = np.ones((3, 3), dtype=np.float32)
    kernel[1][1] = 5
    kernel /= 13
    degenerate = cv2.filter2D(img, -1, kernel)
    if factor == 0.0:
        out = degenerate
    elif factor == 1.0:
        out = img
    else:
        out = img.astype(np.float32)
        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
        out = out.astype(np.uint8)
    return out


def shear_x_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, factor, 0], [0, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def translate_x_func(img, offset, fill=(0, 0, 0)):
    '''
        same output as PIL.Image.transform
    '''
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, -offset], [0, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def translate_y_func(img, offset, fill=(0, 0, 0)):
    '''
        same output as PIL.Image.transform
    '''
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [0, 1, -offset]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def posterize_func(img, bits):
    '''
        same output as PIL.ImageOps.posterize
    '''
    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
    return out


def shear_y_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [factor, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def cutout_func(img, pad_size, replace=(0, 0, 0)):
    replace = np.array(replace, dtype=np.uint8)
    H, W = img.shape[0], img.shape[1]
    rh, rw = np.random.random(2)
    pad_size = pad_size // 2
    ch, cw = int(rh * H), int(rw * W)
    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
    out = img.copy()
    out[x1:x2, y1:y2, :] = replace
    return out


### level to args
def enhance_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        return ((level / MAX_LEVEL) * 1.8 + 0.1,)
    return level_to_args


def shear_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 0.3
        if np.random.random() > 0.5: level = -level
        return (level, replace_value)

    return level_to_args


def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * float(translate_const)
        if np.random.random() > 0.5: level = -level
        return (level, replace_value)

    return level_to_args


def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * cutout_const)
        return (level, replace_value)

    return level_to_args


def solarize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 256)
        return (level, )
    return level_to_args


def none_level_to_args(level):
    return ()


def posterize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 4)
        return (level, )
    return level_to_args


def rotate_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 30
        if np.random.random() < 0.5:
            level = -level
        return (level, replace_value)

    return level_to_args


func_dict = {
    'Identity': identity_func,
    'AutoContrast': autocontrast_func,
    'Equalize': equalize_func,
    'Rotate': rotate_func,
    'Solarize': solarize_func,
    'Color': color_func,
    'Contrast': contrast_func,
    'Brightness': brightness_func,
    'Sharpness': sharpness_func,
    'ShearX': shear_x_func,
    'TranslateX': translate_x_func,
    'TranslateY': translate_y_func,
    'Posterize': posterize_func,
    'ShearY': shear_y_func,
}

translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
    'Identity': none_level_to_args,
    'AutoContrast': none_level_to_args,
    'Equalize': none_level_to_args,
    'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
    'Solarize': solarize_level_to_args(MAX_LEVEL),
    'Color': enhance_level_to_args(MAX_LEVEL),
    'Contrast': enhance_level_to_args(MAX_LEVEL),
    'Brightness': enhance_level_to_args(MAX_LEVEL),
    'Sharpness': enhance_level_to_args(MAX_LEVEL),
    'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
    'TranslateX': translate_level_to_args(
        translate_const, MAX_LEVEL, replace_value
    ),
    'TranslateY': translate_level_to_args(
        translate_const, MAX_LEVEL, replace_value
    ),
    'Posterize': posterize_level_to_args(MAX_LEVEL),
    'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}


class RandomAugment(object):

    def __init__(self, N=2, M=10):
        self.N = N
        self.M = M

    def get_random_ops(self):
        sampled_ops = np.random.choice(list(func_dict.keys()), self.N)
        return [(op, 0.5, self.M) for op in sampled_ops]

    def __call__(self, img):
        img = np.array(img)
        ops = self.get_random_ops()
        for name, prob, level in ops:
            if np.random.random() > prob:
                continue
            args = arg_dict[name](level)
            img = func_dict[name](img, *args)
        img = cutout_func(img, 16, replace_value)
        return img

In [6]:
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
class LabeledDataset(Dataset):
    def __init__(self,dataset:Dataset,args,is_train=True):
        super().__init__()
        self.dataset=dataset
        self.is_train=is_train
        self.input_size = args.input_size
        if self.is_train:
            self.trans_weak = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            self.trans_strong = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                RandomAugment(2,10),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.trans = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
    def __getitem__(self,index):
        image,label = self.dataset[index]
        # image=np.array(image)
        # print(image.shape)
        if self.is_train:
            return self.trans_weak(image),self.trans_strong(image),label
        else:
            return self.trans(image),label
    def __len__(self):
        return len(self.dataset)

class UnlabeledDataset(Dataset):
    def __init__(self,image_dir,args,is_train=True) ->None:
        super().__init__()
        self.image_dir = image_dir
        self.image_name =os.listdir(self.image_dir)
        self.is_train = is_train
        self.input_size = args.input_size
        if self.is_train:
            self.trans_weak = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            self.trans_strong = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                RandomAugment(2,10),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.trans = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __getitem__(self,index):
        image_item_name = self.image_name[index]
        image_item_path = os.path.join(self.image_dir,image_item_name)
        image = Image.open(image_item_path)
        # image=np.array(image)
        label = 'unlabeled'
        if self.is_train:
            return self.trans_weak(image),self.trans_strong(image),label
        else:
            return self.trans(image),label

    def __len__(self):
        return len(self.image_name)



In [7]:
import torch
import torch.distributed as dist

"""
为什么 param 和 buffer 要采用不同的的更新策略
param 是 指数移动平均数，buffer 不是
"""


class EMA(object):
    def __init__(self, model, alpha=0.999):
        self.step = 0
        self.model = model
        self.alpha = alpha
        self.shadow = self.get_model_state()
        self.backup = {}
        self.param_keys = [k for k, _ in self.model.named_parameters()]
        # num_batches_tracked, running_mean, running_var in bn
        self.buffer_keys = [k for k, _ in self.model.named_buffers()]

    def update_params(self):
        # decay = min(self.alpha, (self.step + 1) / (self.step + 10))  # ????
        decay = self.alpha
        state = self.model.state_dict()  # current params
        for name in self.param_keys:
            self.shadow[name].copy_(
                decay * self.shadow[name] + (1 - decay) * state[name]
            )
        # for name in self.buffer_keys:
        #     self.shadow[name].copy_(
        #         decay * self.shadow[name]
        #         + (1 - decay) * state[name]
        #     )

        self.step += 1

    def update_buffer(self):
        # without EMA
        state = self.model.state_dict()
        for name in self.buffer_keys:
            self.shadow[name].copy_(state[name])

    def apply_shadow(self):
        self.backup = self.get_model_state()
        self.model.load_state_dict(self.shadow)

    def restore(self):
        self.model.load_state_dict(self.backup)

    def get_model_state(self):
        return {
            k: v.clone().detach()
            for k, v in self.model.state_dict().items()
        }



In [8]:
import torch
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from torch.utils.data.distributed import DistributedSampler

def get_train_loader(args):
    num_iters_per_epoch = args.num_images_per_epoch // args.batch_size
    datasetPath = os.path.join(args.dataset,'3-Semi-Supervised')
    labeledDatasetPath = os.path.join(datasetPath,'labeled')
    unlabeledDatasetPath = os.path.join(datasetPath,'unlabeled')
    labeledDataset = datasets.ImageFolder(labeledDatasetPath)
    unlabeledDataset = UnlabeledDataset(unlabeledDatasetPath,args)
    labeledDataset = LabeledDataset(labeledDataset,args)
    sampler_x = RandomSampler(labeledDataset, replacement=True, num_samples=num_iters_per_epoch * args.batch_size)
    batch_sampler_x = BatchSampler(sampler_x, args.batch_size, drop_last=True)  # yield a batch of samples one time
    labeledDatasetDataloader = DataLoader(
        labeledDataset,
        batch_sampler=batch_sampler_x,
        num_workers = args.num_workers
    )
    sampler_u = RandomSampler(unlabeledDataset, replacement=True, num_samples=args.mu * num_iters_per_epoch * args.batch_size)
    batch_sampler_u = BatchSampler(sampler_u, args.batch_size * args.mu, drop_last=True)
    unlabeledDatasetDataloader = DataLoader(
        unlabeledDataset,
        batch_sampler=batch_sampler_u,
        num_workers = args.num_workers
    )
    return labeledDatasetDataloader,unlabeledDatasetDataloader


def get_valid_loader(args):
    testDataset = datasets.ImageFolder(os.path.join(args.dataset,'test'))
    testDataset = LabeledDataset(testDataset,args,is_train=False)
    validDatasetDataloader = DataLoader(
        testDataset,
        shuffle=False,
        drop_last=False,
        batch_size = 64,
        num_workers = args.num_workers
    )
    return validDatasetDataloader




In [9]:
import math

import torch
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
import numpy as np


class WarmupExpLrScheduler(_LRScheduler):
    def __init__(
            self,
            optimizer,
            power,
            step_interval=1,
            warmup_iter=500,
            warmup_ratio=5e-4,
            warmup='exp',
            last_epoch=-1,
    ):
        self.power = power
        self.step_interval = step_interval
        self.warmup_iter = warmup_iter
        self.warmup_ratio = warmup_ratio
        self.warmup = warmup
        super(WarmupExpLrScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        ratio = self.get_lr_ratio()
        lrs = [ratio * lr for lr in self.base_lrs]
        return lrs

    def get_lr_ratio(self):
        if self.last_epoch < self.warmup_iter:
            ratio = self.get_warmup_ratio()
        else:
            real_iter = self.last_epoch - self.warmup_iter
            ratio = self.power ** (real_iter // self.step_interval)
        return ratio

    def get_warmup_ratio(self):
        assert self.warmup in ('linear', 'exp')
        alpha = self.last_epoch / self.warmup_iter
        if self.warmup == 'linear':
            ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
        elif self.warmup == 'exp':
            ratio = self.warmup_ratio ** (1. - alpha)
        return ratio


class WarmupPolyLrScheduler(_LRScheduler):
    def __init__(
            self,
            optimizer,
            power,
            max_iter,
            warmup_iter,
            warmup_ratio=5e-4,
            warmup='exp',
            last_epoch=-1,
    ):
        self.power = power
        self.max_iter = max_iter
        self.warmup_iter = warmup_iter
        self.warmup_ratio = warmup_ratio
        self.warmup = warmup
        super(WarmupPolyLrScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        ratio = self.get_lr_ratio()
        lrs = [ratio * lr for lr in self.base_lrs]
        return lrs

    def get_lr_ratio(self):
        if self.last_epoch < self.warmup_iter:
            ratio = self.get_warmup_ratio()
        else:
            real_iter = self.last_epoch - self.warmup_iter
            real_max_iter = self.max_iter - self.warmup_iter
            alpha = real_iter / real_max_iter
            ratio = (1 - alpha) ** self.power
        return ratio

    def get_warmup_ratio(self):
        assert self.warmup in ('linear', 'exp')
        alpha = self.last_epoch / self.warmup_iter
        if self.warmup == 'linear':
            ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
        elif self.warmup == 'exp':
            ratio = self.warmup_ratio ** (1. - alpha)
        return ratio


class WarmupCosineLrScheduler(_LRScheduler):
    '''
    This is different from official definition, this is implemented according to
    the paper of fix-match
    '''
    def __init__(
            self,
            optimizer,
            max_iter,
            warmup_iter,
            warmup_ratio=5e-4,
            warmup='exp',
            last_epoch=-1,
    ):
        self.max_iter = max_iter
        self.warmup_iter = warmup_iter
        self.warmup_ratio = warmup_ratio
        self.warmup = warmup
        super(WarmupCosineLrScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        ratio = self.get_lr_ratio()
        lrs = [ratio * lr for lr in self.base_lrs]
        return lrs

    def get_lr_ratio(self):
        if self.last_epoch < self.warmup_iter:
            ratio = self.get_warmup_ratio()
        else:
            real_iter = self.last_epoch - self.warmup_iter
            real_max_iter = self.max_iter - self.warmup_iter
            ratio = np.cos((7 * np.pi * real_iter) / (16 * real_max_iter))
        return ratio

    def get_warmup_ratio(self):
        assert self.warmup in ('linear', 'exp')
        alpha = self.last_epoch / self.warmup_iter
        if self.warmup == 'linear':
            ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
        elif self.warmup == 'exp':
            ratio = self.warmup_ratio ** (1. - alpha)
        return ratio


# from Fixmatch-pytorch
def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        # return max(0., math.cos(math.pi * num_cycles * no_progress))

        return max(0., (math.cos(math.pi * num_cycles * no_progress) + 1) * 0.5)

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

In [10]:
from datetime import datetime
import logging
import os
import sys
import torch

# from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


def interleave(x, bt):
    s = list(x.shape)
    return torch.reshape(torch.transpose(x.reshape([-1, bt] + s[1:]), 1, 0), [-1] + s[1:])


def de_interleave(x, bt):
    s = list(x.shape)
    return torch.reshape(torch.transpose(x.reshape([bt, -1] + s[1:]), 1, 0), [-1] + s[1:])


def setup_default_logging(args, default_level=logging.INFO,
                          format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s"):
    output_dir = os.path.join(args.dataset, f'x{args.num_labeled}')
    os.makedirs(output_dir, exist_ok=True)

    writer = SummaryWriter(comment=f'{args.dataset}_{args.num_labeled}')

    logger = logging.getLogger('train')

    logging.basicConfig(  # unlike the root logger, a custom logger can’t be configured using basicConfig()
        filename=os.path.join(output_dir, f'{time_str()}.log'),
        format=format,
        datefmt="%m/%d/%Y %H:%M:%S",
        level=default_level)

    # print
    # file_handler = logging.FileHandler()
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(default_level)
    console_handler.setFormatter(logging.Formatter(format))
    logger.addHandler(console_handler)

    return logger, writer


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    
    _, pred = output.topk(maxk, 1, largest=True, sorted=True)  # return value, indices
    # print(pred)
    pred = pred.t()
    # print(target)
    # print(target.view(1, -1).expand_as(pred))
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    # print(correct)
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        # self.avg = self.sum / (self.count + 1e-20)
        self.avg = self.sum / self.count


def time_str(fmt=None):
    if fmt is None:
        fmt = '%Y-%m-%d_%H:%M:%S'

    #     time.strftime(format[, t])
    return datetime.today().strftime(fmt)


In [11]:
from __future__ import print_function
import random
import time
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def train_one_epoch(
    epoch,
    model,
    criteria_x,
    criteria_u,
    optimizer,
    lr_schdlr,
    ema,
    labeledDatasetLoader,
    unlabeledDatasetLoader,
    lambda_u,
    n_iters,
    args,
    logger
    ):
    model.train(True)
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    # number of gradient-considered strong augmentation of unlabeled samples
    n_strong_aug_meter = AverageMeter()
    mask_meter = AverageMeter()
    epoch_start = time.time()

    labeledIter,unlabeledIter=iter(labeledDatasetLoader),iter(unlabeledDatasetLoader)
    for it in range(n_iters):
        image_labeled_weak ,image_labeled_strong, labeled_label = next(labeledIter)
        image_unlabeled_weak,image_unlabeled_strong,_ =next(unlabeledIter)
        labeled_label = labeled_label.to(args.device)
        # print(labeled_label)
        batch_size = image_labeled_weak.size(0)
        mu = int(image_unlabeled_weak.size(0)//batch_size)
        imgs = torch.cat([image_labeled_weak,image_unlabeled_weak,image_unlabeled_strong],dim=0).to(args.device)
        imgs = interleave(imgs,2*mu+1)
        logits = model(imgs)
        logits = de_interleave(logits,2*mu+1)
        logits_x = logits[:batch_size]
        logits_unlabeled_w,logits_unlabeled_s = torch.split(logits[batch_size:],batch_size*mu)
        loss_x = criteria_x(logits_x,labeled_label)
        with torch.no_grad():
            probs = torch.softmax(logits_unlabeled_w,dim=1)
            scores,labels_unlabeled_guess = torch.max(probs,dim=1)
            mask =scores.ge(args.threshold).float()
        loss_u = (criteria_u(logits_unlabeled_s,labels_unlabeled_guess)*mask).mean()
        loss = loss_x+ lambda_u *loss_u
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.update_params()
        lr_schdlr.step()
        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        mask_meter.update(mask.mean().item())
        n_strong_aug_meter.update(mask.sum().item())
        if (it + 1) % 128 == 0:
            t = time.time() - epoch_start
            lr_log = [pg['lr'] for pg in optimizer.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. "
                        "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
                epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))

            epoch_start = time.time()
    ema.update_buffer()
    return loss_meter.avg,loss_x_meter.avg,loss_u_meter.avg,mask_meter.avg


In [12]:
import torch
def evaluate(ema,dataloader,criterion,args):
    # ema params to evaluate performance
    ema.apply_shadow()
    ema.model.to(args.device)
    ema.model.eval()

    loss_meter= AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    with torch.no_grad():
        for images,labels in dataloader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            logits = ema.model(images)
            # print(images.shape,labels.shape)
            loss = criterion(logits,labels)
            scores = torch.softmax(logits,dim=1)
            top1,top5 = accuracy(scores,labels,(1,5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    ema.restore()
    return top1_meter.avg,top5_meter.avg,loss_meter.avg
    

In [13]:
model,criteria_x,criteria_u = create_model(args)
num_iters_per_epoch = args.num_images_per_epoch // args.batch_size
num_iters_all = num_iters_per_epoch * args.num_epoches
ema = EMA(model,args.ema_decay)
wd_params,non_wd_params = [],[]
for name,param in model.named_parameters():
    if 'bn' in name:
        non_wd_params.append(param)
    else:
        wd_params.append(param)
# print(len(wd_params),len(non_wd_params))
param_list = [
    {
        'params':wd_params
    },
    {
        'params':non_wd_params,
        'weight_decay': 0
    }
]
optimizer = torch.optim.SGD(param_list,lr=args.lr,weight_decay=args.wdecay,momentum=args.momentum,nesterov=args.nesterov)
lr_schdlr = WarmupCosineLrScheduler(
    optimizer,
    max_iter=num_iters_all,
    warmup_iter=0
)


best_acc = -1
best_epoch = 0
logger,writer = setup_default_logging(args)
labeledDatasetDataloader,unlabeledDatasetDataloader=get_train_loader(args)
validDatasetDataloader = get_valid_loader(args)

logger.info("***** Running training *****")
logger.info(f"  Task = {args.dataset}@{args.num_labeled}")
logger.info(f"  Num Epochs = {num_iters_per_epoch}")
logger.info(f"  Batch size per GPU = {args.batch_size}")
# logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
logger.info(f"  Total optimization steps = {num_iters_all}")
logger.info("Total params: {:.2f}M".format(
    sum(p.numel() for p in model.parameters()) / 1e6))
logger.info('-----------start training--------------')
for epoch in range(args.num_epoches):
    train_loss,loss_x,loss_u,mask_mean=train_one_epoch(
        epoch=epoch,
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optimizer=optimizer,
        lr_schdlr=lr_schdlr,
        ema=ema,
        labeledDatasetLoader=labeledDatasetDataloader,
        unlabeledDatasetLoader=unlabeledDatasetDataloader,
        lambda_u=args.lambda_u,
        n_iters=num_iters_per_epoch,
        args=args,
        logger=logger 
    )
    top1, top5, valid_loss = evaluate(ema, validDatasetDataloader, criteria_x, args=args)

    writer.add_scalars('train/1.loss', {'train': train_loss,
                                        'test': valid_loss}, epoch)
    writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
    writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
    writer.add_scalar('train/4.mask_mean', mask_mean, epoch)
    writer.add_scalars('test/1.test_acc', {'top1': top1, 'top5': top5}, epoch)
    # writer.add_scalar('test/2.test_loss', loss, epoch)

    # best_acc = top1 if best_acc < top1 else best_acc
    if best_acc < top1:
        best_acc = top1
        best_epoch = epoch

    logger.info("Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}".
                format(epoch, top1, top5, best_acc, best_epoch))

writer.close()

cuda:1
2022-04-12 07:36:34,344 - INFO - train -   ***** Running training *****
2022-04-12 07:36:34,345 - INFO - train -     Task = dataset@250
2022-04-12 07:36:34,346 - INFO - train -     Num Epochs = 1024
2022-04-12 07:36:34,347 - INFO - train -     Batch size per GPU = 16
2022-04-12 07:36:34,348 - INFO - train -     Total optimization steps = 131072
2022-04-12 07:36:34,349 - INFO - train -   Total params: 11.18M
2022-04-12 07:36:34,350 - INFO - train -   -----------start training--------------
2022-04-12 07:37:23,278 - INFO - train -   epoch:0, iter: 128. loss: 1.9111. loss_u: 0.0351. loss_x: 1.8759. Mask:4.2500 . LR: 0.0379. Time: 0.03
2022-04-12 07:38:11,113 - INFO - train -   epoch:0, iter: 256. loss: 1.7505. loss_u: 0.0572. loss_x: 1.6934. Mask:3.9297 . LR: 0.0351. Time: 0.03
2022-04-12 07:38:59,678 - INFO - train -   epoch:0, iter: 384. loss: 1.6410. loss_u: 0.0640. loss_x: 1.5770. Mask:4.4922 . LR: 0.0401. Time: 0.03
2022-04-12 07:39:48,019 - INFO - train -   epoch:0, iter: 512