In [1]:
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn.metrics import confusion_matrix
import numpy as np
import models
import torch
from utils import accuracy
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [2]:
model = models.__dict__['resnet32'](num_classes=10, use_norm=False)
checkpoint = torch.load('checkpoint/cifar10_resnet32_CE_None_exp_0.01_0/ckpt.best.pth.tar', map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [3]:

def validate(val_loader, model, flag='val'):
    # switch to evaluate mode

    model.eval()
    model = model.cuda(0)
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            if 0 is not None:
                input = input.cuda(0, non_blocking=True)
            target = target.cuda(0, non_blocking=True)

            # compute output
            output = model(input)
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            # measure elapsed time

            _, pred = torch.max(output, 1)
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        cf = confusion_matrix(all_targets, all_preds).astype(float)
        cls_cnt = cf.sum(axis=1)
        cls_hit = np.diag(cf)
        cls_acc = cls_hit / cls_cnt
        output = ('{flag} Results: Prec@1 {acc1:.3f} Prec@5 {acc5:.3f}'
                  .format(flag=flag, acc1=acc1.item(), acc5=acc5.item()))
        out_cls_acc = '%s Class Accuracy: %s' % (
            flag, (np.array2string(cls_acc, separator=',', formatter={'float_kind': lambda x: "%.3f" % x})))
        print(output)
        print(out_cls_acc)

    return cf, acc1

In [4]:
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=100, shuffle=False,
        num_workers=4, pin_memory=True)

cf, _ = validate(val_loader, model)


Files already downloaded and verified
val Results: Prec@1 57.000 Prec@5 92.000
val Class Accuracy: [0.914,0.791,0.748,0.778,0.829,0.253,0.363,0.207,0.315,0.126]


In [5]:
a = np.copy(cf)

In [6]:
a = a - np.diag(a) * np.identity(a.shape[0])

In [7]:
temp = np.argmax(a[-3:,:], axis=1)

In [9]:
a[-3:,:]

array([[ 23.,   0.,  78., 211., 451.,  27.,   2.,   0.,   1.,   0.],
       [524.,   9.,  29.,  66.,  48.,   0.,   9.,   0.,   0.,   0.],
       [296., 155.,  45., 158., 161.,   5.,   3.,   2.,  49.,   0.]])

In [9]:
tails = np.asarray([78, 79, 88, 89, 98, 99])

In [10]:
temp = temp[tails-70]

In [11]:
temp

array([18, 44, 50,  8,  2, 61])

In [18]:
tail_to_head = {'cifar10': {7:4, 8:0, 9:0},'cifar100':{78:18, 79:44, 88:50, 89:8, 98:2, 99:61}}
new_tail_to_head = {'cifar10': {8:0},'cifar100':{78:18, 79:44, 88:50, 89:8, 98:2, 99:61}}
t_to_h_list = []
model = models.__dict__['resnet32'](num_classes=10, use_norm=False)
checkpoint = torch.load('checkpoint/cifar10_resnet32_CE_None_exp_0.1_0/ckpt.best.pth.tar', map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.cuda(0)
for tail in new_tail_to_head['cifar10'].keys():
    train = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
#     tail = 8
    head= new_tail_to_head['cifar10'][tail]
    print(head, tail)
    idx = [i for i, item in enumerate(train.targets) if item==head]
    idx = np.array(idx)
    train = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train.targets = [train.targets[i] for i in idx]
    train.data = [train.data[i] for i in idx]
    print(len(train.data))
    loader = torch.utils.data.DataLoader(
        train, batch_size=100, shuffle=False,
        num_workers=4, pin_memory=True)
    preds = []

    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            if 0 is not None:
                input = input.cuda(0, non_blocking=True)
            target = target.cuda(0, non_blocking=True)
            # compute output
            output = model(input).cpu().numpy()[:, tail]
            preds.append(output)
    preds = np.array(preds)
    print(preds.shape)
    preds = np.hstack(preds)
    print(preds.shape)
    inds = np.flip(np.argsort(preds))
    arr = preds[inds]
    all_list = np.stack([inds, arr])
    t_to_h_list.append({'head':head, 'tail':tail, 'samples':all_list})
import pickle
with open('./data/new_cifar10_resnet32_CE_None_exp_0.1_0.pickle', 'wb') as f:
    pickle.dump(t_to_h_list, f)


Files already downloaded and verified
0 8
Files already downloaded and verified
1000
(10, 100)
(1000,)
