<a href="https://colab.research.google.com/github/matinmoezzi/ups_conformal_classification/blob/main/train-ups-raps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/matinmoezzi/ups_conformal_classification
%cd ups_conformal_classification
!pip install -r requirements.txt

Cloning into 'ups_conformal_classification'...
remote: Enumerating objects: 42, done.[K
remote: Counting objects: 100% (42/42), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 42 (delta 11), reused 29 (delta 6), pack-reused 0[K
Unpacking objects: 100% (42/42), done.
/content/ups_conformal_classification
Collecting numpy==1.16.2
[?25l  Downloading https://files.pythonhosted.org/packages/91/e7/6c780e612d245cca62bc3ba8e263038f7c144a96a54f877f3714a0e8427e/numpy-1.16.2-cp37-cp37m-manylinux1_x86_64.whl (17.3MB)
[K     |████████████████████████████████| 17.3MB 163kB/s 
[?25hCollecting scikit-learn==0.21.1
[?25l  Downloading https://files.pythonhosted.org/packages/ef/52/3254e511ef1fc88d31edf457d90ecfd531931d4202f1b8ee0c949e9478f6/scikit_learn-0.21.1-cp37-cp37m-manylinux1_x86_64.whl (6.7MB)
[K     |████████████████████████████████| 6.7MB 18.4MB/s 
[?25hCollecting scipy==1.2.1
[?25l  Downloading https://files.pythonhosted.org/packages/3e/7e/5cee36eee5b31946872

In [1]:
%cd ups_conformal_classification

/content/ups_conformal_classification


In [21]:
import argparse
import logging
import math
import os
import random
import shutil
import time
from copy import deepcopy
from collections import OrderedDict
import pickle
import numpy as np
from re import search
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, random_split
from tensorboardX import SummaryWriter
from tqdm import tqdm
from datetime import datetime
from data.cifar import get_cifar10, get_cifar100
from utils import AverageMeter, accuracy
from utils.utils import *
from utils.train_util import train_initial, train_regular
from utils.evaluate import test
from utils.pseudo_labeling_util import pseudo_labeling
from utils.misc import AverageMeter, accuracy
from utils.utils import enable_dropout
from conformal_classification.utils import *
# from conformal_classification.conformal import *


In [14]:
def raps_pseudo_labeling(args, data_loader, model, itr):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    pseudo_idx = []
    pseudo_target = []
    pseudo_maxstd = []
    gt_target = []
    idx_list = []
    gt_list = []
    target_list = []
    nl_mask = []
    model.eval()
    if not args.no_uncertainty:
        f_pass = 10
        enable_dropout(model)
    else:
        f_pass = 1

    if not args.no_progress:
        data_loader = tqdm(data_loader)

    with torch.no_grad():
        for batch_idx, (inputs, targets, indexs, _) in enumerate(data_loader):
            data_time.update(time.time() - end)
            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            out_prob = []
            out_prob_nl = []
            for _ in range(f_pass):
                outputs, pred_sets = model(inputs)
                out_prob.append(F.softmax(outputs, dim=1)) #for selecting positive pseudo-labels
                out_prob_nl.append(F.softmax(outputs/args.temp_nl, dim=1)) #for selecting negative pseudo-labels
            print("pred_sets:\n", pred_sets)
            out_prob = torch.stack(out_prob)
            out_prob_nl = torch.stack(out_prob_nl)
            out_std = torch.std(out_prob, dim=0)
            out_std_nl = torch.std(out_prob_nl, dim=0)
            out_prob = torch.mean(out_prob, dim=0)
            out_prob_nl = torch.mean(out_prob_nl, dim=0)
            max_value, max_idx = torch.max(out_prob, dim=1)
            max_std = out_std.gather(1, max_idx.view(-1,1))
            out_std_nl = out_std_nl.cpu().numpy()
            
            #selecting negative pseudo-labels
            interm_nl_mask = ((out_std_nl < args.kappa_n) * (out_prob_nl.cpu().numpy() < args.tau_n)) *1

            #manually setting the argmax value to zero
            for enum, item in enumerate(max_idx.cpu().numpy()):
                interm_nl_mask[enum, item] = 0
            nl_mask.extend(interm_nl_mask)

            idx_list.extend(indexs.numpy().tolist())
            gt_list.extend(targets.cpu().numpy().tolist())
            target_list.extend(max_idx.cpu().numpy().tolist())

            #selecting positive pseudo-labels
            if not args.no_uncertainty:
                selected_idx = (max_value>=args.tau_p) * (max_std.squeeze(1) < args.kappa_p)
            else:
                selected_idx = max_value>=args.tau_p

            pseudo_maxstd.extend(max_std.squeeze(1)[selected_idx].cpu().numpy().tolist())
            pseudo_target.extend(max_idx[selected_idx].cpu().numpy().tolist())
            pseudo_idx.extend(indexs[selected_idx].numpy().tolist())
            gt_target.extend(targets[selected_idx].cpu().numpy().tolist())

            loss = F.cross_entropy(outputs, targets.to(dtype=torch.long))
            prec1, prec5 = accuracy(outputs[selected_idx], targets[selected_idx], topk=(1, 5))

            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                data_loader.set_description("Pseudo-Labeling Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                    batch=batch_idx + 1,
                    iter=len(data_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))
        if not args.no_progress:
            data_loader.close()

    pseudo_target = np.array(pseudo_target)
    gt_target = np.array(gt_target)
    pseudo_maxstd = np.array(pseudo_maxstd)
    pseudo_idx = np.array(pseudo_idx)

    #class balance the selected pseudo-labels
    if itr < args.class_blnc-1:
        min_count = 5000000 #arbitary large value
        for class_idx in range(args.num_classes):
            class_len = len(np.where(pseudo_target==class_idx)[0])
            if class_len < min_count:
                min_count = class_len
        min_count = max(25, min_count) #this 25 is used to avoid degenarate cases when the minimum count for a certain class is very low

        blnc_idx_list = []
        for class_idx in range(args.num_classes):
            current_class_idx = np.where(pseudo_target==class_idx)
            if len(np.where(pseudo_target==class_idx)[0]) > 0:
                current_class_maxstd = pseudo_maxstd[current_class_idx]
                sorted_maxstd_idx = np.argsort(current_class_maxstd)
                current_class_idx = current_class_idx[0][sorted_maxstd_idx[:min_count]] #select the samples with lowest uncertainty 
                blnc_idx_list.extend(current_class_idx)

        blnc_idx_list = np.array(blnc_idx_list)
        pseudo_target = pseudo_target[blnc_idx_list]
        pseudo_idx = pseudo_idx[blnc_idx_list]
        gt_target = gt_target[blnc_idx_list]

    pseudo_labeling_acc = (pseudo_target == gt_target)*1
    pseudo_labeling_acc = (sum(pseudo_labeling_acc)/len(pseudo_labeling_acc))*100
    print(f'Pseudo-Labeling Accuracy (positive): {pseudo_labeling_acc}, Total Selected: {len(pseudo_idx)}')

    pseudo_nl_mask = []
    pseudo_nl_idx = []
    nl_gt_list = []

    for i in range(len(idx_list)):
        if idx_list[i] not in pseudo_idx and sum(nl_mask[i]) > 0:
            pseudo_nl_mask.append(nl_mask[i])
            pseudo_nl_idx.append(idx_list[i])
            nl_gt_list.append(gt_list[i])

    nl_gt_list = np.array(nl_gt_list)
    pseudo_nl_mask = np.array(pseudo_nl_mask)
    one_hot_targets = np.eye(args.num_classes)[nl_gt_list]
    one_hot_targets = one_hot_targets - 1
    one_hot_targets = np.abs(one_hot_targets)
    flat_pseudo_nl_mask = pseudo_nl_mask.reshape(1,-1)[0]
    flat_one_hot_targets = one_hot_targets.reshape(1,-1)[0]
    flat_one_hot_targets = flat_one_hot_targets[np.where(flat_pseudo_nl_mask == 1)]
    flat_pseudo_nl_mask = flat_pseudo_nl_mask[np.where(flat_pseudo_nl_mask == 1)]

    nl_accuracy = (flat_pseudo_nl_mask == flat_one_hot_targets)*1
    nl_accuracy_final = (sum(nl_accuracy)/len(nl_accuracy))*100
    print(f'Pseudo-Labeling Accuracy (negative): {nl_accuracy_final}, Total Selected: {len(nl_accuracy)}, Unique Samples: {len(pseudo_nl_mask)}')
    pseudo_label_dict = {'pseudo_idx': pseudo_idx.tolist(), 'pseudo_target':pseudo_target.tolist(), 'nl_idx': pseudo_nl_idx, 'nl_mask': pseudo_nl_mask.tolist()}
 
    return losses.avg, top1.avg, pseudo_labeling_acc, len(pseudo_idx), nl_accuracy_final, len(nl_accuracy), len(pseudo_nl_mask), pseudo_label_dict

In [15]:
from PIL import Image
from torchvision import datasets
from torchvision import transforms
from data.augmentations import RandAugment, CutoutRandom

In [16]:
def get_cifar10_customized(root='data/datasets', n_lbl=4000, ssl_idx=None, pseudo_lbl=None, itr=0, split_txt=''):
    os.makedirs(root, exist_ok=True) #create the root directory for saving data
    # augmentations
    transform_train = transforms.Compose([
        RandAugment(3,4),  #from https://arxiv.org/pdf/1909.13719.pdf. For CIFAR-10 M=3, N=4
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=int(32*0.125), padding_mode='reflect'),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
        ),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
        CutoutRandom(n_holes=1, length=16, random=True)
    ])
    
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616))
    ])

    if ssl_idx is None:
        base_dataset = datasets.CIFAR10(root, train=True, download=True)
        train_lbl_idx, train_unlbl_idx = lbl_unlbl_split(base_dataset.targets, n_lbl, 10)
        
        os.makedirs('data/splits', exist_ok=True)
        f = open(os.path.join('data/splits', f'cifar10_basesplit_{n_lbl}_{split_txt}.pkl'),"wb")
        lbl_unlbl_dict = {'lbl_idx': train_lbl_idx, 'unlbl_idx': train_unlbl_idx}
        pickle.dump(lbl_unlbl_dict,f)
    
    else:
        lbl_unlbl_dict = pickle.load(open(ssl_idx, 'rb'))
        train_lbl_idx = lbl_unlbl_dict['lbl_idx']
        train_unlbl_idx = lbl_unlbl_dict['unlbl_idx']

    lbl_idx = train_lbl_idx
    if pseudo_lbl is not None:
        pseudo_lbl_dict = pickle.load(open(pseudo_lbl, 'rb'))
        pseudo_idx = pseudo_lbl_dict['pseudo_idx']
        pseudo_target = pseudo_lbl_dict['pseudo_target']
        nl_idx = pseudo_lbl_dict['nl_idx']
        nl_mask = pseudo_lbl_dict['nl_mask']
        lbl_idx = np.array(lbl_idx + pseudo_idx)

        #balance the labeled and unlabeled data 
        if len(nl_idx) > len(lbl_idx):
            exapand_labeled = len(nl_idx) // len(lbl_idx)
            lbl_idx = np.hstack([lbl_idx for _ in range(exapand_labeled)])

            if len(lbl_idx) < len(nl_idx):
                diff = len(nl_idx) - len(lbl_idx)
                lbl_idx = np.hstack((lbl_idx, np.random.choice(lbl_idx, diff)))
            else:
                assert len(lbl_idx) == len(nl_idx)
    else:
        pseudo_idx = None
        pseudo_target = None
        nl_idx = None
        nl_mask = None

    train_lbl_dataset = CustomizedCIFAR10SSL(
        root, lbl_idx, train=True, transform=transform_train,
        pseudo_idx=pseudo_idx, pseudo_target=pseudo_target,
        nl_idx=nl_idx, nl_mask=nl_mask)
    
    if nl_idx is not None:
        train_nl_dataset = CIFAR10SSL(
            root, np.array(nl_idx), train=True, transform=transform_train,
            pseudo_idx=pseudo_idx, pseudo_target=pseudo_target,
            nl_idx=nl_idx, nl_mask=nl_mask)

    return train_lbl_dataset


class CustomizedCIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=True, pseudo_idx=None, pseudo_target=None,
                 nl_idx=None, nl_mask=None):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        
        self.targets = np.array(self.targets)
        self.nl_mask = np.ones((len(self.targets), len(np.unique(self.targets))))
        
        if nl_mask is not None:
            self.nl_mask[nl_idx] = nl_mask

        if pseudo_target is not None:
            self.targets[pseudo_idx] = pseudo_target

        if indexs is not None:
            indexs = np.array(indexs)
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
            self.nl_mask = np.array(self.nl_mask)[indexs]
            self.indexs = indexs
        else:
            self.indexs = np.arange(len(self.targets))
        

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

In [23]:
from conformal_classification.utils import validate, get_logits_targets, sort_sum
from scipy.special import softmax

In [24]:
class ConformalModel(nn.Module):
    def __init__(self, model, calib_loader, alpha, kreg=None, lamda=None, randomized=True, allow_zero_sets=False, pct_paramtune=0.3, batch_size=32, lamda_criterion='size'):
        super(ConformalModel, self).__init__()
        self.model = model
        self.alpha = alpha
        # initialize (1.3 is usually a good value)
        self.T = torch.Tensor([1.3])
        self.T, calib_logits = platt(self, calib_loader)
        self.randomized = randomized
        self.allow_zero_sets = allow_zero_sets
        self.num_classes = len(calib_loader.dataset.classes)

        if kreg == None or lamda == None:
            kreg, lamda, calib_logits = pick_parameters(
                model, calib_logits, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion)

        self.penalties = np.zeros((1, self.num_classes))
        self.penalties[:, kreg:] += lamda

        calib_loader = DataLoader(
            calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)

        self.Qhat = conformal_calibration_logits(self, calib_loader)

    def forward(self, *args, randomized=None, allow_zero_sets=None, **kwargs):
        if randomized == None:
            randomized = self.randomized
        if allow_zero_sets == None:
            allow_zero_sets = self.allow_zero_sets
        logits = self.model(*args, **kwargs)

        with torch.no_grad():
            logits_numpy = logits.detach().cpu().numpy()
            scores = softmax(logits_numpy/self.T.item(), axis=1)

            I, ordered, cumsum = sort_sum(scores)

            S = gcq(scores, self.Qhat, I=I, ordered=ordered, cumsum=cumsum,
                    penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets)

        return logits, S

# Computes the conformal calibration


def conformal_calibration(cmodel, calib_loader):
    print("Conformal calibration")
    with torch.no_grad():
        E = np.array([])
        for x, targets in tqdm(calib_loader):
            logits = cmodel.model(x.cuda()).detach().cpu().numpy()
            scores = softmax(logits/cmodel.T.item(), axis=1)

            I, ordered, cumsum = sort_sum(scores)

            E = np.concatenate((E, giq(scores, targets, I=I, ordered=ordered, cumsum=cumsum,
                                       penalties=cmodel.penalties, randomized=True, allow_zero_sets=True)))

        Qhat = np.quantile(E, 1-cmodel.alpha, interpolation='higher')

        return Qhat

# Temperature scaling


def platt(cmodel, calib_loader, max_iters=10, lr=0.01, epsilon=0.01):
    print("Begin Platt scaling.")
    # Save logits so don't need to double compute them
    logits_dataset = get_logits_targets(cmodel.model, calib_loader)
    logits_loader = torch.utils.data.DataLoader(
        logits_dataset, batch_size=calib_loader.batch_size, shuffle=False, pin_memory=True)

    T = platt_logits(cmodel, logits_loader,
                     max_iters=max_iters, lr=lr, epsilon=epsilon)

    print(f"Optimal T={T.item()}")
    return T, logits_dataset


"""


        INTERNAL FUNCTIONS


"""

# Precomputed-logit versions of the above functions.


class ConformalModelLogits(nn.Module):
    def __init__(self, model, calib_loader, alpha, kreg=None, lamda=None, randomized=True, allow_zero_sets=False, naive=False, LAC=False, pct_paramtune=0.3, batch_size=32, lamda_criterion='size'):
        super(ConformalModelLogits, self).__init__()
        self.model = model
        self.alpha = alpha
        self.randomized = randomized
        self.LAC = LAC
        self.allow_zero_sets = allow_zero_sets
        self.T = platt_logits(self, calib_loader)

        if (kreg == None or lamda == None) and not naive and not LAC:
            kreg, lamda, calib_logits = pick_parameters(
                model, calib_loader.dataset, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion)
            calib_loader = DataLoader(
                calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)

        self.penalties = np.zeros((1, calib_loader.dataset[0][0].shape[0]))
        if not (kreg == None) and not naive and not LAC:
            self.penalties[:, kreg:] += lamda
        self.Qhat = 1-alpha
        if not naive and not LAC:
            self.Qhat = conformal_calibration_logits(self, calib_loader)
        elif not naive and LAC:
            gt_locs_cal = np.array([np.where(np.argsort(x[0]).flip(dims=(0,)) == x[1])[
                                   0][0] for x in calib_loader.dataset])
            scores_cal = 1-np.array([np.sort(torch.softmax(calib_loader.dataset[i][0]/self.T.item(), dim=0))[
                                    ::-1][gt_locs_cal[i]] for i in range(len(calib_loader.dataset))])
            self.Qhat = np.quantile(scores_cal, np.ceil(
                (scores_cal.shape[0]+1) * (1-alpha)) / scores_cal.shape[0])

    def forward(self, logits, randomized=None, allow_zero_sets=None):
        if randomized == None:
            randomized = self.randomized
        if allow_zero_sets == None:
            allow_zero_sets = self.allow_zero_sets

        with torch.no_grad():
            logits_numpy = logits.detach().cpu().numpy()
            scores = softmax(logits_numpy/self.T.item(), axis=1)

            if not self.LAC:
                I, ordered, cumsum = sort_sum(scores)

                S = gcq(scores, self.Qhat, I=I, ordered=ordered, cumsum=cumsum,
                        penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets)
            else:
                S = [np.where((1-scores[i, :]) < self.Qhat)[0]
                     for i in range(scores.shape[0])]

        return logits, S


def conformal_calibration_logits(cmodel, calib_loader):
    with torch.no_grad():
        E = np.array([])
        for logits, targets in calib_loader:
            logits = logits.detach().cpu().numpy()

            scores = softmax(logits/cmodel.T.item(), axis=1)

            I, ordered, cumsum = sort_sum(scores)

            E = np.concatenate((E, giq(scores, targets, I=I, ordered=ordered, cumsum=cumsum,
                                       penalties=cmodel.penalties, randomized=True, allow_zero_sets=True)))

        Qhat = np.quantile(E, 1-cmodel.alpha, interpolation='higher')

        return Qhat


def platt_logits(cmodel, calib_loader, max_iters=10, lr=0.01, epsilon=0.01):
    nll_criterion = nn.CrossEntropyLoss().cuda()

    T = nn.Parameter(torch.Tensor([1.3]).cuda())

    optimizer = optim.SGD([T], lr=lr)
    for iter in range(max_iters):
        T_old = T.item()
        for x, targets in calib_loader:
            optimizer.zero_grad()
            x = x.cuda()
            x.requires_grad = True
            out = x/T
            loss = nll_criterion(out, targets.long().cuda())
            loss.backward()
            optimizer.step()
        if abs(T_old - T.item()) < epsilon:
            break
    return T

# CORE CONFORMAL INFERENCE FUNCTIONS

# Generalized conditional quantile function.


def gcq(scores, tau, I, ordered, cumsum, penalties, randomized, allow_zero_sets):
    penalties_cumsum = np.cumsum(penalties, axis=1)
    sizes_base = ((cumsum + penalties_cumsum) <=
                  tau).sum(axis=1) + 1  # 1 - 1001
    sizes_base = np.minimum(sizes_base, scores.shape[1])  # 1-1000

    if randomized:
        V = np.zeros(sizes_base.shape)
        for i in range(sizes_base.shape[0]):
            V[i] = 1/ordered[i, sizes_base[i]-1] * \
                (tau-(cumsum[i, sizes_base[i]-1]-ordered[i, sizes_base[i]-1]) -
                 penalties_cumsum[0, sizes_base[i]-1])  # -1 since sizes_base \in {1,...,1000}.

        sizes = sizes_base - (np.random.random(V.shape) >= V).astype(int)
    else:
        sizes = sizes_base

    if tau == 1.0:
        # always predict max size if alpha==0. (Avoids numerical error.)
        sizes[:] = cumsum.shape[1]

    if not allow_zero_sets:
        # allow the user the option to never have empty sets (will lead to incorrect coverage if 1-alpha < model's top-1 accuracy
        sizes[sizes == 0] = 1

    S = list()

    # Construct S from equation (5)
    for i in range(I.shape[0]):
        S = S + [I[i, 0:sizes[i]], ]

    return S

# Get the 'p-value'


def get_tau(score, target, I, ordered, cumsum, penalty, randomized, allow_zero_sets):  # For one example
    idx = np.where(I == target)
    tau_nonrandom = cumsum[idx]

    if not randomized:
        return tau_nonrandom + penalty[0]

    U = np.random.random()

    if idx == (0, 0):
        if not allow_zero_sets:
            return tau_nonrandom + penalty[0]
        else:
            return U * tau_nonrandom + penalty[0]
    else:
        return U * ordered[idx] + cumsum[(idx[0], idx[1]-1)] + (penalty[0:(idx[1][0]+1)]).sum()

# Gets the histogram of Taus.


def giq(scores, targets, I, ordered, cumsum, penalties, randomized, allow_zero_sets):
    """
        Generalized inverse quantile conformity score function.
        E from equation (7) in Romano, Sesia, Candes.  Find the minimum tau in [0, 1] such that the correct label enters.
    """
    E = -np.ones((scores.shape[0],))
    for i in range(scores.shape[0]):
        E[i] = get_tau(scores[i:i+1, :], targets[i].item(), I[i:i+1, :], ordered[i:i+1, :],
                       cumsum[i:i+1, :], penalties[0, :], randomized=randomized, allow_zero_sets=allow_zero_sets)

    return E

# AUTOMATIC PARAMETER TUNING FUNCTIONS


def pick_kreg(paramtune_logits, alpha):
    gt_locs_kstar = np.array([np.where(np.argsort(x[0]).flip(dims=(0,)) == x[1])[
                             0][0] for x in paramtune_logits])
    kstar = np.quantile(gt_locs_kstar, 1-alpha, interpolation='higher') + 1
    return kstar


def pick_lamda_size(model, paramtune_loader, alpha, kreg, randomized, allow_zero_sets):
    # Calculate lamda_star
    best_size = iter(paramtune_loader).__next__()[
        0][1].shape[0]  # number of classes
    # Use the paramtune data to pick lamda.  Does not violate exchangeability.
    # predefined grid, change if more precision desired.
    for temp_lam in [0.001, 0.01, 0.1, 0.2, 0.5]:
        conformal_model = ConformalModelLogits(model, paramtune_loader, alpha=alpha, kreg=kreg,
                                               lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, naive=False)
        top1_avg, top5_avg, cvg_avg, sz_avg = validate(
            paramtune_loader, conformal_model, print_bool=False)
        if sz_avg < best_size:
            best_size = sz_avg
            lamda_star = temp_lam
    return lamda_star


def pick_lamda_adaptiveness(model, paramtune_loader, alpha, kreg, randomized, allow_zero_sets, strata=[[0, 1], [2, 3], [4, 6], [7, 10], [11, 100], [101, 1000]]):
    # Calculate lamda_star
    lamda_star = 0
    best_violation = 1
    # Use the paramtune data to pick lamda.  Does not violate exchangeability.
    # predefined grid, change if more precision desired.
    for temp_lam in [0, 1e-5, 1e-4, 8e-4, 9e-4, 1e-3, 1.5e-3, 2e-3]:
        conformal_model = ConformalModelLogits(model, paramtune_loader, alpha=alpha, kreg=kreg,
                                               lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, naive=False)
        curr_violation = get_violation(
            conformal_model, paramtune_loader, strata, alpha)
        if curr_violation < best_violation:
            best_violation = curr_violation
            lamda_star = temp_lam
    return lamda_star


def pick_parameters(model, calib_logits, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion):
    num_paramtune = int(np.ceil(pct_paramtune * len(calib_logits)))
    paramtune_logits, calib_logits = random_split(
        calib_logits, [num_paramtune, len(calib_logits)-num_paramtune])
    calib_loader = DataLoader(
        calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)
    paramtune_loader = DataLoader(
        calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)

    if kreg == None:
        kreg = pick_kreg(paramtune_logits, alpha)
    if lamda == None:
        if lamda_criterion == "size":
            lamda = pick_lamda_size(
                model, paramtune_loader, alpha, kreg, randomized, allow_zero_sets)
        elif lamda_criterion == "adaptiveness":
            lamda = pick_lamda_adaptiveness(
                model, paramtune_loader, alpha, kreg, randomized, allow_zero_sets)
    return kreg, lamda, calib_logits


def get_violation(cmodel, loader_paramtune, strata, alpha):
    df = pd.DataFrame(columns=['size', 'correct'])
    for logit, target in loader_paramtune:
        # compute output
        # This is a 'dummy model' which takes logits, for efficiency.
        output, S = cmodel(logit)
        # measure accuracy and record loss
        size = np.array([x.size for x in S])
        I, _, _ = sort_sum(logit.numpy())
        correct = np.zeros_like(size)
        for j in range(correct.shape[0]):
            correct[j] = int(target[j] in list(S[j]))
        batch_df = pd.DataFrame({'size': size, 'correct': correct})
        df = df.append(batch_df, ignore_index=True)
    wc_violation = 0
    for stratum in strata:
        temp_df = df[(df['size'] >= stratum[0]) & (df['size'] <= stratum[1])]
        if len(temp_df) == 0:
            continue
        stratum_violation = abs(temp_df.correct.mean()-(1-alpha))
        wc_violation = max(wc_violation, stratum_violation)
    return wc_violation  # the violation

In [25]:
cudnn.benchmark = True

In [26]:
run_started = datetime.today().strftime('%d-%m-%y_%H%M') #start time to create unique experiment name
parser = argparse.ArgumentParser(description='UPS Training')
parser.add_argument('--out', default=f'outputs', help='directory to output the result')
parser.add_argument('--gpu-id', default='0', type=int,
                    help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--num-workers', type=int, default=8,
                    help='number of workers')
parser.add_argument('--dataset', default='cifar10', type=str,
                    choices=['cifar10', 'cifar100'],
                    help='dataset names')
parser.add_argument('--n-lbl', type=int, default=4000,
                    help='number of labeled data')
parser.add_argument('--arch', default='cnn13', type=str,
                    choices=['wideresnet', 'cnn13', 'shakeshake'],
                    help='architecture name')
parser.add_argument('--iterations', default=20, type=int,
                    help='number of total pseudo-labeling iterations to run')
parser.add_argument('--epchs', default=1024, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--batchsize', default=128, type=int,
                    help='train batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                    help='initial learning rate, default 0.03')
parser.add_argument('--warmup', default=0, type=float,
                    help='warmup epochs (unlabeled data based)')
parser.add_argument('--wdecay', default=5e-4, type=float,
                    help='weight decay')
parser.add_argument('--nesterov', action='store_true', default=True,
                    help='use nesterov momentum')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', type=int, default=-1,
                    help="random seed (-1: don't use random seed)")
parser.add_argument('--no-progress', action='store_true',
                    help="don't use progress bar")
parser.add_argument('--dropout', default=0.3, type=float,
                    help='dropout probs')
parser.add_argument('--num-classes', default=10, type=int,
                    help='total classes')
parser.add_argument('--class-blnc', default=10, type=int,
                    help='total number of class balanced iterations')
parser.add_argument('--tau-p', default=0.70, type=float,
                    help='confidece threshold for positive pseudo-labels, default 0.70')
parser.add_argument('--tau-n', default=0.05, type=float,
                    help='confidece threshold for negative pseudo-labels, default 0.05')
parser.add_argument('--kappa-p', default=0.05, type=float,
                    help='uncertainty threshold for positive pseudo-labels, default 0.05')
parser.add_argument('--kappa-n', default=0.005, type=float,
                    help='uncertainty threshold for negative pseudo-labels, default 0.005')
parser.add_argument('--temp-nl', default=2.0, type=float,
                    help='temperature for generating negative pseduo-labels, default 2.0')
parser.add_argument('--no-uncertainty', action='store_true',
                    help='use uncertainty in the pesudo-label selection, default true')
parser.add_argument('--split-txt', default='run1', type=str,
                    help='extra text to differentiate different experiments. it also creates a new labeled/unlabeled split')
parser.add_argument('--model-width', default=2, type=int,
                    help='model width for WRN-28')
parser.add_argument('--model-depth', default=28, type=int,
                    help='model depth for WRN')
parser.add_argument('--test-freq', default=10, type=int,
                    help='frequency of evaluations')


options = ("--dataset", "cifar10", "--n-lbl", "1000", "--class-blnc", "7", "--split-txt", "run1", "--arch", "cnn13")
args = parser.parse_args(options)
#print key configurations
print('########################################################################')
print('########################################################################')
print(f'dataset:                                  {args.dataset}')
print(f'number of labeled samples:                {args.n_lbl}')
print(f'architecture:                             {args.arch}')
print(f'number of pseudo-labeling iterations:     {args.iterations}')
print(f'number of epochs:                         {args.epchs}')
print(f'batch size:                               {args.batchsize}')
print(f'lr:                                       {args.lr}')
print(f'value of tau_p:                           {args.tau_p}')
print(f'value of tau_n:                           {args.tau_n}')
print(f'value of kappa_p:                         {args.kappa_p}')
print(f'value of kappa_n:                         {args.kappa_n}')
print('########################################################################')
print('########################################################################')

DATASET_GETTERS = {'cifar10': get_cifar10, 'cifar100': get_cifar100}
exp_name = f'exp_{args.dataset}_{args.n_lbl}_{args.arch}_{args.split_txt}_{args.epchs}_{args.class_blnc}_{args.tau_p}_{args.tau_n}_{args.kappa_p}_{args.kappa_n}_{run_started}'
# device = torch.device('cuda', args.gpu_id)
device = torch.device('cpu')
args.device = device
args.exp_name = exp_name
args.dtype = torch.float32
if args.seed != -1:
    set_seed(args)
args.out = os.path.join(args.out, args.exp_name)
start_itr = 0

if args.resume and os.path.isdir(args.resume):
    resume_files = os.listdir(args.resume)
    resume_itrs = [int(item.replace('.pkl','').split("_")[-1]) for item in resume_files if 'pseudo_labeling_iteration' in item]
    if len(resume_itrs) > 0:
        start_itr = max(resume_itrs)
    args.out = args.resume
os.makedirs(args.out, exist_ok=True)
writer = SummaryWriter(args.out)

if args.dataset == 'cifar10':
    args.num_classes = 10
elif args.dataset == 'cifar100':
    args.num_classes = 100

for itr in range(start_itr, args.iterations):
    if itr == 0 and args.n_lbl < 4000: #use a smaller batchsize to increase the number of iterations
        args.batch_size = 64
        args.epochs = 1024
    else:
        args.batch_size = args.batchsize
        args.epochs = args.epchs

    if os.path.exists(f'data/splits/{args.dataset}_basesplit_{args.n_lbl}_{args.split_txt}.pkl'):
        lbl_unlbl_split = f'data/splits/{args.dataset}_basesplit_{args.n_lbl}_{args.split_txt}.pkl'
    else:
        lbl_unlbl_split = None
    
    #load the saved pseudo-labels
    if itr > 0:
        pseudo_lbl_dict = f'{args.out}/pseudo_labeling_iteration_{str(itr)}.pkl'
    else:
        pseudo_lbl_dict = None
    
    lbl_dataset, nl_dataset, unlbl_dataset, test_dataset = DATASET_GETTERS[args.dataset]('data/datasets', args.n_lbl,
                                                            lbl_unlbl_split, pseudo_lbl_dict, itr, args.split_txt)

    model = create_model(args)
    model.to(args.device)

    nl_batchsize = int((float(args.batch_size) * len(nl_dataset))/(len(lbl_dataset) + len(nl_dataset)))

    if itr == 0:
        lbl_batchsize = args.batch_size
        args.iteration = len(lbl_dataset) // args.batch_size
    else:
        lbl_batchsize = args.batch_size - nl_batchsize
        args.iteration = (len(lbl_dataset) + len(nl_dataset)) // args.batch_size

    lbl_loader = DataLoader(
        lbl_dataset,
        sampler=RandomSampler(lbl_dataset),
        batch_size=lbl_batchsize,
        num_workers=args.num_workers,
        drop_last=True)

    nl_loader = DataLoader(
        nl_dataset,
        sampler=RandomSampler(nl_dataset),
        batch_size=nl_batchsize,
        num_workers=args.num_workers,
        drop_last=True)

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)
    
    unlbl_loader = DataLoader(
        unlbl_dataset,
        sampler=SequentialSampler(unlbl_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=args.nesterov)
    args.total_steps = args.epochs * args.iteration
    scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup * args.iteration, args.total_steps)
    start_epoch = 0

    if args.resume and itr == start_itr and os.path.isdir(args.resume):
        resume_itrs = [int(item.replace('.pth.tar','').split("_")[-1]) for item in resume_files if 'checkpoint_iteration_' in item]
        if len(resume_itrs) > 0:
            checkpoint_itr = max(resume_itrs)
            resume_model = os.path.join(args.resume, f'checkpoint_iteration_{checkpoint_itr}.pth.tar')
            if os.path.isfile(resume_model) and checkpoint_itr == itr:
                checkpoint = torch.load(resume_model)
                best_acc = checkpoint['best_acc']
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])

    model.zero_grad()
    best_acc = 0

    args.epochs = 1
    for epoch in range(start_epoch, args.epochs):
        if itr == 0:
            train_loss = train_initial(args, lbl_loader, model, optimizer, scheduler, epoch, itr)
        else:
            train_loss = train_regular(args, lbl_loader, nl_loader, model, optimizer, scheduler, epoch, itr)

        test_loss = 0.0
        test_acc = 0.0
        test_model = model
        if epoch > (args.epochs+1)/2 and epoch%args.test_freq==0:
            test_loss, test_acc = test(args, test_loader, test_model)
        elif epoch == (args.epochs-1):
            test_loss, test_acc = test(args, test_loader, test_model)

        writer.add_scalar('train/1.train_loss', train_loss, (itr*args.epochs)+epoch)
        writer.add_scalar('test/1.test_acc', test_acc, (itr*args.epochs)+epoch)
        writer.add_scalar('test/2.test_loss', test_loss, (itr*args.epochs)+epoch)

        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        model_to_save = model.module if hasattr(model, "module") else model
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model_to_save.state_dict(),
            'acc': test_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }, is_best, args.out, f'iteration_{str(itr)}')

    checkpoint = torch.load(f'{args.out}/checkpoint_iteration_{str(itr)}.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    model.zero_grad()

    ## create calibration holdout data
    lbl_dataset_custom = get_cifar10_customized('data/datasets', args.n_lbl, lbl_unlbl_split, pseudo_lbl_dict, itr, args.split_txt)
    calib_lbl_loader = DataLoader(
        lbl_dataset_custom,
        sampler=RandomSampler(lbl_dataset_custom),
        batch_size=lbl_batchsize,
        num_workers=args.num_workers,
        drop_last=True)
    conformal_model = ConformalModel(model, calib_lbl_loader, alpha=0.1, lamda_criterion='size') # conformal model for generating pseudo-labels

    #pseudo-label generation and selection
    pl_loss, pl_acc, pl_acc_pos, total_sel_pos, pl_acc_neg, total_sel_neg, unique_sel_neg, pseudo_label_dict = pseudo_labeling(args, unlbl_loader, model, itr)

    writer.add_scalar('pseudo_labeling/1.regular_loss', pl_loss, itr)
    writer.add_scalar('pseudo_labeling/2.regular_acc', pl_acc, itr)
    writer.add_scalar('pseudo_labeling/3.pseudo_acc_positive', pl_acc_pos, itr)
    writer.add_scalar('pseudo_labeling/4.total_sel_positive', total_sel_pos, itr)
    writer.add_scalar('pseudo_labeling/5.pseudo_acc_negative', pl_acc_neg, itr)
    writer.add_scalar('pseudo_labeling/6.total_sel_negative', total_sel_neg, itr)
    writer.add_scalar('pseudo_labeling/7.unique_samples_negative', unique_sel_neg, itr)

    with open(os.path.join(args.out, f'pseudo_labeling_iteration_{str(itr+1)}.pkl'),"wb") as f:
        pickle.dump(pseudo_label_dict,f)
    
    with open(os.path.join(args.out, 'log.txt'), 'a+') as ofile:
        ofile.write(f'############################# PL Iteration: {itr+1} #############################\n')
        ofile.write(f'Last Test Acc: {test_acc}, Best Test Acc: {best_acc}\n')
        ofile.write(f'PL Acc (Positive): {pl_acc_pos}, Total Selected (Positive): {total_sel_pos}\n')
        ofile.write(f'PL Acc (Negative): {pl_acc_neg}, Total Selected (Negative): {total_sel_neg}, Unique Negative Samples: {unique_sel_neg}\n\n')

writer.close()

Pseudo-Labeling Iter:   81/ 766. Data: 0.012s. Batch: 0.168s. Loss: 2.2511. top1: 0.00. top5: 0.00. :  11%|█         | 81/766 [00:13<01:51,  6.14it/s]

KeyboardInterrupt: ignored