In [None]:
import re
import argparse
import os
import shutil
import time
import math
import logging
from datetime import datetime
import itertools
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler, Sampler
import torchvision.datasets

from collections import defaultdict
from pandas import DataFrame
import threading


from mean_teacher import architectures, datasets, data, losses, ramps, cli
from mean_teacher.run_context import RunContext
from mean_teacher.data import NO_LABEL
from mean_teacher.utils import *

# networks

In [None]:
class generator(nn.Module):  # # #
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self, input_dim=100, output_dim=1, input_size=32):
        super(generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 4,2,1可以将大小扩大一倍
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return x

class discriminator(nn.Module):  # # #
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, input_dim=1, output_dim=1, input_size=32):
        super(discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1), # 4，2，1缩小1倍
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2)            
        )
        self.weight = torch.nn.Parameter(torch.FloatTensor(output_dim, 1024))
        torch.nn.init.xavier_uniform_(self.weight)
        initialize_weights(self)

    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        x = self.fc(x)
        x = F.linear(F.normalize(x), F.normalize(self.weight))

        return x

In [None]:
z = torch.rand(32, 100)
x = generator()(z)
print(x.shape)

out = discriminator()(x)
out.shape

In [None]:
LOG = logging.getLogger('main')

In [None]:
logging.basicConfig(level=logging.INFO)


In [None]:
best_prec1 = 0
global_step = 0

In [None]:
def create_parser():
    parser = argparse.ArgumentParser(description='RMCOS-SSL')

    parser.add_argument('--dataset', default='cifar10'),

    parser.add_argument('--train-subdir', type=str, default='train+val')

    parser.add_argument('--eval-subdir', type=str, default='test')

    parser.add_argument('--labels', default="data-local/labels/cifar10/4000_balanced_labels/00.txt", type=str)

    parser.add_argument('--exclude-unlabeled', default=False, action="store_true")

    parser.add_argument('--arch', '-a', default='cifar_shakeshake26')

    # # 好像没有resnet18
    parser.add_argument('--lrG', type=float, default=0.0002)  # # #
    parser.add_argument('--lrD', type=float, default=0.0002)  # # #
    parser.add_argument('--beta1', type=float, default=0.5)  # # #
    parser.add_argument('--beta2', type=float, default=0.999)  # # #

    parser.add_argument('--workers', default=4, type=int)

    parser.add_argument('--epochs', default=1, type=int)

    parser.add_argument('--start-epoch', default=0, type=int)

    parser.add_argument('--batch-size', default=128, type=int)

    parser.add_argument('--labeled-batch-size', default=64, type=int)

    parser.add_argument('--generated-batch-size', type=int, default=128)

    parser.add_argument('--z-dim', type=int, default=100)

    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float)

    parser.add_argument('--initial-lr', default=0.0, type=float)

    parser.add_argument('--lr-rampup', default=0, type=int)

    parser.add_argument('--lr-rampdown-epochs', default=None, type=int)

    parser.add_argument('--momentum', default=0.9, type=float)

    parser.add_argument('--nesterov', default=True, action="store_true")

    parser.add_argument('--weight-decay', default=1e-4, type=float)

    parser.add_argument('--ema-decay', default=0.999, type=float)

    parser.add_argument('--consistency', default=100.0, type=float)

    parser.add_argument('--consistency-type', default="mse", type=str)

    parser.add_argument('--consistency-rampup', default=5, type=int)

    parser.add_argument('--logit-distance-cost', default=0.01, type=float)

    parser.add_argument('--checkpoint-epochs', default=1, type=int)
    parser.add_argument('--evaluation-epochs', default=1, type=int)

    parser.add_argument('--print-freq', default=10, type=int)

    parser.add_argument('--resume', default='', type=str)

    parser.add_argument('--evaluate', default=False, action="store_true")

    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    return parser

args = create_parser().parse_args(args=[]) # for jupyter

In [None]:
out_dir = "out"
date_time_now = datetime.now()

checkpoint_path = "{}/{:%Y-%m-%d_%H:%M:%S}".format(out_dir, date_time_now)
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path, exist_ok=True)

print(checkpoint_path)


# transform data

In [None]:
class RandomTranslateWithReflect:
    """Translate image randomly

    Translate vertically and horizontally by n pixels where
    n is integer drawn uniformly independently for each axis
    from [-max_translation, max_translation].

    Fill the uncovered blank area with reflect padding.
    """

    def __init__(self, max_translation):
        self.max_translation = max_translation

    def __call__(self, old_image):
        xtranslation, ytranslation = np.random.randint(-self.max_translation,
                                                       self.max_translation + 1,
                                                       size=2)
        xpad, ypad = abs(xtranslation), abs(ytranslation)
        xsize, ysize = old_image.size

        flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT)
        flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM)
        flipped_both = old_image.transpose(Image.ROTATE_180)

        new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad))

        new_image.paste(old_image, (xpad, ypad))

        new_image.paste(flipped_lr, (xpad + xsize - 1, ypad))
        new_image.paste(flipped_lr, (xpad - xsize + 1, ypad))

        new_image.paste(flipped_tb, (xpad, ypad + ysize - 1))
        new_image.paste(flipped_tb, (xpad, ypad - ysize + 1))

        new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1))
        new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1))
        new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1))
        new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1))

        new_image = new_image.crop((xpad - xtranslation,
                                    ypad - ytranslation,
                                    xpad + xsize - xtranslation,
                                    ypad + ysize - ytranslation))

        return new_image


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2


def relabel_dataset(dataset, labels):
    unlabeled_idxs = []
    for idx in range(len(dataset.imgs)):
        path, _ = dataset.imgs[idx]  # # dataset.imgs里面存的是(path,label)
        filename = os.path.basename(path)
        if filename in labels:
            label_idx = dataset.class_to_idx[labels[filename]]  # # class_to_idx存的是{'cat': 0, 'dog': 1}
            dataset.imgs[idx] = path, label_idx  # # 重新更新一遍
            del labels[filename]
        else:
            dataset.imgs[idx] = path, NO_LABEL
            unlabeled_idxs.append(idx)

    if len(labels) != 0:
        message = "List of unlabeled contains {} unknown files: {}, ..."
        some_missing = ', '.join(list(labels.keys())[:5])  # # 为什么有个[:5]
        raise LookupError(message.format(len(labels), some_missing))

    labeled_idxs = sorted(set(range(len(dataset.imgs))) - set(unlabeled_idxs))

    print("num of labeled: " + str(len(labeled_idxs)) + '\n')
    print("num of unlabeled: " + str(len(unlabeled_idxs)) + '\n')
    return labeled_idxs, unlabeled_idxs


class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in  zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)

In [None]:
import torchvision.transforms as transforms

def cifar10():

    train_transformation = TransformTwice(transforms.Compose([  # # 为什么弄个twice
        data.RandomTranslateWithReflect(4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]))

    eval_transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    return {
        'train_transformation': train_transformation,
        'eval_transformation': eval_transformation,
        'datadir': 'data-local/images/cifar/cifar10/by-image',
        'num_classes': 10
    }

In [None]:
args.dataset

In [None]:
dataset_config = datasets.__dict__[args.dataset]()
dataset_config

In [None]:
num_classes = dataset_config.pop('num_classes')


In [None]:
num_classes

In [None]:
dataset_config

# make dataloader

In [None]:
lst = [args.exclude_unlabeled, args.labeled_batch_size]
print(lst)
print(sum(int(bool(el)) for el in lst) == 1, ", ".join(str(el) for el in lst))
assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

In [None]:
def assert_exactly_one(lst):
    assert sum(int(bool(el)) for el in lst) == 1, ", ".join(str(el) for el in lst)
    

In [None]:
dataset_config

In [None]:
train_transformation = dataset_config["train_transformation"]
datadir = dataset_config["datadir"]
traindir = os.path.join(datadir, args.train_subdir)
evaldir = os.path.join(datadir, args.eval_subdir)

print(train_transformation)
print(traindir)
print(evaldir)

In [None]:
dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)
dataset

In [None]:
dataset.extensions

In [None]:
dataset.root

In [None]:
dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

with open(args.labels) as f:
    labels = dict(line.split(' ') for line in f.read().splitlines())

labeled_idxs, unlabeled_idxs = relabel_dataset(dataset, labels)
len(labeled_idxs), type(labeled_idxs)

In [None]:
def create_data_loaders(train_transformation,
                        eval_transformation,
                        datadir,
                        args):
    traindir = os.path.join(datadir, args.train_subdir)
    evaldir = os.path.join(datadir, args.eval_subdir)

    assert_exactly_one([args.exclude_unlabeled, args.labeled_batch_size])

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled: # False
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size: ## True
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
            # unlabeled: 97, labeled: 31 -> batch_size: 128 
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    # tao train loader voi 31 labeled va 97 unlabeled
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler)

    eval_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(evaldir, eval_transformation),
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)

    return train_loader, eval_loader

In [None]:
train_loader, eval_loader = create_data_loaders(**dataset_config, args=args)

# utils

In [None]:
def initialize_weights(net): # # #
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

def save_images(images, size, image_path): # # #
    return imsave(images, size, image_path)

def imsave(images, size, path): # # #
    image = np.squeeze(merge(images, size))
    return imageio.imwrite(path, image)

def merge(images, size): # # #
    h, w = images.shape[1], images.shape[2]
    if (images.shape[3] in (3,4)):
        c = images.shape[3]
        img = np.zeros((h * size[0], w * size[1], c))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w, :] = image
        return img
    elif images.shape[3]==1:
        img = np.zeros((h * size[0], w * size[1]))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
        return img
    else:
        raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')


def parameters_string(module):
    lines = [
        "",
        "List of model parameters:",
        "=========================",
    ]

    row_format = "{name:<40} {shape:>20} ={total_size:>12,d}"
    params = list(module.named_parameters())
    for name, param in params:
        lines.append(row_format.format(
            name=name,
            shape=" * ".join(str(p) for p in param.size()),
            total_size=param.numel()
        ))
    lines.append("=" * 75)
    lines.append(row_format.format(
        name="all parameters",
        shape="sum of above",
        total_size=sum(int(param.numel()) for name, param in params)
    ))
    lines.append("")
    return "\n".join(lines)


def assert_exactly_one(lst):
    assert sum(int(bool(el)) for el in lst) == 1, ", ".join(str(el) for el in lst)


class AverageMeterSet:
    def __init__(self):
        self.meters = {}

    def __getitem__(self, key):
        return self.meters[key]

    def update(self, name, value, n=1):
        if not name in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, postfix=''):
        return {name + postfix: meter.val for name, meter in self.meters.items()}

    def averages(self, postfix='/avg'):
        return {name + postfix: meter.avg for name, meter in self.meters.items()}

    def sums(self, postfix='/sum'):
        return {name + postfix: meter.sum for name, meter in self.meters.items()}

    def counts(self, postfix='/count'):
        return {name + postfix: meter.count for name, meter in self.meters.items()}


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, format):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)


def export(fn):
    mod = sys.modules[fn.__module__]
    if hasattr(mod, '__all__'):  # 检查是否具有某个属性
        mod.__all__.append(fn.__name__)
    else:
        mod.__all__ = [fn.__name__]
    return fn


def parameter_count(module):
    return sum(int(param.numel()) for param in module.parameters())

In [None]:
architectures.__dict__[args.arch]

In [None]:
ema = False
LOG.info("=> creating {pretrained}{ema}model '{arch}'".format(
pretrained='pre-trained ' if args.pretrained else '',
ema='EMA ' if ema else '',
arch=args.arch))

model_factory = architectures.__dict__[args.arch]
model_params = dict(pretrained=args.pretrained, num_classes=num_classes)
model = model_factory(**model_params)

# model

In [None]:
def create_model(ema=False):
        LOG.info("=> creating {pretrained}{ema}model '{arch}'".format(
            pretrained='pre-trained ' if args.pretrained else '',
            ema='EMA ' if ema else '',
            arch=args.arch))

        model_factory = architectures.__dict__[args.arch]
        model_params = dict(pretrained=args.pretrained, num_classes=num_classes)
        model = model_factory(**model_params)

        if ema:  # # exponential moving average
            for param in model.parameters():
                param.detach_()

        return model

In [None]:
model = create_model()
ema_model = create_model(ema=True)

model.cuda()
ema_model.cuda()

In [None]:
G = generator(input_dim=args.z_dim, output_dim=3, input_size=32)
D = discriminator(input_dim=3, output_dim=1, input_size=32)
G.cuda()
D.cuda()



In [None]:
# optimizer
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay,
                            nesterov=args.nesterov)
G_optimizer = torch.optim.Adam(G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

# loss

BCEloss = nn.BCEWithLogitsLoss().cuda()

def nll_loss_neg(y_pred, y_true):  # # #
    out = torch.sum(y_true * y_pred, dim=1)
    return torch.mean(- torch.log((1 - out) + 1e-6))

def inverted_cross_entropy(y_pred, y_true):
    out = - torch.mean(y_true * torch.log(1-y_pred + 1e-6) + 1e-6)
    return out

# bce_loss = nn.BCEWithLogitsLoss()

def d_loss(real, fake, y, m=0.15, s=10.0):
    real = real - m
    return BCEloss(s*(real - fake) + 1e-6 , y)

def g_loss(real, fake, y, m=0.15, s=10.0):
    fake = fake + m
    return BCEloss(s*(fake - real) + 1e-6, y)

In [None]:
args.resume

In [None]:
if args.resume:
        LOG.info("=> loading checkpoint '{}'".format(args.resume))
        # exit()

        best_file = os.path.join(args.resume, 'best.ckpt')
        G_file = os.path.join(args.resume, 'G.pkl')
        C_file = os.path.join(args.resume, 'D.pkl')

        assert os.path.isfile(best_file), "=> no checkpoint found at '{}'".format(best_file)
        assert os.path.isfile(G_file), "=> no checkpoint found at '{}'".format(G_file)
        assert os.path.isfile(C_file), "=> no checkpoint found at '{}'".format(C_file)

        checkpoint = torch.load(best_file)
        # print(checkpoint.keys())
        # exit()
        args.start_epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        G.load_state_dict(torch.load(G_file))
        D.load_state_dict(torch.load(C_file))

        print('----------------best_precl----------------', best_prec1)

        LOG.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

In [None]:
cudnn.benchmark =  True

In [None]:
args.evaluate

In [None]:
def validate(eval_loader, model, log, global_step, epoch):
    if torch.cuda.is_available():
        class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda()
    else:
        class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL)

    meters = AverageMeterSet()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(eval_loader):
        meters.update('data_time', time.time() - end)
        if torch.cuda.is_available():
            input_var = torch.autograd.Variable(input.cuda(), volatile=True)
            target_var = torch.autograd.Variable(target.cuda(), volatile=True)
        else:          
            input_var = torch.autograd.Variable(input, volatile=True)
            target_var = torch.autograd.Variable(target, volatile=True)

        minibatch_size = len(target_var)
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()
        assert labeled_minibatch_size > 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)

        # compute output
        output1, output2 = model(input_var)
        softmax1, softmax2 = F.softmax(output1, dim=1), F.softmax(output2, dim=1)
        class_loss = class_criterion(output1, target_var) / minibatch_size

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output1.data, target_var.data, topk=(1, 5))
        meters.update('class_loss', class_loss.item(), labeled_minibatch_size)
        meters.update('top1', prec1[0], labeled_minibatch_size)
        meters.update('error1', 100.0 - prec1[0], labeled_minibatch_size)
        meters.update('top5', prec5[0], labeled_minibatch_size)
        meters.update('error5', 100.0 - prec5[0], labeled_minibatch_size)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            LOG.info(
                'Test: [{0}/{1}]\t'
                'Time {meters[batch_time]:.3f}\t'
                'Data {meters[data_time]:.3f}\t'
                'Class {meters[class_loss]:.4f}\t'
                'Prec@1 {meters[top1]:.3f}\t'
                'Prec@5 {meters[top5]:.3f}'.format(
                    i, len(eval_loader), meters=meters))

    LOG.info(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}'
          .format(top1=meters['top1'], top5=meters['top5']))


    return meters['top1'].avg

In [None]:
def save_checkpoint(state, is_best, dirpath, epoch):
    best_path = os.path.join(dirpath, 'best.ckpt')
    torch.save(state, best_path)


def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch):
    lr = args.lr
    epoch = epoch + step_in_epoch / total_steps_in_epoch

    # LR warm-up to handle large minibatch sizes from https://arxiv.org/abs/1706.02677
    lr = ramps.linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr

    # Cosine LR rampdown from https://arxiv.org/abs/1608.03983 (but one cycle only)
    if args.lr_rampdown_epochs:
        assert args.lr_rampdown_epochs >= args.epochs
        lr *= ramps.cosine_rampdown(epoch, args.lr_rampdown_epochs)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)

def generated_weight(epoch):
    alpha = 0.0
    T1 = 10
    T2 = 60
    af = 0.3
    if epoch > T1:
        alpha = (epoch-T1) / (T2-T1)*af
        if epoch > T2:
            alpha = af
    return alpha

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    labeled_minibatch_size = max(target.ne(NO_LABEL).sum(), 1e-8)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / labeled_minibatch_size.float()))
    return res
    
def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

def visualize_results(G, epoch):
    G.eval()
    generated_images_dir = 'generated_images/' + args.dataset
    if not os.path.exists(generated_images_dir):
        os.makedirs(generated_images_dir)

    tot_num_samples = 64
    image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

    sample_z_ = torch.rand((tot_num_samples, args.z_dim))

    sample_z_ = sample_z_.cuda()

    samples = G(sample_z_)


    samples = samples.mul(0.5).add(0.5)

    samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)


    save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                      generated_images_dir + '/' + 'epoch%03d' % epoch + '.png')
                          
def train(train_loader, model, ema_model, optimizer, G, D, G_optimizer, D_optimizer, epoch, log, BCEloss):
    global global_step

    class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda()
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
    else:
        assert False, args.consistency_type
    residual_logit_criterion = losses.symmetric_mse_loss

    meters = AverageMeterSet()

    # switch to train mode
    model.train()
    ema_model.train()
    D.train()
    G.train()

    end = time.time()

    # y_real = 



    for i, ((input, ema_input), target) in enumerate(train_loader):
        # print("input", input.shape) # (128, 3, 32, 32)
        # print("ema input", ema_input.shape) # (128, 3, 32, 32)
        # print("target", target.shape) # 128

        # measure data loading time
        meters.update('data_time', time.time() - end)
        # print(meters['data_time'])

        adjust_learning_rate(optimizer, epoch, i, len(train_loader))
        meters.update('lr', optimizer.param_groups[0]['lr'])
        # print(meters["lr"])

        input_var = torch.autograd.Variable(input.cuda())
        ema_input_var = torch.autograd.Variable(ema_input.cuda(), volatile=True)
        target_var = torch.autograd.Variable(target.cuda())

        minibatch_size = len(target_var)
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()
        # print("labeled batchsize", labeled_minibatch_size) # 31
        assert labeled_minibatch_size > 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)

        ema_model_out = ema_model(ema_input_var) # tuple shape (128, 10), (128, 10 )
        model_out = model(input_var) # nhu tren

        if isinstance(model_out, Variable): # False
            assert args.logit_distance_cost < 0
            logit1 = model_out
            ema_logit = ema_model_out
        else:  
            assert len(model_out) == 2
            assert len(ema_model_out) == 2
            logit1, logit2 = model_out
            ema_logit, _ = ema_model_out

        ema_logit = Variable(ema_logit.detach().data, requires_grad=False)

        if args.logit_distance_cost >= 0: # 0.01
            class_logit, cons_logit = logit1, logit2
            res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
            meters.update('res_loss', res_loss.item())

        else:
            class_logit, cons_logit = logit1, logit1
            res_loss = 0

        class_loss = class_criterion(class_logit, target_var) / minibatch_size
        meters.update('class_loss', class_loss.item())

        ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
        meters.update('ema_class_loss', ema_class_loss.item())

        if args.consistency: # 100
            consistency_weight = get_current_consistency_weight(epoch)
            meters.update('cons_weight', consistency_weight)
            consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size


            meters.update('cons_loss', consistency_loss.item())
        else:
            consistency_loss = 0
            meters.update('cons_loss', 0)


        z_ = torch.rand((args.generated_batch_size, args.z_dim))
        z_ = z_.cuda()
        G_ = G(z_)

        C_fake_pred, _ = model(G_)

        C_fake_pred = F.softmax(C_fake_pred, dim=1)
        with torch.no_grad():
            C_fake_wei = torch.max(C_fake_pred, 1)[1]
            C_fake_wei = C_fake_wei.view(-1, 1)
            C_fake_wei = torch.zeros(args.generated_batch_size, 10).cuda().scatter_(1, C_fake_wei, 1)

        # C_fake_loss = nll_loss_neg(C_fake_pred, C_fake_wei)
        C_fake_loss = inverted_cross_entropy(C_fake_pred, C_fake_wei)

        loss = class_loss + consistency_loss + res_loss + generated_weight(epoch) * C_fake_loss

        assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format(loss.item())
        meters.update('loss', loss.item())

        prec1, prec5 = accuracy(class_logit.data, target_var.data, topk=(1, 5))
        meters.update('top1', prec1[0], labeled_minibatch_size)
        meters.update('error1', 100. - prec1[0], labeled_minibatch_size)
        meters.update('top5', prec5[0], labeled_minibatch_size)
        meters.update('error5', 100. - prec5[0], labeled_minibatch_size)

        ema_prec1, ema_prec5 = accuracy(ema_logit.data, target_var.data, topk=(1, 5))
        meters.update('ema_top1', ema_prec1[0], labeled_minibatch_size)
        meters.update('ema_error1', 100. - ema_prec1[0], labeled_minibatch_size)
        meters.update('ema_top5', ema_prec5[0], labeled_minibatch_size)
        meters.update('ema_error5', 100. - ema_prec5[0], labeled_minibatch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        update_ema_variables(model, ema_model, args.ema_decay, global_step)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            LOG.info(
                'Epoch: [{0}][{1}/{2}]\t'
                'Time {meters[batch_time]:.3f}\t'
                'Data {meters[data_time]:.3f}\t'
                'Class {meters[class_loss]:.4f}\t'
                'Cons {meters[cons_loss]:.4f}\t'
                'Prec@1 {meters[top1]:.3f}\t'
                'Prec@5 {meters[top5]:.3f}'.format(
                    epoch, i, len(train_loader), meters=meters))


        # update D network
        D_optimizer.zero_grad()

        # sửa hai hàm loss chỗ này !!! 
        # input_var là gì ? 

        D_real = D(input_var) # (128, 1)
        # D_real_loss = BCEloss(D_real, torch.ones_like(D_real)) # scalar

        G_ = G(z_)
        D_fake = D(G_) # (32, 1)

        # D_fake_loss = BCEloss(D_fake, torch.zeros_like(D_fake)) # scalar

        # D_loss = D_real_loss + D_fake_loss # scalar + scalar
        D_loss = d_loss(D_real, D_fake, torch.ones_like(D_real))

        D_loss.backward()
        D_optimizer.step()

        # update G network
        G_optimizer.zero_grad()

        G_ = G(z_)

        D_real = D(input_var)
        D_fake = D(G_)
        # G_loss_D = BCEloss(D_fake, torch.ones_like(D_fake))
        G_loss_D = g_loss(D_real, D_fake, torch.ones_like(D_fake))

        C_fake_pred, _ = model(G_)
        C_fake_pred = F.log_softmax(C_fake_pred, dim=1)
        with torch.no_grad():
            C_fake_wei = torch.max(C_fake_pred, 1)[1]
        G_loss_C = F.nll_loss(C_fake_pred, C_fake_wei)

        G_loss = G_loss_D + generated_weight(epoch) * G_loss_C
        if epoch <= 10:
            G_loss_D.backward()
        else:
            G_loss_D.backward(retain_graph=True)
            G_loss_C.backward()

        G_optimizer.step()

        if i % args.print_freq == 0:
            print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                  (
                      epoch, i, len(train_loader),
                      D_loss.item(),
                      G_loss.item()))

    with torch.no_grad():
        visualize_results(G, (epoch + 1))

In [None]:
for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()
        # train for one epoch
        train(train_loader, model, ema_model, optimizer, G, D, G_optimizer, D_optimizer, epoch, training_log, BCEloss)
        LOG.info("--- training epoch in %s seconds ---" % (time.time() - start_time))

        if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0:
            start_time = time.time()
            LOG.info("Evaluating the primary model:")
            prec1 = validate(eval_loader, model, validation_log, global_step, epoch + 1)
            LOG.info("Evaluating the EMA model:")
            ema_prec1 = validate(eval_loader, ema_model, ema_validation_log, global_step, epoch + 1)
            LOG.info("--- validation in %s seconds ---" % (time.time() - start_time))
            is_best = ema_prec1 > best_prec1
            best_prec1 = max(ema_prec1, best_prec1)
        else:
            is_best = False

        if args.checkpoint_epochs and (epoch + 1) % args.checkpoint_epochs == 0 and is_best:
            save_checkpoint({
                'epoch': epoch + 1,
                'global_step': global_step,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'ema_state_dict': ema_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, checkpoint_path, epoch + 1)
            torch.save(G.state_dict(), os.path.join(checkpoint_path, 'G.pkl'))
            torch.save(D.state_dict(), os.path.join(checkpoint_path, 'D.pkl'))