In [1]:
import os
import logging
import random
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp

from utils import net_builder, get_logger, count_parameters
from train_utils import TBLog, get_SGD, get_cosine_schedule_with_warmup
from models.fixmatch.fixmatch import FixMatch
from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader

from sklearn.metrics import log_loss
from scipy.special import softmax
from scipy.optimize import minimize
from pycalib.metrics import classwise_ECE, conf_ECE

In [2]:
args = {
    'save_dir': './saved_models',
    'resume': 'store_true',
    'save_name': 'fixmatch',
    'load_path': None,
    'overwrite': 'store_true',
    'epoch': 300,
    'num_train_iter': 2**20,
    'num_eval_iter': 10000,
    'num_labels': 4000,
    'batch_size': 64,
    'uratio': 7,
    'eval_batch_size': 1024,
    'hard_label': True,
    'T': 0.5,
    'p_cutoff': 0.95,
    'ema_m': 0.999,
    'ulb_loss_ratio': 1.0,
    'lr': 0.03,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'net': 'WideResNet',
    'net_from_name': False,
    'depth': 28,
    'widen_factor': 2,
    'leaky_slope': 0.1,
    'dropout': 0.0,
    'data_dir': './data',
    'dataset': 'cifar10',
    'train_sampler': 'RandomSampler',
    'num_classes': 10,
    'amp': 'store_true',
    'gpu': 0,
    'multiprocessing_distributed': 'store_true',
    'rank': 0
}

from argparse import Namespace
args = Namespace(**args)

In [3]:
save_path = os.path.join(args.save_dir, args.save_name)

global best_acc1

seed = 1
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.deterministic = True

In [4]:
args.bn_momentum = 1.0 - args.ema_m
_net_builder = net_builder(args.net, 
                            args.net_from_name,
                            {'depth': args.depth, 
                            'widen_factor': args.widen_factor,
                            'leaky_slope': args.leaky_slope,
                            'bn_momentum': args.bn_momentum,
                            'dropRate': args.dropout})

model = FixMatch(_net_builder,
                    args.num_classes,
                    args.ema_m,
                    args.T,
                    args.p_cutoff,
                    args.ulb_loss_ratio,
                    args.hard_label,
                    num_eval_iter=args.num_eval_iter)


optimizer = get_SGD(model.train_model, 'SGD', args.lr, args.momentum, args.weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer,
                                            args.num_train_iter,
                                            num_warmup_steps=args.num_train_iter*0)

model.set_optimizer(optimizer, scheduler)
model.train_model = model.train_model.cuda()
model.eval_model = model.eval_model.cuda()

cudnn.benchmark = True

depth in <models.nets.wrn.build_WideResNet object at 0x7ff63b72e580> is overlapped by kwargs: 28 -> 28
widen_factor in <models.nets.wrn.build_WideResNet object at 0x7ff63b72e580> is overlapped by kwargs: 2 -> 2
leaky_slope in <models.nets.wrn.build_WideResNet object at 0x7ff63b72e580> is overlapped by kwargs: 0.0 -> 0.1
bn_momentum in <models.nets.wrn.build_WideResNet object at 0x7ff63b72e580> is overlapped by kwargs: 0.01 -> 0.0010000000000000009
dropRate in <models.nets.wrn.build_WideResNet object at 0x7ff63b72e580> is overlapped by kwargs: 0.0 -> 0.0


In [5]:
train_dset = SSL_Dataset(name=args.dataset, train=True, 
                            num_classes=args.num_classes, data_dir=args.data_dir)
lb_dset, ulb_dset, calib_dset = train_dset.get_ssl_dset(args.num_labels)

_eval_dset = SSL_Dataset(name=args.dataset, train=False, 
                            num_classes=args.num_classes, data_dir=args.data_dir)
eval_dset = _eval_dset.get_dset()

loader_dict = {}
dset_dict = {'train_lb': lb_dset, 'train_ulb': ulb_dset, 'eval': eval_dset}

loader_dict['train_lb'] = get_data_loader(dset_dict['train_lb'],
                                            args.batch_size,
                                            data_sampler = args.train_sampler,
                                            num_epochs=args.epoch)

loader_dict['train_ulb'] = get_data_loader(dset_dict['train_ulb'],
                                            args.batch_size*args.uratio,
                                            data_sampler = args.train_sampler,
                                            num_epochs=args.epoch)

loader_dict['eval'] = get_data_loader(dset_dict['eval'],
                                        args.eval_batch_size)

model.set_data_loader(loader_dict)

Files already downloaded and verified
Files already downloaded and verified
d
a
d
a
[!] data loader keys: dict_keys(['train_lb', 'train_ulb', 'eval'])


In [22]:
trainer = model.train
for epoch in range(args.epoch):
    print(f'Starting epoch {epoch + 1}/{args.epoch}')
    trainer(args)
    print(f'Finished epoch {epoch + 1}/{args.epoch}')

model.save_model('latest_model.pth', save_path)

Starting epoch 1/1
0 iteration, USE_EMA: True, {'train/sup_loss': tensor(2.2393, device='cuda:0'), 'train/unsup_loss': tensor(0., device='cuda:0'), 'train/total_loss': tensor(2.2393, device='cuda:0'), 'train/mask_ratio': tensor(1., device='cuda:0'), 'lr': 0.029999999999355702, 'train/prefecth_time': 0.7456192016601563, 'train/run_time': 0.2748617858886719, 'eval/loss': tensor(8.2592, device='cuda:0'), 'eval/top-1-acc': tensor(0.1000, device='cuda:0')}, BEST_EVAL_ACC: 0.09999999403953552, at 0 iters
model saved: ./saved_models/fixmatch/model_best.pth
Finished epoch 1/1
model saved: ./saved_models/fixmatch/latest_model.pth


In [8]:
checkpoint_path = os.path.join('./saved_models/fixmatch/model_best.pth')
checkpoint = torch.load(checkpoint_path)
load_model = checkpoint['train_model']

_net_builder = net_builder(args.net, 
                            args.net_from_name,
                            {'depth': args.depth, 
                            'widen_factor': args.widen_factor,
                            'leaky_slope': args.leaky_slope,
                            'dropRate': args.dropout})

net = _net_builder(num_classes=args.num_classes)

net.load_state_dict(load_model)
if torch.cuda.is_available():
    net.cuda()
net.eval()

calib_loader = get_data_loader(calib_dset,
                                args.eval_batch_size, 
                                num_workers=1)

with torch.no_grad():
    logits = np.empty((0, 10))
    encoding = []
    for image, target in calib_loader:
        image = image.type(torch.FloatTensor).cuda()
        logits = np.append(logits, net(image).cpu().detach().numpy(), axis=0)
        for value in target:
            encoding.append(np.eye(10)[value])
    
    def vector_scale_loss(x, *args):
        sm = softmax(np.multiply(logits, x[0:-1]) + x[-1], axis=1)
        return log_loss(encoding, sm)

    min_obj = minimize(vector_scale_loss, [1 for i in range(11)], method='Nelder-Mead')

    uncal_scores = softmax(logits, axis=1)
    vector_scores = softmax(np.multiply(logits, min_obj.x[0:-1]) + min_obj.x[-1], axis=1)

    conf_ece = conf_ECE(uncal_scores, encoding)
    print(conf_ece)

    conf_ece = conf_ECE(vector_scores, encoding)
    print(conf_ece)

depth in <models.nets.wrn.build_WideResNet object at 0x7ff62557bdf0> is overlapped by kwargs: 28 -> 28
widen_factor in <models.nets.wrn.build_WideResNet object at 0x7ff62557bdf0> is overlapped by kwargs: 2 -> 2
leaky_slope in <models.nets.wrn.build_WideResNet object at 0x7ff62557bdf0> is overlapped by kwargs: 0.0 -> 0.1
dropRate in <models.nets.wrn.build_WideResNet object at 0x7ff62557bdf0> is overlapped by kwargs: 0.0 -> 0.0


0.17620000000000002
0.14690000000000003


In [9]:
min_obj

       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 0.47924357241256127
             x: [ 5.179e-01  4.114e-01  5.672e-01  5.684e-01  5.539e-01
                  9.161e-01  7.049e-01  6.259e-01  4.441e-01  1.062e+00
                  2.171e+00]
           nit: 760
          nfev: 1076
 final_simplex: (array([[ 5.179e-01,  4.114e-01, ...,  1.062e+00,
                         2.171e+00],
                       [ 5.179e-01,  4.113e-01, ...,  1.062e+00,
                         2.171e+00],
                       ...,
                       [ 5.180e-01,  4.113e-01, ...,  1.062e+00,
                         2.171e+00],
                       [ 5.179e-01,  4.113e-01, ...,  1.062e+00,
                         2.171e+00]]), array([ 4.792e-01,  4.792e-01,  4.792e-01,  4.792e-01,
                        4.792e-01,  4.792e-01,  4.792e-01,  4.792e-01,
                        4.792e-01,  4.792e-01,  4.792e-01,  4.792e-01]))