In [1]:
import torch
import numpy as np
from scipy.spatial import distance
import os

import sys
sys.path.append('..')
import registry
import datafree
%matplotlib inline
%config InlineBackend.figure_format = 'pdf'

In [2]:
!nvidia-smi

Sun Aug 21 02:19:16 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.60.02    Driver Version: 510.60.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:05:00.0 Off |                  N/A |
| 30%   33C    P0    94W / 350W |      0MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:09:00.0 Off |                  N/A |
| 30%   35C    P0   103W / 350W |      0MiB / 24576MiB |      0%      Defaul

In [3]:
distributed = False
gpu = 5
# gpu ='0,1'
batch_size = 128
workers = 8
# num_classes = 10
# num_classes = 100
num_classes = 200
def prepare_model(model):
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
        return model
    elif distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpu is not None:
#             torch.cuda.set_device(gpu)
            model.cuda()
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            batch_size = int(batch_size / 1)
            workers = int((workers + 1 - 1) / 1)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(x) for x in gpu.split(',')])
            return model
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            return model
    elif gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
        return model
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()
        return model

In [22]:
from torchvision.datasets import CIFAR10,CIFAR100
import datafree
import registry
from torch import nn
student = registry.get_model('resnet18_imagenet', num_classes=num_classes)
teacher = registry.get_model('resnet34_imagenet', num_classes=num_classes, pretrained=True).eval()
# student = registry.get_model('wrn40_1', num_classes=num_classes)
# student= registry.get_model('wrn16_2', num_classes=num_classes)
# teacher = registry.get_model('wrn40_2', num_classes=num_classes)
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['tiny_imagenet'])
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar100'])
student = prepare_model(student)

# teacher = teacher.to(gpu)
teacher.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, 200)
teacher.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
teacher.maxpool = nn.Sequential()
teacher = prepare_model(teacher)
ckpt = torch.load('../checkpoints/scratch/tiny_imagenet_resnet34_imagenet.pth', map_location='cpu')
dict_ckpt = dict()
for k, v in ckpt['state_dict'].items():
    dict_ckpt['.'.join(k.split('.')[1:])] = v
teacher.load_state_dict(dict_ckpt)
print(ckpt['best_acc1'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_resnet34.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_wrn40_2.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar100_wrn40_2.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar100_resnet34.pth', map_location='cpu')['state_dict'])
teacher.eval()


61.47


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Sequential()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [5]:
!ls /data/lijingru/DataFree/checkpoints/datafree-dafl/

cifar10-resnet34-resnet18--dafl_interval_a2.pth
cifar10-vgg11-resnet18--dafl_interval_2.pth
log-cifar10-resnet34-resnet18-dafl_interval_a2.txt
log-cifar10-vgg11-resnet18-dafl_interval_2.txt
synthetic-dafl_interval_2.png
synthetic-dafl_interval_a2.png


In [27]:
# generator = datafree.models.generator.LargeGenerator(nz=512, ngf=64, img_size=32, nc=3)
# generator = prepare_model(generator)
# ckpt = torch.load('/data/lijingru/CMI/checkpoints/datafree-cmi/cifar100-resnet34-resnet18.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-deepinv/cifar100-resnet34-resnet18--deepinv_100a.pth', map_location='cpu')
# ckpt = torch.load('/data/lijingru/CMI/checkpoints/datafree-dafl/cifar100-resnet34-resnet18.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-cmi/cifar10-resnet34-resnet18_adv_cmi.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-cudfkd/cifar100-resnet34-resnet18--cudfkd_L2_line33_e.pth', map_location='cpu')
ckpt = torch.load('../checkpoints/datafree-probkd/tiny_imagenet-resnet34_imagenet-resnet18_imagenet--probkd_L3_line7-R0.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-cudfkd/cifar10-wrn40_2-wrn16_2--cudfkd_L2_line93_agg_2.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-probkd/cifar10-resnet34-resnet18--probkd_L2_line66.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-deepinv/cifar10-resnet34-resnet18.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/scratch_i/cifar10_resnet18.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/scratch/cifar10_wrn16_2.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/scratch/cifar10_wrn40_1.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/scratch/cifar100_resnet18.pth', map_location='cpu')
# ckpt = torch.load('../checkpoints/datafree-cudfkd/cifar100-wrn40_2-wrn40_1--cudfkd_L2_line33_d.pth', map_location='cpu')
print(ckpt['best_acc1'])
dict_ckpt = dict()
for k, v in ckpt['state_dict'].items():
    dict_ckpt['.'.join(k.split('.')[1:])] = v
# G_ckpt = ckpt['G_0']

# student = prepare_model(student)
# student.avgpool = nn.AdaptiveAvgPool2d(1)
# num_ftrs = student.fc.in_features
# student.fc = nn.Linear(num_ftrs, 200)
# student.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
# student.maxpool = nn.Sequential()
student.load_state_dict(dict_ckpt)
student.eval()


43.419999999999995


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [24]:
import torchvision.transforms as T
from registry import NORMALIZE_DICT
from torchvision import datasets
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
# val_transform = T.Compose([
#     #T.Resize((224, 224), Image.BICUBIC),
#     T.ToTensor(),
#     T.Normalize( **NORMALIZE_DICT['tiny_imagenet'] ),
# ])
val_transform = T.Compose([
    T.ToTensor(),
    T.Normalize( **NORMALIZE_DICT['tiny_imagenet'] )]
) 

# val_dst = datasets.CIFAR10('/data/lijingru/cifar10/', train=False, download=True, transform=val_transform)
# val_dst = datasets.CIFAR100('/data/lijingru/cifar100/', train=False, download=True, transform=val_transform)
val_dst = datasets.ImageFolder(os.path.join('/data/lijingru/timagenet/tiny-imagenet-200', 'val_split'), transform=val_transform)

val_loader = torch.utils.data.DataLoader(val_dst, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

In [25]:
import tqdm
probs_s = []
probs_t = []
ys = []
for x, y in tqdm.tqdm(val_loader, desc='calculate_agg'):
    s_out = student(x.to(gpu))
    t_out = teacher(x.to(gpu))
    print(np.argmax(prob_t, 1), y.numpy())
#     print(torch.sum(s_out.argmax(1).detach().cpu() == y) / 128)
    prob_s = torch.softmax(s_out.detach(), 1).cpu().numpy()
    probs_s.append(prob_s)
    prob_t = torch.softmax(t_out.detach(), 1).cpu().numpy()
    probs_t.append(prob_t)
    ys.append(y.numpy())

probs_s = np.concatenate(probs_s, 0)
probs_t = np.concatenate(probs_t, 0)
ys = np.concatenate(ys, 0)

calculate_agg:   3%|▎         | 4/157 [00:00<00:18,  8.49it/s]

[189 187 185 199 131 185 189  38 189 110  38 130 199 176 168  15] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[  0   0   0   0 196   0   0   0   0   0  38  38   0 185  13  44  36   0
   0 176 102 115  36   0   0 151   0   0   0   0   0   0  38   0  38 161
   0  32   0   0   0   0   0   0 102 102  83   0   0  36 102 176 176 113
 141  44 176 176  38 176 176 113 113 176] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
[ 75  38  44  17  38  38 176 131 154  38  44 174 176 176   1 176 113 176
   8 176 176 113  44 176 176  44 176  39 168 156  64 156 196 176  44 176
  44  44 139  10 122 176 119   2   2  12  84 162 176   3   3  12  84  44
 196   5  92  92  44 185 102   2  39 122] [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
[176 152   2 199 19

calculate_agg:   7%|▋         | 11/157 [00:00<00:07, 18.90it/s]

[168 161  88   6 150   6  95  66 122   6 123   6 176   6  59  84 185 168
   6  12   6 199 111  66 164 102  16  12 106  16  88   7  17  74 156  10
  53 196   7 144 101  10  44   8  13  41  39   8   3  41   7 122 112 196
 139 132  41  41   7 176 156  76  88 102] [7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[  7   6  38 115  75   1 176   7 149  13 176 134  13 161  44 122   9  38
  44 168 156  17   8 122   8 155   8 156 168  96 176   8   8   8   8   8
 151  38  38 122  36 168 101 168 111   8   8  44 122   8 159   8   8 174
   8 135   8   8   8 168   8   8 168   8] [ 8  8  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9
  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9
  9  9  9  9 10 10 10 10 10 10 10 10 10 10 10 10]
[  8  64   9 156   9  41 130 176  37 122 149 176 156   9  41 159   3 176
 102   8 176 119 176  10 176 176 162  75  38 131  75  29  36  75 176  41
 122 

calculate_agg:  12%|█▏        | 19/157 [00:01<00:05, 24.86it/s]

[ 15 176 122  41 196 176   7  88 131  11 187 127  59 178  15  75  15 135
  36  41 199 185  37  41  36  41  44  15 160 176  17  39  84   3  39 122
 159  16 176 113  75  38 176 196  76 176  17 156 185 176  58 176 127  36
  38 176 188 176  41 176  99 156 122 180] [16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 17 17 17 17 17 17
 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17
 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17]
[114  16  38 113  16  10  76  44  10 156  41 199   2 185  36  38  10  13
 189  83  17 176  13  14  41 176 163  17  26 196  41  17  44  17 176 176
 102 176  93  13 102  42  17  93 168 113  17 113  17  17  10 130 188  17
 196  17 113 113  38  16  17  17 176 185] [17 17 17 17 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18
 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18
 18 18 18 18 18 18 19 19 19 19 19 19 19 19 19 19]
[ 17  36 196 196 176  75  18 176 112 156 176 176 176 131 189 131 176 168
  75 176 113 

calculate_agg:  15%|█▍        | 23/157 [00:01<00:05, 26.54it/s]

[ 26 148 168  26 176 129  24  32  24 176  24  75 184 156  74  99  74  26
  75  99  24  29 156  75 156 176 131 156  24  24  27  75 156 102 156 131
  25  25  41  25  17  75  25 156 102 168  25 112  25 115 156 102  38  25
  25  25  41 102  30  25 161 102  25 199] [25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26]
[196  38  38  38  25 168  36 102 176 176 102 156 156  75 168 131 176 112
  25 102  26  26  26  26  24  26  27  26  26 148  38  26  12  26 127 159
  26  26  26  26  26  26  26 161  26  26  75 108  23  97  26  35  26  26
  26  26  26  26  26 150  26 102  26 156] [26 26 26 26 26 26 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27
 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27
 27 27 27 27 27 27 27 27 28 28 28 28 28 28 28 28]
[ 26  27  12  26  27 102  26  27  27 106  26  75  51  26 176  25  47  75
  27  36 131 

calculate_agg:  20%|█▉        | 31/157 [00:01<00:04, 28.51it/s]

[176  33  24 176  26 176  91 161 161  33 176 161  33  76 136  33  94 102
  12 156  12 196  58  12 156  49 102  33  13  23  12 122  33 188  12 102
  33  38  47  34  12 185  76  26   1  35  34 131   3 145  96  93 102  34
 176 105  54  34  38  38  58   7 185 106] [34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 34 35 35
 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35
 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35]
[102  39  34  35  34  26   1  34  33 122   3  34 129 176 122  99  35 106
  26  25 176 102  35  35  26 113 161 105 102  26  35  56  75 161  26  26
  35 102  35 102 139 102  35 161  54 102 102  26  35 176  35 199 102  35
  75 168 102  23 105  75 168 102  54  35] [35 35 35 35 35 35 35 35 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
 36 36 36 36 36 36 36 36 36 36 37 37 37 37 37 37]
[127  26 161  35  75  58  35  35  36  36  36  36  36  36  36  36  36 187
 113  44  36 

calculate_agg:  25%|██▍       | 39/157 [00:01<00:03, 29.51it/s]

[  1  42  42 188  39 135 145 142  75 113 176  42  42  39  40 163  37  36
  39  42  42  42  39  39  40  42 152  38  21  43  42  39  42 176 175  92
  99  42  99  36  43 113 176  44  42  43 122  44  44 152  41 176  43 176
  43 113  39  44 168 112 189 168 185  36] [43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43
 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44
 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44 44]
[ 43  39  38  44  43  36 176   8  44  44  43 113 188  41 163  95  40  36
  44  43 122  43  44 156  44  44  44 176  44  44  44  44  44  44  44 176
  44  44 113  44  44  44 176  44  44  44  44  44 176  44  44  44  44  44
  44  44  44  44 176 152 176  44  44  44] [44 44 44 44 44 44 44 44 44 44 45 45 45 45 45 45 45 45 45 45 45 45 45 45
 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45
 45 45 45 45 45 45 45 45 45 45 45 45 46 46 46 46]
[ 44  44 113  44  44  44 196  44  44 196 176  38  38 176 176  45  38  38
  45  38 168 

calculate_agg:  30%|██▉       | 47/157 [00:01<00:03, 29.98it/s]

[131  49 131 176  51 176  33  66 106 176  66 102  48  51 106  34  51  13
 194 105 176 102   7  49  51 168  91 106 176  33   4 122 194  51   2 102
  51  51 161 176  52  66  52  24  73 122 122  24 131 156  52 102  52  52
 102 114  52  52  52  52  89 153 185  12] [52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52 52
 52 52 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53
 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53 53]
[ 52 102  52 185 102  54 131  39  24  52  52 176  52 102  52  52  41 108
 122  52  16  20  74  52  76  52 102 194  52  24  75 131  53 162  53  53
 168  53 176  53 176  53 168  53 102  27  75  20 173  56 156  50 162  48
 168 102 162 122  53 176 185  53  99 113] [53 53 53 53 53 53 53 53 53 53 53 53 54 54 54 54 54 54 54 54 54 54 54 54
 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54 54
 54 54 54 54 54 54 54 54 54 54 54 54 54 54 55 55]
[176  53 176 153  75  56  66  53 168  59  53  53 156 156  26 176  35 176
 105  54  54 

calculate_agg:  32%|███▏      | 51/157 [00:02<00:03, 30.12it/s]

[139  60 131 131 143 122 101 129  13  60  60 176  60 106 106 131  38 131
  60  60 176 131  64 131 113 168 176  60  75  60  60  99  76 176  60  60
  60 176 156  60 176  60 113 176 176 113  61  76  61 113 176 113 129  61
 176 171 113 135  44 176 176 176 113 146] [61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61
 61 61 61 61 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62
 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62 62]
[ 68 176 176 130 176  65 126 156 102 176  61 176 154 176  93 113 131 176
 113  61  98  92 176 102 176  61 176 113 135  26 176 176 135 101  62 131
  80 168 113 156 176 176  62  75 176  62 151  75  73  59  80 135  44 131
  90 151 176 176  62 176 176 176 149 135] [62 62 62 62 62 62 62 62 62 62 62 62 62 62 63 63 63 63 63 63 63 63 63 63
 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63
 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63 63]
[176  62 158  59  74 112 176 161 176 176 113  60  73 176 168 176 168 147
  63  44  63 

calculate_agg:  38%|███▊      | 59/157 [00:02<00:03, 30.29it/s]

[122 140 122 131  59  98  81 156 110  69  92  69 160 122 156 176  59 131
 146  75 142  98  69  86  84 134 107 188 122  36  72  75 161  69 122 160
 156 169 156  61  67  62 159  88 153 176 156 176 176 111 156 176  75  70
 116 153 133 176 156 156 174 176 196 116] [70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70 70
 70 70 70 70 70 70 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71
 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71]
[153 116 153 176 176 111  17  84 176 153 153 156  94  88 153 133 101 133
  84 176 133 146  70 176  64 116 116  80 153 133  71  71  76  71  71  71
  71  71  71  66  71 196 102  71  71 132 196  64 115 198 164 197  71  71
  71  71  71  71  71  71  71  71  71  64] [71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 71 72 72 72 72 72 72 72 72
 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72
 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72 72]
[ 64 137 197 139  71  71  71  71  71 196  71 102  71 102  71  13  72  13
 102 176  72 

calculate_agg:  43%|████▎     | 67/157 [00:02<00:02, 30.39it/s]

[176 154  84  92 146 146  85  78 103 176  78 131 176 169 114 146 121 154
 146  78 152 146  93 146  66  78 122  78  78  78 176  78 156 167 176  85
 122 154 146 110 146  78 156  78  73  67 176 159 101  79 122 133  95 101
  38 176 176 101  79 122 139 196 101 176] [79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79 79
 79 79 79 79 79 79 79 79 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80
 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80]
[176 168  69 135 134 168 168  95 110 168 102  59 145  76  79 111 131  77
  79 122  64  79 176 101  92  79  97 156 122 113 134  79  80  80 136 161
 172  74 131  64 138 111 151 172 106 131 101 135  71 176  72  80 130 115
  11  80 168   7  80 125 122 139 168 155] [80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 80 81 81 81 81 81 81
 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81
 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81 81]
[ 80 106 176  83 156 150 176  80  45  36 130 151  80  72 113 156 113  38
 103 176  81 

calculate_agg:  48%|████▊     | 75/157 [00:02<00:02, 30.44it/s]

[146 159 176  38 156  81 176 145 101 160 110 117  13  75  64 123 133 176
 111 156 146 155 151  75  75  87  87  87 145 168 111 131 176 176  87 176
 159  93 116  75 176  87 125  13 146  75 113 153 101 108  75 154 125  99
  64  44 131  64  88  38  88 176  17 127] [88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88 88
 88 88 88 88 88 88 88 88 88 88 89 89 89 89 89 89 89 89 89 89 89 89 89 89
 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89]
[ 99 156 168 125  64  75 154 176  88 196 135 176  88  88  82 130 131  99
 102 168 152 160 113 131 156  33 102 113 122  88  95  64  75 108  80 145
  66 184 151  44  64  92 164 154 135 113 146 122 156 131  13 168 176  89
 176 122 135  80  92 179 142 151  89 122] [89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 89 90 90 90 90
 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90
 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90]
[ 89  20 174 186 102 115 128 130  89 115 156 176 174 113 123  44  67  89
 152 176  74 

calculate_agg:  50%|█████     | 79/157 [00:03<00:02, 30.45it/s]

[162 154 108 154  96 197 176 197 154 101 197 154  59  96 115 168 146 154
 176 156  96 103  96 154  81  96  33 197  96 149 154 162  69  75  88 102
 179  96 176 196 101 154 102 146 146 156  96 108 154  96 113 176 168  65
 156  98 131 176 131 131 176  92 133 131] [97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97 97
 97 97 97 97 97 97 97 97 97 97 97 97 98 98 98 98 98 98 98 98 98 98 98 98
 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98]
[176 102 151 176 122  97 176  97 151 156  97  65 136 176  80 176  95 176
  97 146 174  97 168 188 122 147 128 131  64 151  64  99 176  97  93 176
 131 176 168  76  64 176  83 131 145  98  75 173  64 159  84  98 188  18
 142  68  92  98 176 176 128 176 113  98] [98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 98 99 99
 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99
 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99 99]
[ 72 131 176  98 176  64  98 142 176  98  64 111 113 129 115 176  98  98
 137  64  98 

calculate_agg:  55%|█████▌    | 87/157 [00:03<00:02, 30.47it/s]

[ 36 113 176  44 105 105 129 156 135 176 140  77 102 116 105 106 112 156
 176  99 131 112 106 102 102 102  75  76 106 120 161 176  75 156 105 112
 116 176 176 127 113 173  82 105 127 112  53 105  75 173 176 105 156 168
 110 132  92 111 131 106 156  75 113 102] [106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106
 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106 106
 106 106 107 107 107 107 107 107 107 107 107 107 107 107 107 107 107 107
 107 107 107 107 107 107 107 107 107 107]
[106 156 111 131 106 150  77 176 106 109 150 176 151 176 106  88 113 131
 106 106 123 176 105 106 176 176 176  75  64 131  64 176  13 122 180 176
  87 168 176 176 176 176 176 111 176 176 176 153 107 176  17 131  64 153
 176 130 114 176 176 146 196 107 107 131] [107 107 107 107 107 107 107 107 107 107 107 107 107 107 107 107 107 107
 107 107 107 107 107 107 108 108 108 108 108 108 108 108 108 108 108 108
 108 108 108 108 108 108 108 108 108 108 108 108 108 108 108 108 108 10

calculate_agg:  61%|██████    | 95/157 [00:03<00:02, 30.49it/s]

[113 160 151 127 114 114 168 176 185 131 114 176 176 114 156 114  13 114
  75  69 176  75 114 114  80 176 153 176 111 114 176 102 146 114 114 176
  80 174 111 176  99  68 114 168 131 168 196  62 168 114 132 114 117  88
 176 115 195 115 115 115 137 115 115  76] [115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115
 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115 115
 115 115 115 115 116 116 116 116 116 116 116 116 116 116 116 116 116 116
 116 116 116 116 116 116 116 116 116 116]
[115 108 176 115 115 115 176 176  10 115 115 115 115 115 115 125 115 115
 115 115 115 115  74  13 115 198 145 115 196 115 115 115 115 115 115 176
 176 115 115 115 116 176 155 116  64  64 116 176  13 131 168 116 153 176
  64  80  26 102 146 176 131 176 153  13] [116 116 116 116 116 116 116 116 116 116 116 116 116 116 116 116 116 116
 116 116 116 116 116 116 116 116 117 117 117 117 117 117 117 117 117 117
 117 117 117 117 117 117 117 117 117 117 117 117 117 117 117 117 117 11

calculate_agg:  66%|██████▌   | 103/157 [00:03<00:01, 30.46it/s]

[139 125 114 176 122 122 123 123  65 176 127 156 176 123  60 156 176  44
 123  65  75 131 112  75 106 156 168 131 123 176 131 123 131 119 106 112
  87 168 176  74 172 138  60 159 123 123 123 155 156 116 123  77 168 131
  77 156 137  64 124 122 124 101 124 132] [124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124
 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124 124
 124 124 124 124 124 124 125 125 125 125 125 125 125 125 125 125 125 125
 125 125 125 125 125 125 125 125 125 125]
[163  71 124 124 101 132 132 124 124 124 132 102 124 124 115 124 124 170
 124 154  65 124 113 122 122 124 138 124 124  13 146 176 124 132 124 124
 124 102 124 137  84 102 110  99  75  95 156  95 131 176 102 168  13  95
 125  88  99  99 122 176  72 186 125 125] [125 125 125 125 125 125 125 125 125 125 125 125 125 125 125 125 125 125
 125 125 125 125 125 125 125 125 125 125 126 126 126 126 126 126 126 126
 126 126 126 126 126 126 126 126 126 126 126 126 126 126 126 126 126 12

calculate_agg:  68%|██████▊   | 107/157 [00:03<00:01, 30.45it/s]

[ 75 147 113  22  92 164 171  83 131 146  84 106 176 131 132 108 152 132
 176  74 168  64  95  86  95  76  13  95  71 171 163 132 132  72 122  95
 156 146 196 168 126 132 113 135 176  92 102  67 168  75 156 102 101 137
 122  73 176 132 133 176 153 133 176 121] [133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133
 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133 133
 133 133 133 133 133 133 133 133 134 134 134 134 134 134 134 134 134 134
 134 134 134 134 134 134 134 134 134 134]
[156 176 116 133 131 133 133  92 176 127 176 153 133 176 133 121 133  13
 176 155 176 133 133 176 133 176  92 153 133 164 133 131 176 176 135  64
 133 131 145 176 133 176 145 153  13 134  80  99 160 173 176  85  38 188
 102  76 176 168 134 112 134 101 134 134] [134 134 134 134 134 134 134 134 134 134 134 134 134 134 134 134 134 134
 134 134 134 134 134 134 134 134 134 134 134 134 135 135 135 135 135 135
 135 135 135 135 135 135 135 135 135 135 135 135 135 135 135 135 135 13

calculate_agg:  73%|███████▎  | 115/157 [00:04<00:01, 30.40it/s]

[123 156 140 174  72 140 140  93  72  75 131  87 141 122 155 135 188 141
 141 113 176 153  72 123 110  92 154 156 113 141 141 131  92 146 133 141
 110 122  81 151 141 123 110 117  12  69 127 160  92 159 122 110 102 176
  72 186 110 130 141  69 122 164 113 142] [142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142
 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142 142
 142 142 142 142 142 142 142 142 142 142 143 143 143 143 143 143 143 143
 143 143 143 143 143 143 143 143 143 143]
[131  72 142 176  80 176 131 176 132 156  44 176 176 113 185 142 112 151
  64 111 142 131  64 176 176 142 176 142  64 142  64 176 126  87 131 142
 145  44 135 176 176 106 111 127 144 117 131 143 176 143 101 131  36  75
 143 176 131 143 143 168 114 143 143 131] [143 143 143 143 143 143 143 143 143 143 143 143 143 143 143 143 143 143
 143 143 143 143 143 143 143 143 143 143 143 143 143 143 144 144 144 144
 144 144 144 144 144 144 144 144 144 144 144 144 144 144 144 144 144 14

calculate_agg:  78%|███████▊  | 123/157 [00:04<00:01, 30.43it/s]

[150 176  17  75 149 113 135 149 130  75 149 196 150  76 176 150 176  75
 150 189 168  99 150 176 153  65 150 176 156 113 150 150 176 176 176 176
 113 150 112 150 150 168 168 113 119 150 168 113  76 176  45 131 113  41
 176 131 150 150 156 176 168 112  97  75] [151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151
 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151 151
 151 151 151 151 151 151 151 151 151 151 151 151 152 152 152 152 152 152
 152 152 152 152 152 152 152 152 152 152]
[174 151 176 151 151 151 151 151 176  93  36 120 151  59 113 122  13 130
 130 130 151 114 151 151 151  98 151 142 150 176  86  80 151 151 168 151
 151 130 117 129 167 142 151 110  64 151 122 151 196 152  38 196 152  75
 152 152   4 102 152  12 152 152  73 176] [152 152 152 152 152 152 152 152 152 152 152 152 152 152 152 152 152 152
 152 152 152 152 152 152 152 152 152 152 152 152 152 152 152 152 153 153
 153 153 153 153 153 153 153 153 153 153 153 153 153 153 153 153 153 15

calculate_agg:  83%|████████▎ | 131/157 [00:04<00:00, 30.40it/s]

[168 158 158 168  13 120 176  74 119 168  75  73  64 186 110 125  59 186
 186 125  13 129  64  40 150 196  72 130  95 113 159 101 106 176 176 176
 115  80  79 119 159 159 132  99 156 132 131 148 150  92 193 168 159 156
 159 155 176 163  72  36   8 130  92  92] [160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160
 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160 160
 160 160 160 160 160 160 160 160 160 160 160 160 160 160 161 161 161 161
 161 161 161 161 161 161 161 161 161 161]
[ 44 131  83 113  75 101 160  13  80  64  87 160 131  75   2 160 110 160
 160 161 160 160 132 160 113   2 131 156 186  44 131 110 160 176 122 176
  88 161 172 156 102  44 160  95 160 160 131   1 160 168 161 161 131 113
 161 161  38 189 161 176 176 131 161 161] [161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161
 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161 161
 162 162 162 162 162 162 162 162 162 162 162 162 162 162 162 162 162 16

calculate_agg:  86%|████████▌ | 135/157 [00:04<00:00, 30.42it/s]

[128 167 167 133  92 139 129  99 176 171 167 128 167  67  75 167 176 168
 111 168 168 176  75  75 176 168 113  75 176 111 176 161 156 176  64 176
 168 127  93  38 131 168 111 131 176 168 176 101 168 119 176  13 156 130
 168 168 156  64 114 168 176 168  76 168] [168 168 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169
 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169
 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 169 170 170
 170 170 170 170 170 170 170 170 170 170]
[168 176 176 126  38  64 131 131 169  60 176  64 134 131 176 129 113 126
 134  59 156 113 176 176 176 169 102 176 169 131  83 106 176 168  60 176
 169  44 113 112 176  72 113 168 129 176 105 112 176 131 176 176 129 113
 152 129 168 102 170 195 170 170 154 170] [170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170
 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170 170
 170 170 171 171 171 171 171 171 171 171 171 171 171 171 171 171 171 17

calculate_agg:  91%|█████████ | 143/157 [00:05<00:00, 30.41it/s]

[176 131 176 176 176 176 176 176 176 176 176 176 176 168 176 176 176 176
 168 186  75 177  38  41 186 177 102 102 174 139 113 122  83 177  41 177
 122  36 176 131 102 176 168  38  38 176 179 177 177  18 131  36  36 131
 189  38 113 177 176 193 152  83 183 102] [177 177 177 177 178 178 178 178 178 178 178 178 178 178 178 178 178 178
 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178
 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178 178
 179 179 179 179 179 179 179 179 179 179]
[178 177 176 174 102 168  38 176  36 176 178 178  38 178  38 102 152  18
  38 179 178 102 178  38 177 102 178 178 178  38 102 176  43 113 185 177
  38  38 183 112 178 178  36  36  93 179 178  38 177 102 188 174  38 176
 130 117 182 179 148  83 151  38  13  93] [179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179
 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179 179
 179 179 179 179 180 180 180 180 180 180 180 180 180 180 180 180 180 18

calculate_agg:  96%|█████████▌| 151/157 [00:05<00:00, 30.48it/s]

[131 185 185  16 185 185 113  67  75 185 185 185 131 185 188 185 113 102
 185 130  38  75  38  38  36 113  38 188 176 187  38 179 186 186  38 186
 186  13  38 180  38 101  38  38 151 188 186 130 113  38  38  38 186 196
 186  38  38 168 102 113 186 187 186 186] [186 186 186 186 186 186 187 187 187 187 187 187 187 187 187 187 187 187
 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187
 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187 187
 187 187 188 188 188 188 188 188 188 188]
[196  38 188  38   6 187  92 186 185  38 113 130 187  38 187  92  38  38
 196 187 187 187  38  83 187 184 187  73 187 187 187 186 187 196 187 187
 187 187 185 196 186  38 187 130 102 186  38 186 187 113 187  38 187 156
 102 188  59 188 186 163 156 188 188 188] [188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188
 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188 188
 188 188 188 188 188 188 189 189 189 189 189 189 189 189 189 189 189 18

calculate_agg:  99%|█████████▊| 155/157 [00:05<00:00, 30.41it/s]

[194 194 194 176 102 194 168 195  75  66 168 194 101  76 101 196 129 194
 162  22 148 122 132 195  60 197 195 102 197  41 162  96  44 126 195  29
 176  12  66 198  79 195 195 195 197 196  31  12 162 119  66 129  75 194
  66  96 194 102  64  66 195  12 176  38] [195 195 195 195 195 195 195 195 196 196 196 196 196 196 196 196 196 196
 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196
 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196 196
 196 196 196 196 197 197 197 197 197 197]
[195  95  12 195 195 176 195 194 196  23 196 196 176 196 196 196 176 196
 176 196 196 196 196 183  46 196 196 176 196  39 196 196 196 183 196 168
  14 196 176 196  14 196 196  13  72  13 196 196 196  38 196 196  46 196
  13 113  46  14 197 176 197 102 198 122] [197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197
 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197 197
 197 197 197 197 197 197 197 197 198 198 198 198 198 198 198 198 198 19

calculate_agg: 100%|██████████| 157/157 [00:05<00:00, 27.28it/s]


In [26]:
if not os.path.exists('prob_loyalty_distribution/'):
    os.mkdir('prob_loyalty_distribution/')
# print(y)
print(np.sum(np.argmax(probs_s, 1) == ys) / len(val_dst))
print(np.sum(np.argmax(probs_t, 1) == ys) / len(val_dst))
print(np.sum(np.argmax(probs_t, 1) == np.argmax(probs_s, 1)) / len(val_dst))
dist = distance.jensenshannon(probs_s.T, probs_t.T)
dist[np.isnan(dist)] = 0 if np.sum(np.abs(probs_s[np.where(np.isnan(dist))] - probs_t[np.where(np.isnan(dist))])) < 1e-6 else 1
prob_loyalty = 1 - np.sqrt(dist)
print(np.nanmean(prob_loyalty))
# print(dist[1773])

0.1192
0.2485
0.291
0.25757793


In [11]:
import seaborn as sns
import matplotlib.pyplot as plt
fig = sns.displot(prob_loyalty, kind='hist', color=sns.xkcd_rgb["pale red"])
# fig.set_xlim(0, 1)
fig.set_xlabels('probability loyalty', fontsize=14)
fig.set_ylabels('Number of Images', fontsize=14)
# plt.savefig('prob_loyalty_distribution/prob_loyalty_deepinv.pdf')
# x_ticks = np.arange(0, 1.01, 0.2)
# y_ticks = np.arange(0, 700, 100)
# fig.set_xticklabels(x_ticks, fontsize=12)
# fig.set_yticklabels(y_ticks, fontsize=12)


<seaborn.axisgrid.FacetGrid at 0x7ff46350a400>

<Figure size 360x360 with 1 Axes>

### Other visualization Results
The items for further visualization:
- Difficulty visualization at different timestage.
- Ablation study of $k_{begin}$ and $k_{end}$.
- Ablation study of $\alpha_{adv}$.
- Error bar of CuDFKD, CMI, ADI and DAFL.

In [26]:
x = np.array(['1/5', '1/4', '1/3'])
y = np.array(['2/3', '3/4', '4/5'])
res2res_acc = np.array([[71.00, 71.22, 71.12], [70.81, 70.79, 70.68], [70.41, 70.52, 70.54]])
res2res_agr = np.array([[85.88, 85.85, 85.64], [85.40, 85.14, 85.58], [83.77, 83.88, 83.67]])
res2res_loyalty = np.array([[0.6659, 0.6777, 0.6751], [0.6684, 0.6624, 0.6682], [0.6469, 0.6500, 0.6467]])

In [14]:
if not os.path.exists('heatmaps/'):
    os.mkdir('heatmaps/')
fig1 = sns.heatmap(res2res_acc, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/vgg_acc.pdf')

<Figure size 432x288 with 2 Axes>

In [15]:
fig2 = sns.heatmap(res2res_agr, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/vgg_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [16]:
fig3 = sns.heatmap(res2res_loyalty, xticklabels=x, yticklabels=y, annot=True, fmt='.4f', cmap='RdBu_r')
plt.savefig('heatmaps/vgg_loyalty.pdf')

<Figure size 432x288 with 2 Axes>

In [17]:
vgg2res_acc = np.array([[64.74, 64.65, 64.77], [64.66, 64.66, 64.10], [63.95, 64.00, 63.81]])
vgg2res_agr = np.array([[68.91, 68.88, 68.93], [68.80, 68.85, 68.54], [68.58, 68.19, 68.31]])
vgg2res_loyalty = np.array([[0.5067, 0.5058, 0.5072], [0.5059, 0.5054, 0.5069], [0.5001, 0.4961, 0.5021]])

In [18]:
sns.heatmap(vgg2res_acc, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/wrn_acc.pdf')

<Figure size 432x288 with 2 Axes>

In [19]:
sns.heatmap(vgg2res_agr, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/wrn_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [20]:
sns.heatmap(vgg2res_loyalty, xticklabels=x, yticklabels=y, annot=True, fmt='.4f', cmap='RdBu_r')
plt.savefig('heatmaps/wrn_loyalty.pdf')

<Figure size 432x288 with 2 Axes>

In [12]:
ckpt = torch.load('../checkpoints/datafree-cudfkd/cifar100-wrn40_2-wrn40_1--cudfkd_L2_line33_d.pth', map_location='cpu')
student.load_state_dict(ckpt['state_dict'])
student.eval()
g = datafree.models.generator.DCGAN_Generator_CIFAR10(nz=512, ngf=64, nc=3, img_size=32, d=2, cond=False)
g = prepare_model(g)
g.load_state_dict(ckpt['G_0'])
g.eval()

DCGAN_Generator_CIFAR10(
  (project): Sequential(
    (0): Flatten()
    (1): Linear(in_features=512, out_features=16384, bias=True)
  )
  (main): Sequential(
    (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode=nearest)
    (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Upsample(scale_factor=2.0, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Sigmoid()
  )
)

In [13]:
from torchvision.utils import make_grid,save_image
diff = []
now_epoch = 300
kl_metric = datafree.criterions.KLDiv(T=2., reduction='none')
with torch.no_grad():
    for i in range(20):
        z = torch.randn(batch_size, 512).to(gpu)
        x = g(z)
        t_out = teacher(normalizer(x))
        s_out = student(normalizer(x))
        kl_diff = kl_metric(t_out, s_out).sum(1).detach().cpu().numpy()
        diff.append(kl_diff)
    
diff = np.concatenate(diff)
print(diff)
with torch.no_grad():
    z = torch.randn(1024, 512).to(gpu)
    img = make_grid(x, nrow=16)
    t_out = teacher(normalizer(x))
    pesudo_label = t_out.argmax(1).cpu().numpy()
    print(pesudo_label)

#     save_image(img, 'difficulty/epoch_{}.jpg'.format(now_epoch))

[1.6653383  2.05072    1.3142852  ... 1.1158123  0.25667655 2.7491198 ]
[41 44 80 ... 86 96 67]


In [19]:
xs = []
for i in range(num_classes):
#     print(i)
    xs.append(x[pesudo_label == i].cpu())
    print(i, xs[i].shape)
    img = make_grid(xs[i], nrow=1)
    save_image(img, 'difficulty/img_epoch_{}_{}.jpg'.format(now_epoch, i))

0 torch.Size([8, 3, 32, 32])
1 torch.Size([13, 3, 32, 32])
2 torch.Size([15, 3, 32, 32])
3 torch.Size([11, 3, 32, 32])
4 torch.Size([9, 3, 32, 32])
5 torch.Size([11, 3, 32, 32])
6 torch.Size([12, 3, 32, 32])
7 torch.Size([9, 3, 32, 32])
8 torch.Size([11, 3, 32, 32])
9 torch.Size([5, 3, 32, 32])
10 torch.Size([12, 3, 32, 32])
11 torch.Size([6, 3, 32, 32])
12 torch.Size([8, 3, 32, 32])
13 torch.Size([11, 3, 32, 32])
14 torch.Size([11, 3, 32, 32])
15 torch.Size([10, 3, 32, 32])
16 torch.Size([9, 3, 32, 32])
17 torch.Size([8, 3, 32, 32])
18 torch.Size([9, 3, 32, 32])
19 torch.Size([6, 3, 32, 32])
20 torch.Size([9, 3, 32, 32])
21 torch.Size([9, 3, 32, 32])
22 torch.Size([13, 3, 32, 32])
23 torch.Size([8, 3, 32, 32])
24 torch.Size([4, 3, 32, 32])
25 torch.Size([6, 3, 32, 32])
26 torch.Size([18, 3, 32, 32])
27 torch.Size([10, 3, 32, 32])
28 torch.Size([11, 3, 32, 32])
29 torch.Size([10, 3, 32, 32])
30 torch.Size([12, 3, 32, 32])
31 torch.Size([11, 3, 32, 32])
32 torch.Size([6, 3, 32, 32])
33 

In [21]:
sns.displot(diff)

<seaborn.axisgrid.FacetGrid at 0x7fbce51f3d30>

<Figure size 360x360 with 1 Axes>

In [21]:
wrn2wrn_acc = np.array([[75.24, 75.16, 75.31], [75.11, 75.21, 75.22], [75.01, 74.89, 75.01]])
wrn2wrn_agr = np.array([[85.26, 86.06, 86.11], [85.40, 85.14, 85.58], [85.38, 85.53, 85.57]])
wrn2wrn_loyalty = np.array([[0.6472, 0.6476, 0.6476], [0.6433, 0.6437, 0.6447], [0.6393, 0.6403, 0.6407]])


In [22]:
sns.heatmap(wrn2wrn_acc, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/res_acc.pdf')

<Figure size 432x288 with 2 Axes>

In [23]:
sns.heatmap(wrn2wrn_agr, xticklabels=x, yticklabels=y, annot=True, fmt='.2f', cmap='RdBu_r')
plt.savefig('heatmaps/res_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [25]:
sns.heatmap(wrn2wrn_loyalty, xticklabels=x, yticklabels=y, annot=True, fmt='.4f', cmap='RdBu_r')
plt.savefig('heatmaps/res_loyalty.pdf')

<Figure size 432x288 with 2 Axes>

In [46]:
import pandas as pd
grad_range = np.array([0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4])
adv_acc = np.array([75.10, 75.31, 75.50, 75.50, 75.51, 75.60, 75.73, 75.67])
adv_agg = np.array([85.52, 86.11, 86.22, 86.52, 86.30, 86.52, 86.32, 86.11])
adv_loyalty = np.array([0.6412, 0.6476, 0.6528, 0.6568, 0.6571, 0.6576, 0.6575, 0.6572])
if not os.path.exists('grad_adv/'):
    os.mkdir('grad_adv/')

res_dict = {
    'grad_of_adv': grad_range,
    'acc@1': adv_acc,
    'agree@1': adv_agg,
    'loyalty': adv_loyalty
}

final_df = pd.DataFrame(res_dict)
final_df

Unnamed: 0,grad_of_adv,acc@1,agree@1,loyalty
0,0.5,75.1,85.52,0.6412
1,1.0,75.31,86.11,0.6476
2,1.5,75.5,86.22,0.6528
3,2.0,75.5,86.52,0.6568
4,2.5,75.51,86.3,0.6571
5,3.0,75.6,86.52,0.6576
6,3.5,75.73,86.32,0.6575
7,4.0,75.67,86.11,0.6572


In [21]:
# now_epoch = 300
global_iter = now_epoch * 400
lamda = datafree.datasets.utils.lambda_scheduler(1.0, global_iter, alpha=0.00002)
g,v = datafree.datasets.utils.curr_v(l=torch.FloatTensor(diff), lamda=lamda, spl_type='log')

In [22]:
if not os.path.exists('difficulty/'):
    os.mkdir('difficulty/')
diffs = dict()
diffs['epoch@{}'.format(now_epoch)] = v.numpy() * diff
# sns.displot(v.numpy() * diff, kind='kde')
# plt.savefig('difficulty/difficulty_at_{}.pdf'.format(now_epoch))

In [23]:
ckpt = torch.load('../checkpoints/datafree-cudfkd/cifar100-wrn40_2-wrn40_1-10.pth', map_location='cpu')
student.load_state_dict(ckpt['state_dict'])
student.eval()
g = datafree.models.generator.DCGAN_Generator_CIFAR10(nz=512, ngf=64, nc=3, img_size=32, d=2, cond=False)
g = prepare_model(g)
g.load_state_dict(ckpt['G_0'])
g.eval()

DCGAN_Generator_CIFAR10(
  (project): Sequential(
    (0): Flatten()
    (1): Linear(in_features=512, out_features=16384, bias=True)
  )
  (main): Sequential(
    (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode=nearest)
    (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Upsample(scale_factor=2.0, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Sigmoid()
  )
)

In [24]:
diff = []
now_epoch = 10
kl_metric = datafree.criterions.KLDiv(T=2., reduction='none')
with torch.no_grad():
    for i in range(20):
        z = torch.randn(batch_size, 512).to(gpu)
        x = g(z)
        t_out = teacher(normalizer(x))
        s_out = student(normalizer(x))
        kl_diff = kl_metric(t_out, s_out).sum(1).detach().cpu().numpy()
        diff.append(kl_diff)
    
diff = np.concatenate(diff)
print(diff)
with torch.no_grad():
    z = torch.randn(1024, 512).to(gpu)
    x = g(z)
    img = make_grid(x, nrow=16)
    t_out = teacher(normalizer(x))
    pesudo_label = t_out.argmax(1).cpu().numpy()
    print(pesudo_label)
    save_image(img, 'difficulty/epoch_{}.jpg'.format(now_epoch))

[0.44048202 0.7211295  0.7666422  ... 0.34204212 0.96407497 0.6640957 ]
[42 24 45 ... 99 60 87]


In [26]:
xs = []
for i in range(num_classes):
#     print(i)
    xs.append(x[pesudo_label == i].cpu())
    print(i, xs[i].shape)
    if xs[i].shape[0] > 0:
        img = make_grid(xs[i], nrow=1)
        save_image(img, 'difficulty/img_epoch_{}_{}.jpg'.format(now_epoch, i))

0 torch.Size([3, 3, 32, 32])
1 torch.Size([25, 3, 32, 32])
2 torch.Size([9, 3, 32, 32])
3 torch.Size([15, 3, 32, 32])
4 torch.Size([18, 3, 32, 32])
5 torch.Size([24, 3, 32, 32])
6 torch.Size([10, 3, 32, 32])
7 torch.Size([20, 3, 32, 32])
8 torch.Size([15, 3, 32, 32])
9 torch.Size([5, 3, 32, 32])
10 torch.Size([15, 3, 32, 32])
11 torch.Size([8, 3, 32, 32])
12 torch.Size([3, 3, 32, 32])
13 torch.Size([8, 3, 32, 32])
14 torch.Size([9, 3, 32, 32])
15 torch.Size([11, 3, 32, 32])
16 torch.Size([14, 3, 32, 32])
17 torch.Size([6, 3, 32, 32])
18 torch.Size([4, 3, 32, 32])
19 torch.Size([12, 3, 32, 32])
20 torch.Size([2, 3, 32, 32])
21 torch.Size([2, 3, 32, 32])
22 torch.Size([6, 3, 32, 32])
23 torch.Size([22, 3, 32, 32])
24 torch.Size([11, 3, 32, 32])
25 torch.Size([4, 3, 32, 32])
26 torch.Size([22, 3, 32, 32])
27 torch.Size([6, 3, 32, 32])
28 torch.Size([15, 3, 32, 32])
29 torch.Size([10, 3, 32, 32])
30 torch.Size([19, 3, 32, 32])
31 torch.Size([10, 3, 32, 32])
32 torch.Size([8, 3, 32, 32])
33

In [27]:
# now_epoch = 10
global_iter = now_epoch * 400
lamda = datafree.datasets.utils.lambda_scheduler(1.0, global_iter, alpha=0.00002)
g,v = datafree.datasets.utils.curr_v(l=torch.FloatTensor(diff), lamda=lamda, spl_type='log')

In [28]:
ax = sns.displot(diffs, kind='kde')
ax.set_xlabels('Difficulty', fontsize=12)
plt.savefig('difficulty/difficulty_wrn402_wrn162_cifar100.pdf')

<Figure size 455x360 with 1 Axes>

In [74]:
res_acc_dict = {
    'method': ['cudfkd', 'cudfkd', 'cudfkd', 'teacher', 'adi', 'adi', 'adi', 'dafl', 'dafl', 'dafl', 'cmi', 'cmi', 'cmi'],
    'acc@1': [95.24, 95.28, 95.15, 95.70],
    'agree@1': [98.20, 98.19, 98.09, 100.00],
    'prob_loyalty': [0.8909, 0.8903, 0.8898, 1.00],
}

In [38]:
fig, ax1 = plt.subplots()
width = 0.4
x1_list = []
x2_list = []
for i in range(adv_acc.shape[0]):
    x1_list.append(i)
    x2_list.append(i + width)

# b1 = sns.lineplot(data=final_df, x="grad_of_adv", y="acc@1",color=sns.xkcd_rgb["pale red"] )  
b1, = plt.plot(final_df["grad_of_adv"], final_df["acc@1"], '-ro', label='Accuracy@1')

# b1 = ax1.bar(x1_list, adv_acc, width=width, label='Acc@1', color=sns.xkcd_rgb["pale red"], tick_label=grad_range)
ax2 = ax1.twinx()
b2, = plt.plot(final_df["grad_of_adv"], final_df["agree@1"], '-bo', label='Agreement@1')
# plt.grid('off')
# b2 = sns.lineplot(data=final_df, x="grad_of_adv", y="agree@1", color=sns.xkcd_rgb["denim blue"])
# b2 = ax2.bar(x2_list, adv_agg, width=width, label='Agree@1', color=sns.xkcd_rgb["denim blue"], tick_label=grad_range)
ax1.set_xlabel('grad of adv', fontsize=12)
ax1.set_ylabel('Accuracy@1', fontsize=12)
ax2.set_ylabel('Agreement@1', fontsize=12)
ax1.set_ylim(74, 77)
ax2.set_ylim(85, 87)
axs = [b1, b2]
lbs = [ax.get_label() for ax in axs]
ax1.legend(axs, lbs)
plt.savefig('grad_adv/acc_agg.pdf')
# plt.legend(['a', 'b'])
# plt.grid('off')

<Figure size 432x288 with 2 Axes>

In [39]:
plt.plot(final_df["grad_of_adv"], final_df["loyalty"], '-go', label="prob_loyalty")
plt.ylim(0.64, 0.66)
x_ticks = final_df["grad_of_adv"]
y_ticks = np.arange(0.64, 0.66, 0.005)
plt.xlabel('grad of adv', fontsize=14)
plt.ylabel('probability loyalty', fontsize=14)
plt.xticks(x_ticks, fontsize=12)
plt.yticks(y_ticks, fontsize=12)
plt.savefig('grad_adv/prob_loyalty.pdf')

<Figure size 432x288 with 1 Axes>

In [40]:
curriculum_dict={
    'method': ['teacher', 'None', 'hard', 'soft', 'log','teacher', 'None', 'hard', 'soft', 'log','teacher', 'None', 'hard', 'soft', 'log'],
    'accuracy': [95.70, 93.24, 94.97, 95.01, 95.28, 100.00, 94.00, 97.71, 97.74, 98.13, 1.0000, 0.8202, 0.8813, 0.8786, 0.8909],
    'metric': ['Acc@1','Acc@1','Acc@1','Acc@1','Acc@1','Agree@1','Agree@1','Agree@1','Agree@1','Agree@1','prob loyalty','prob loyalty','prob loyalty','prob loyalty','prob loyalty']
}
curr_df = pd.DataFrame(curriculum_dict)
curr_df
if not os.path.exists('curr_strategy/'):
    os.mkdir('curr_strategy/')

In [53]:
fig, ax1 = plt.subplots()
width = 0.4
x1_list = []
x2_list = []
for i in range(5):
    x1_list.append(i)
    x2_list.append(i + width)

# b1 = sns.lineplot(data=final_df, x="grad_of_adv", y="acc@1",color=sns.xkcd_rgb["pale red"] )  
# b1, = plt.plot(final_df["grad_of_adv"], final_df["acc@1"], '-ro', label='Accuracy@1')
# b1 = sns.barplot(data=curriculum_dict, x='method', y='accuracy', color=sns.xkcd_rgb["pale red"])
b1 = ax1.bar(x1_list, curr_df['accuracy'][:5], width=width, label='Acc@1', color=sns.xkcd_rgb["pale red"], tick_label=curriculum_dict['method'][:5])
ax2 = ax1.twinx()
# b2, = plt.plot(final_df["grad_of_adv"], final_df["agree@1"], '-bo', label='Agreement@1')
# plt.grid('off')
# b2 = sns.lineplot(data=final_df, x="grad_of_adv", y="agree@1", color=sns.xkcd_rgb["denim blue"])
b2 = ax2.bar(x2_list, curr_df['accuracy'][5:10], width=width, label='Agree@1', color=sns.xkcd_rgb["denim blue"], tick_label=curriculum_dict['method'][:5])
# ax2 = ax1.twinx()
# b2 = sns.barplot(data=curriculum_dict, x='method', y='Agree@1', color=sns.xkcd_rgb["denim blue"])
ax1.set_xlabel('Curriculum strategy', fontsize=12)
ax1.set_ylabel('Accuracy@1', fontsize=12)
ax2.set_ylabel('Agreement@1', fontsize=12)
ax1.set_ylim(93, 96)
ax2.set_ylim(92, 101)
# ax2.set_yticklabels(np.arange(92, 101, 2))
axs = [b1, b2]
lbs = [ax.get_label() for ax in axs]
# ax1.legend()
ax1.legend(axs, lbs)
plt.savefig('curr_strategy/acc_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [63]:
plt.bar(x1_list, curr_df['accuracy'][-5:], width=width, label='prob_loyalty', color=sns.xkcd_rgb["dark green"], tick_label=['teacher', 'None', 'hard', 'soft', 'log'])
plt.ylim(0.8, 1.01)
# x_ticks = final_df["grad_of_adv"]
# x_ticks = ['teacher', 'None', 'hard', 'soft', 'log']
y_ticks = np.arange(0.8, 1.01, 0.05)
plt.xlabel('Curriculum Strategy', fontsize=14)
plt.ylabel('probability loyalty', fontsize=14)
plt.xticks(x_ticks, fontsize=12)
plt.yticks(y_ticks, fontsize=12)
plt.legend()
plt.savefig('curr_strategy/prob_loyalty.pdf')

ConversionError: Failed to convert value(s) to axis units: ['teacher', 'None', 'hard', 'soft', 'log']

<Figure size 432x288 with 1 Axes>

In [47]:
if not os.path.exists('T/'):
    os.mkdir('T/')
T_dict = {
    'T': [4, 5, 8, 10, 12, 15, 18, 20],
    'accuracy@1': [94.96, 95.01,94.94,94.93,94.97, 95.03, 94.94, 95.28],
    'agreement@1': [97.86, 97.64, 97.79, 97.63, 97.60, 97.69, 97.74, 98.20],
    'prob_loyalty': [0.8777, 0.8774, 0.8807, 0.8739, 0.8771, 0.8801, 0.8793, 0.8909]
    
}


In [50]:
fig, ax1 = plt.subplots()
width = 0.4
x1_list = []
x2_list = []
for i in range(adv_acc.shape[0]):
    x1_list.append(i)
    x2_list.append(i + width)

# b1 = sns.lineplot(data=final_df, x="grad_of_adv", y="acc@1",color=sns.xkcd_rgb["pale red"] )  
b1, = plt.plot(T_dict["T"], T_dict['accuracy@1'], '-ro', label='Accuracy@1')

# b1 = ax1.bar(x1_list, adv_acc, width=width, label='Acc@1', color=sns.xkcd_rgb["pale red"], tick_label=grad_range)
ax2 = ax1.twinx()
b2, = plt.plot(T_dict["T"], T_dict['agreement@1'], '-bo', label='Agreement@1')
# plt.grid('off')
# b2 = sns.lineplot(data=final_df, x="grad_of_adv", y="agree@1", color=sns.xkcd_rgb["denim blue"])
# b2 = ax2.bar(x2_list, adv_agg, width=width, label='Agree@1', color=sns.xkcd_rgb["denim blue"], tick_label=grad_range)
ax1.set_xlabel('T', fontsize=12)
ax1.set_ylabel('Accuracy@1', fontsize=12)
ax2.set_ylabel('Agreement@1', fontsize=12)
ax1.set_ylim(94, 95.5)
ax2.set_ylim(97, 99)
axs = [b1, b2]
lbs = [ax.get_label() for ax in axs]
ax1.legend(axs, lbs)
plt.savefig('T/acc_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [54]:
plt.plot(T_dict["T"], T_dict["prob_loyalty"], '-go', label="prob_loyalty")
plt.ylim(0.87, 0.9)
x_ticks = np.arange(4, 21, 4)
y_ticks = np.arange(0.87, 0.9, 0.005)
plt.xlabel('T', fontsize=14)
plt.ylabel('probability loyalty', fontsize=14)
plt.xticks(x_ticks, fontsize=12)
plt.yticks(y_ticks, fontsize=12)
plt.savefig('T/prob_loyalty.pdf')

<Figure size 432x288 with 1 Axes>

In [50]:
if not os.path.exists('lambda_0/'):
    os.mkdir('lambda_0/')
lambda_0_dict = {
    'lambda_0': [2, 2.2, 2.5, 2.8],
    'accuracy@1': [93.34,93.24,93.17,93.10],
    'agreement@1': [94.15, 94.02, 93.92, 93.97],
    'loyalty': [0.8183, 0.8176, 0.8173, 0.8170],
    'curriculum': ['logarithm', 'logarithm', 'logarithm', 'logarithm']
}



In [52]:
fig, ax1 = plt.subplots()
# width = 0.4
# x1_list = []
# x2_list = []
# for i in range(adv_acc.shape[0]):
#     x1_list.append(i)
#     x2_list.append(i + width)

# b1 = sns.lineplot(data=final_df, x="grad_of_adv", y="acc@1",color=sns.xkcd_rgb["pale red"] )  
b1, = plt.plot(lambda_0_dict["lambda_0"], lambda_0_dict['accuracy@1'], '-ro', label='Accuracy@1')

# b1 = ax1.bar(x1_list, adv_acc, width=width, label='Acc@1', color=sns.xkcd_rgb["pale red"], tick_label=grad_range)
ax2 = ax1.twinx()
b2, = plt.plot(lambda_0_dict["lambda_0"], lambda_0_dict['agreement@1'], '-bo', label='Agreement@1')
# plt.grid('off')
# b2 = sns.lineplot(data=final_df, x="grad_of_adv", y="agree@1", color=sns.xkcd_rgb["denim blue"])
# b2 = ax2.bar(x2_list, adv_agg, width=width, label='Agree@1', color=sns.xkcd_rgb["denim blue"], tick_label=grad_range)
ax1.set_xlabel('lambda_0', fontsize=12)
ax1.set_ylabel('Accuracy@1', fontsize=12)
ax2.set_ylabel('Agreement@1', fontsize=12)
ax1.set_ylim(93, 93.5)
ax2.set_ylim(93.8, 94.2)
axs = [b1, b2]
lbs = [ax.get_label() for ax in axs]
ax1.legend(axs, lbs)
plt.savefig('lambda_0/acc_agg.pdf')

<Figure size 432x288 with 2 Axes>

In [53]:
plt.plot(lambda_0_dict["lambda_0"], lambda_0_dict["loyalty"], '-go', label="prob_loyalty")
plt.ylim(0.815, 0.82)
x_ticks = np.arange(2, 3, 0.2)
y_ticks = np.arange(0.815, 0.82, 0.001)
plt.xlabel('lambda_0', fontsize=14)
plt.ylabel('probability loyalty', fontsize=14)
plt.xticks(x_ticks, fontsize=12)
plt.yticks(y_ticks, fontsize=12)
plt.savefig('lambda_0/prob_loyalty.pdf')

<Figure size 432x288 with 1 Axes>

In [36]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'pdf'
res_dict_below = {
    'method': ['teacher', 'teacher', 'cudfkd', 'cudfkd', 'cmi', 'cmi', 'dafl', 'dafl', 'adi', 'adi', 'adi', 'adi', 'cudfkd', 'cudfkd', 'dafl', 'dafl', 'cmi', 'cmi', 'adi', 'adi', 'cmi', 'cmi', 'dafl', 'dafl', 'cudfkd', 'cudfkd'],
    'acc@1': [95.70, 94.07, 95.28, 93.24, 94.84, 92.36, 91.45 , 81.66, 93.26, 90.38, 90.36, 90.17, 95.24, 95.15, 82.89, 81.86, 91.95, 91.85, 94.56, 94.63, 93.18, 93.36 , 89.41, 89.04, 93.38, 93.30],
    'type': ['clean', 'noisy', 'clean', 'noisy','clean', 'noisy','clean', 'noisy','clean', 'noisy', 'noisy', 'noisy', 'clean', 'clean', 'noisy', 'noisy', 'noisy', 'noisy', 'clean', 'clean', 'clean', 'clean', 'clean', 'clean', 'noisy', 'noisy'],
    'agree@1':[100.00, 100.00, 98.19, 97.25, 96.46, 93.13, 93.20, 83.85, 93.45, 92.88, 92.90, 92.92 , 98.20, 98.09, 85.03, 83.55, 94.98, 95.10, 96.96, 96.85, 95.23, 95.34, 91.06, 90.69, 97.69, 97.27],
    'loyalty':[1.0000, 1.0000, 0.8909, 0.9002, 0.8747, 0.8049, 0.7686, 0.6521, 0.8078, 0.8247, 0.8235, 0.8245, 0.8915, 0.8897, 0.6803, 0.6527, 0.8701, 0.8711, 0.8849, 0.8833, 0.8501,0.8499, 0.7298, 0.7336, 0.9063, 0.8998],
}

noisy_df = pd.DataFrame(data=res_dict_below)
noisy_df

Unnamed: 0,method,acc@1,type,agree@1,loyalty
0,teacher,95.7,clean,100.0,1.0
1,teacher,94.07,noisy,100.0,1.0
2,cudfkd,95.28,clean,98.19,0.8909
3,cudfkd,93.24,noisy,97.25,0.9002
4,cmi,94.84,clean,96.46,0.8747
5,cmi,92.36,noisy,93.13,0.8049
6,dafl,91.45,clean,93.2,0.7686
7,dafl,81.66,noisy,83.85,0.6521
8,adi,93.26,clean,93.45,0.8078
9,adi,90.38,noisy,92.88,0.8247


In [39]:
# noisy_df.columns
import os
if not os.path.exists('noisy/'):
    os.mkdir('noisy/')
ax= sns.barplot(data=noisy_df, x='method', y='acc@1', hue='type')
ax.set_ylim(80, 100)
ax.legend(loc='upper right')
plt.savefig('noisy/acc.pdf')

<Figure size 432x288 with 1 Axes>

In [40]:
ax = sns.barplot(data=noisy_df, x='method', y='agree@1', hue='type',palette="Set2")
ax.set_ylim(80, 100)
ax.legend(loc='upper right')
plt.savefig('noisy/agg.pdf')

<Figure size 432x288 with 1 Axes>

In [41]:
ax = sns.barplot(data=noisy_df, x='method', y='loyalty', hue='type', palette="Set2_r")
ax.set_ylim(0.6, 1)
ax.legend(loc='upper right')
plt.savefig('noisy/loyalty.pdf')

<Figure size 432x288 with 1 Axes>