In [6]:
import torch
from modules import DefaultModule
from risk_control import risk_control


lm = DefaultModule.load_from_checkpoint("/home/jiajie/Code/private/Selective/SelectiveNet/ckpt/CIFAR-100-default/epoch=04-val_acc@1=24.07-05-31-02:49.ckpt", strict=False)
net = lm.model
net.eval();

In [85]:
from dataloaders import *
from torch.utils.data import DataLoader
from PIL import Image
import torchvision
transform = torchvision.transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
transform_x = lambda X: transform(torch.Tensor(X.transpose(0, 3, 1, 2)/255))

valset = cifar100(split='valid')
x_val, x_test = valset.data[:5000], valset.data[5000:]
y_val, y_test = valset.targets[:5000], valset.targets[5000:]

x_val, x_test = transform_x(x_val), transform_x(x_test)
y_val, y_test = torch.LongTensor(y_val), torch.LongTensor(y_test)
#valset.transform(Image.fromarray(valset.data[0]))

Files already downloaded and verified


In [80]:
y_pred = net(x_val)

In [98]:
residuals

tensor([ True,  True,  True,  ...,  True,  True, False])

In [145]:
kappa = torch.max(y_pred,1)[1]
residuals = (torch.argmax(y_pred, 1)[1] != y_val)
bound_cal = risk_control()
[theta, b_star] = bound_cal.bound(0.15, 0.001, kappa, residuals)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,

In [96]:
import pickle
import numpy as np
from scipy.stats import binom
import scipy
import math
from scipy.optimize import fsolve
import random

class risk_control:

    def calculate_bound(self,delta,m,erm):
        #This function is a solver for the inverse of binomial CDF based on binary search.
        precision = 1e-7
        def func(b):
            return (-1*delta) + scipy.stats.binom.cdf(int(m*erm),m,b)
        a=erm #start binary search from the empirical risk
        c=1   # the upper bound is 1
        b = (a+c)/2 #mid point
        funcval  =func(b)
        while abs(funcval)>precision:
            if a == 1.0 and c == 1.0:
                b = 1.0
                break
            elif funcval>0:
                a=b
            else:
                c=b
            b = (a + c) / 2
            funcval = func(b)
        return b

    def bound(self,rstar,delta,kappa,residuals,split=True):
        # A function to calculate the risk bound proposed in the paper, the algorithm is based on algorithm 1 from the paper.
        #Input: rstar - the requested risk bound
        #       delta - the desired delta
        #       kappa - rating function over the points (higher values is more confident prediction)
        #       residuals - a vector of the residuals of the samples 0 is correct prediction and 1 corresponding to an error
        #       split - is a boolean controls whether to split train and test
        #Output - [theta, bound] (also prints latex text for the tables in the paper)

        # when spliting to train and test this represents the fraction of the validation size
        valsize = 0.5

        probs = kappa
        FY = residuals


        if split:
            idx = list(range(len(FY)))
            print(idx)
            random.shuffle(idx)
            slice = round(len(FY)*(1-valsize))
            print(slice)
            print(FY)
            print(idx[slice:])
            FY_val = FY[idx[slice:]]
            probs_val = probs[idx[slice:]]
            FY = FY[idx[:slice]]
            probs = probs[idx[:slice]]
        m = len(FY)

        probs_idx_sorted = np.argsort(probs)

        a=0
        b = m-1
        deltahat = delta/math.ceil(math.log2(m))

        for q in range(math.ceil(math.log2(m))+1):
            # the for runs log(m)+1 iterations but actually the bound calculated on only log(m) different candidate thetas
            mid = math.ceil((a+b)/2)

            mi = len(FY[probs_idx_sorted[mid:]])
            theta = probs[probs_idx_sorted[mid]]
            risk = sum(FY[probs_idx_sorted[mid:]])/mi
            if split:
                testrisk = sum(FY_val[probs_val>=theta])/len(FY_val[probs_val>=theta])
                testcov = len(FY_val[probs_val>=theta])/len(FY_val)
            bound = self.calculate_bound(deltahat,mi,risk)
            coverage = mi/m
            if bound>rstar:
                a=mid
            else:
                b=mid

        if split:
            print("%.2f & %.4f & %.4f & %.4f & %.4f & %.4f  \\\\" % (rstar,risk,coverage,testrisk,testcov,bound))
        else:
            print("%.2f & %.4f & %.4f & %.4f   \\\\" % (rstar,risk,coverage,bound))
        return [theta,bound]


In [138]:
valsize = 0.5

probs = kappa
FY = residuals

idx = list(range(len(FY)))
random.shuffle(idx)
slice = round(len(FY)*(1-valsize))
FY_val = FY[idx[slice:]]
probs_val = probs[idx[slice:]]
FY = FY[idx[:slice]]
probs = probs[idx[:slice]]

In [140]:
probs.shape

torch.Size([2500])