In [3]:
import torch
import numpy as np
import sys
sys.path.append('..')
import datafree
%config InlineBackend.figure_format = 'pdf'

In [4]:
!nvidia-smi

Thu Nov  3 11:47:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.60.02    Driver Version: 510.60.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:05:00.0 Off |                  N/A |
| 71%   64C    P2   328W / 350W |  15897MiB / 24576MiB |     95%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:09:00.0 Off |                  N/A |
| 72%   65C    P2   332W / 350W |  13911MiB / 24576MiB |     99%      Default |
|       

In [5]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


accs = {
    'DAFL': [11.9, 81.10],
    'ADI': [(28 / 60 + 9 / 3600)*50, 89.46],
    'DFQ': [(0.25 + 34/3600)*50, 90.84],
    'CMI': [(1 + 23.5/60)*50, 91.13],
    'CuDFKD': [13 + 3/60, 91.61],
    'AdaDFKD(S)': [14 + 28/60, 92.04],
    'AdaDFKD(G)': [15 + 34/60, 92.19]
}

acc = {
    'Time(h)': [11.9, (28 / 60 + 9 / 3600)*50,(0.25 + 34/3600)*50, (1 + 23.5/60)*50, 13 + 3/60,14 + 28/60, 15 + 34/60 ],
    'Acc@1':[81.10,  89.46, 90.84, 91.13, 91.61, 92.04, 92.19],
    'method': accs.keys()
}

df = pd.DataFrame(acc)
print(df)

     Time(h)  Acc@1      method
0  11.900000  81.10        DAFL
1  23.458333  89.46         ADI
2  12.972222  90.84         DFQ
3  69.583333  91.13         CMI
4  13.050000  91.61      CuDFKD
5  14.466667  92.04  AdaDFKD(S)
6  15.566667  92.19  AdaDFKD(G)


In [8]:
markers = {"AdaDFKD(S)": "X", "AdaDFKD(G)": "X"}
sns.scatterplot(data=df, x='Time(h)', y='Acc@1', hue='method', markers=markers)

<AxesSubplot:xlabel='Time(h)', ylabel='Acc@1'>

<Figure size 432x288 with 1 Axes>

In [3]:
distributed = False
gpu = 5
# gpu ='0,1'
batch_size = 128
workers = 8
num_classes = 10
# num_classes = 100
# num_classes = 200
def prepare_model(model):
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
        return model
    elif distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpu is not None:
#             torch.cuda.set_device(gpu)
            model.cuda()
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            batch_size = int(batch_size / 1)
            workers = int((workers + 1 - 1) / 1)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(x) for x in gpu.split(',')])
            return model
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            return model
    elif gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
        return model
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()
        return model

In [4]:
from torchvision.datasets import CIFAR10,CIFAR100
import datafree
import registry
from torch import nn
student = registry.get_model('resnet18', num_classes=num_classes)
teacher = registry.get_model('resnet34', num_classes=num_classes, pretrained=True).eval()
# student = registry.get_model('wrn40_1', num_classes=num_classes)
# student= registry.get_model('wrn16_2', num_classes=num_classes)
# teacher = registry.get_model('wrn40_2', num_classes=num_classes)
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['tiny_imagenet'])
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar100'])
student = prepare_model(student)

# teacher = teacher.to(gpu)
# teacher.avgpool = nn.AdaptiveAvgPool2d(1)
# num_ftrs = teacher.fc.in_features
# teacher.fc = nn.Linear(num_ftrs, 200)
# teacher.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
# teacher.maxpool = nn.Sequential()
teacher = prepare_model(teacher)
# ckpt = torch.load('../checkpoints/scratch/tiny_imagenet_resnet34_imagenet.pth', map_location='cpu')
# dict_ckpt = dict()
# for k, v in ckpt['state_dict'].items():
#     dict_ckpt['.'.join(k.split('.')[1:])] = v
# teacher.load_state_dict(dict_ckpt)

teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_resnet34.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_wrn40_2.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar100_wrn40_2.pth', map_location='cpu')['state_dict'])
# print(ckpt['best_acc1'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar100_resnet34.pth', map_location='cpu')['state_dict'])
teacher.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [5]:
!ls /data/lijingru/DFKD/run/cr4_sim_normalize_pos_c

buffer.pt


In [6]:
anchor_bank = torch.load('/data/lijingru/DFKD/run/cr6_sim_normalize_pos/buffer.pt', map_location='cpu')
# all_anchor = anchor_bank.reshape(-1, 512)
print(anchor_bank.shape)

torch.Size([10, 768, 512])


In [21]:
import torch.nn.functional as F
def difficulty_loss(anchor, teacher, t_out, logit_t, ds='cifar10', hard_factor=0., tau=10, device='cpu', d_neg_fea=None):
    batch_size = anchor.size(0)
    with torch.no_grad():
        # t_logit, anchor_t_out = teacher(anchor.to(device).detach(), return_features=True)
        t_logit = teacher(anchor.to(device).detach())
        anchor_t_out = anchor.to(device)
        # pseudo_label = pseudo_label.argmax(1)
    # loss = 0.
    pos_loss = 0.
    neg_loss = 0.
    if ds == 'cifar10':
        normalized_anchor_t_out, normalized_t_out = F.normalize(anchor_t_out, dim=1), F.normalize(t_out, dim=1)
        d = torch.mm(normalized_anchor_t_out, normalized_t_out.T)
        N_an, N_batch = d.size()
        
        sorted_d, indice_d = torch.sort(d, dim=1)
        d_pos = sorted_d[:, -int(0.05 * N_batch):]
        d_neg = sorted_d[:, :-int(0.05 * N_batch)]
        # n_neg = d_neg.size(1)
        d_mask = torch.zeros_like(indice_d)
        d_mask = d_mask.scatter(1, indice_d[:, -int(0.1*N_batch):], 1)
        p_t_anchor = torch.softmax(t_logit, 1)
        p_t_batch = torch.softmax(logit_t, 1)
        kld_matrix = -torch.mm(p_t_anchor, p_t_batch.T.log()) + torch.diag(torch.mm(p_t_anchor, p_t_anchor.T.log())).unsqueeze(1)
        l_kld = ((kld_matrix * d_mask).sum(1) / d_mask.sum(1)).mean()
        # Get positive DA index
        p_pos = torch.softmax(d_pos / tau, dim=1)
        p_da_pos = torch.quantile(p_pos, q=1-hard_factor, dim=1).unsqueeze(1)
        pos_loss = torch.sum(p_pos * torch.log(p_pos / p_da_pos).abs(), dim=1).mean()
        # Get Negative DA index
        
        if d_neg_fea is not None:
            d = torch.cat([d_neg, d_pos], 1)
            d_mask = torch.zeros_like(d)
#             d_mask[:, ]
        p_total = torch.softmax(d / tau, dim=1)
        # Out supervised loss.
        print(d_mask, d_mask.shape)
        neg_loss = -((d_mask * p_total.log()).sum(1) / (d_mask.sum(1))).mean()
#         print(pos_loss, neg_loss, l_kld)
        
        return pos_loss, indice_d, neg_loss, l_kld

In [14]:
tg = datafree.models.generator.DCGAN_Generator_CIFAR10(nz=512, ngf=64, nc=3, img_size=32, d=2, cond=False, type='normal', widen_factor=1)
tg.load_state_dict(torch.load('/data/lijingru/DFKD/checkpoints/datafree-improved_cudfkd/cifar10-resnet34-resnet18--cr6_sim_normalize_pos.pth', map_location='cpu')['G'])
prepare_model(tg)

DCGAN_Generator_CIFAR10(
  (project): Sequential(
    (0): Flatten()
    (1): Linear(in_features=512, out_features=16384, bias=True)
  )
  (main): Sequential(
    (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode=nearest)
    (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Upsample(scale_factor=2.0, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Sigmoid()
  )
)

In [22]:
z = torch.randn(512, 512).to(gpu)
x = normalizer(tg(z))
t_out, t_feat = teacher(x, return_features=True)

a, b, c, d = difficulty_loss(anchor_bank[2], teacher.linear, t_feat, logit_t=t_out, device=gpu, hard_factor=0., tau=0.07)
print(a, b, c, d)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 1, 0],
        ...,
        [0, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0]], device='cuda:1') torch.Size([768, 512])
tensor(0.6864, device='cuda:1', grad_fn=<MeanBackward0>) tensor([[228, 303, 256,  ..., 100, 253, 286],
        [403, 220, 114,  ..., 189, 298, 500],
        [206, 220, 372,  ...,  48, 146, 510],
        ...,
        [ 15,  32, 363,  ...,  65, 456,  18],
        [401, 312, 421,  ..., 286, 149, 315],
        [250, 203, 403,  ...,  48, 500, 499]], device='cuda:1') tensor(4.5037, device='cuda:1', grad_fn=<NegBackward>) tensor(1.1287, device='cuda:1', grad_fn=<MeanBackward0>)


In [26]:
from torchvision.utils import save_image,make_grid
import matplotlib.pyplot as plt
indexs = b[:, :-int(0.1*512)]
img = make_grid(x[indexs[0]], nrow=7, padding=2, value_range=(0, 1))
# plt.show(img)

In [19]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [27]:
save_image(img, 'neg1.jpg')

In [None]:
make_grid()