In [1]:
import torch
import numpy as np
import sys
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import os
import tqdm
sys.path.append('..')
import datafree
%config InlineBackend.figure_format = 'pdf'

In [2]:
!nvidia-smi

Thu Jun  1 22:43:37 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090         On | 00000000:05:00.0 Off |                  N/A |
| 74%   66C    P2              335W / 350W|  15455MiB / 24576MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090         On | 00000000:09:00.0 Off |  

In [3]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


accs = {
    'DAFL': [11.9, 81.10],
    'ADI': [(28 / 60 + 9 / 3600)*50, 89.46],
    'DFQ': [(0.25 + 34/3600)*50, 90.84],
    'CMI': [(1 + 23.5/60)*50, 91.13],
    'CuDFKD': [13 + 3/60, 91.61],
    'AdaDFKD(S)': [14 + 28/60, 92.04],
    'AdaDFKD(G)': [15 + 34/60, 92.19]
}

acc = {
    'Time(h)': [11.9, (28 / 60 + 9 / 3600)*50,(0.25 + 34/3600)*50, (1 + 23.5/60)*50, 13 + 3/60,14 + 28/60, 15 + 34/60 ],
    'Acc@1':[81.10,  89.46, 90.84, 91.13, 91.61, 92.04, 92.19],
    'method': accs.keys(),
    'size': [300, 300, 300, 300, 300, 5000, 5000]
}

df = pd.DataFrame(acc)
df.to_csv('acc.csv', index=None)
print(df)

     Time(h)  Acc@1      method  size
0  11.900000  81.10        DAFL   300
1  23.458333  89.46         ADI   300
2  12.972222  90.84         DFQ   300
3  69.583333  91.13         CMI   300
4  13.050000  91.61      CuDFKD   300
5  14.466667  92.04  AdaDFKD(S)  5000
6  15.566667  92.19  AdaDFKD(G)  5000


In [4]:
def text_df(ax, df):
    for i in range(len(df)):
        method = df.loc[i]['method']
        acc1 = df.loc[i]['Acc@1']
        time = df.loc[i]['Time(h)']
        if method == 'DFQ' or method == 'CuDFKD':
            x_shift, y_shift = -3, 0
        else:
            x_shift, y_shift = -1, 0.8
        ax.text(time+x_shift, acc1+y_shift, method, color='black')

In [5]:
markers = {"AdaDFKD(S)": "*", "AdaDFKD(G)": "*", "CuDFKD": "o", "CMI":"o", "DFQ":"o", "ADI":"o", "DAFL":"o"}
ax = sns.scatterplot(data=df, x='Time(h)', y='Acc@1', markers=markers, s=[50, 50, 50, 50, 50, 300, 300])
text_df(ax, df)
# ax.set_xscale('log', base=2, subs=[1.25, 1.5])
# ax.set_xlim(7, 64)
# ax.set_xticklabels([ 0, 8, 16, 32, 64])
ax.grid()

<Figure size 432x288 with 1 Axes>

In [3]:
distributed = False
gpu = 3
# gpu ='0,1'
batch_size = 128
workers = 8
num_classes = 10
# num_classes = 100
# num_classes = 200
def prepare_model(model):
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
        return model
    elif distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpu is not None:
#             torch.cuda.set_device(gpu)
            model.cuda()
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            batch_size = int(batch_size / 1)
            workers = int((workers + 1 - 1) / 1)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(x) for x in gpu.split(',')])
            return model
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            return model
    elif gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
        return model
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()
        return model

In [4]:
from torchvision.datasets import CIFAR10,CIFAR100
import datafree
import registry
from torch import nn
student = registry.get_model('resnet18', num_classes=num_classes)
teacher = registry.get_model('vgg11', num_classes=num_classes, pretrained=True).eval()
# student = registry.get_model('wrn40_1', num_classes=num_classes)
# student= registry.get_model('wrn16_2', num_classes=num_classes)
# teacher = registry.get_model('wrn40_2', num_classes=num_classes)
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['tiny_imagenet'])
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
# normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar100'])
student = prepare_model(student)

# teacher = teacher.to(gpu)
# teacher.avgpool = nn.AdaptiveAvgPool2d(1)
# num_ftrs = teacher.fc.in_features
# teacher.fc = nn.Linear(num_ftrs, 200)
# teacher.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
# teacher.maxpool = nn.Sequential()
teacher = prepare_model(teacher)
# ckpt = torch.load('../checkpoints/scratch/tiny_imagenet_resnet34_imagenet.pth', map_location='cpu')
# dict_ckpt = dict()
# for k, v in ckpt['state_dict'].items():
#     dict_ckpt['.'.join(k.split('.')[1:])] = v
# teacher.load_state_dict(dict_ckpt)

# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_resnet34.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_wrn40_2.pth', map_location='cpu')['state_dict'])
# teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar100_wrn40_2.pth', map_location='cpu')['state_dict'])
# print(ckpt['best_acc1'])
teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_vgg11.pth', map_location='cpu')['state_dict'])
teacher.eval()

VGG(
  (block0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block2): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
 

In [11]:
!ls ../run/cr4_sim_normalize_pos_c

ls: cannot access '../run/cr4_sim_normalize_pos_c': No such file or directory


dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer', 'scheduler', 'G', 'neg_bank'])
torch.Size([512, 8192])


In [3]:
import torch.nn.functional as F
def difficulty_loss(anchor, teacher, t_out, logit_t, ds='cifar10', hard_factor=0., tau=10, device='cpu', d_neg_fea=None):
    batch_size = anchor.size(0)
    with torch.no_grad():
        # t_logit, anchor_t_out = teacher(anchor.to(device).detach(), return_features=True)
        t_logit = teacher(anchor.to(device).detach())
        anchor_t_out = anchor.to(device)
        # pseudo_label = pseudo_label.argmax(1)
    # loss = 0.
    pos_loss = 0.
    neg_loss = 0.
    if ds == 'cifar10':
        normalized_anchor_t_out, normalized_t_out = F.normalize(anchor_t_out, dim=1), F.normalize(t_out, dim=1)
        d = torch.mm(normalized_anchor_t_out, normalized_t_out.T)
        N_an, N_batch = d.size()
        
        sorted_d, indice_d = torch.sort(d, dim=1)
        d_pos = sorted_d[:, -int(0.05 * N_batch):]
        d_neg = sorted_d[:, :-int(0.05 * N_batch)]
        # n_neg = d_neg.size(1)
        d_mask = torch.zeros_like(indice_d)
        d_mask = d_mask.scatter(1, indice_d[:, -int(0.1*N_batch):], 1)
        p_t_anchor = torch.softmax(t_logit, 1)
        p_t_batch = torch.softmax(logit_t, 1)
        kld_matrix = -torch.mm(p_t_anchor, p_t_batch.T.log()) + torch.diag(torch.mm(p_t_anchor, p_t_anchor.T.log())).unsqueeze(1)
        l_kld = ((kld_matrix * d_mask).sum(1) / d_mask.sum(1)).mean()
        # Get positive DA index
        p_pos = torch.softmax(d_pos / tau, dim=1)
        p_da_pos = torch.quantile(p_pos, q=1-hard_factor, dim=1).unsqueeze(1)
        pos_loss = torch.sum(p_pos * torch.log(p_pos / p_da_pos).abs(), dim=1).mean()
        # Get Negative DA index
        
        if d_neg_fea is not None:
            d = torch.cat([d_neg, d_pos], 1)
            d_mask = torch.zeros_like(d)
#             d_mask[:, ]
        p_total = torch.softmax(d / tau, dim=1)
        # Out supervised loss.
        print(d_mask, d_mask.shape)
        neg_loss = -((d_mask * p_total.log()).sum(1) / (d_mask.sum(1))).mean()
#         print(pos_loss, neg_loss, l_kld)
        
        return pos_loss, indice_d, neg_loss, l_kld

In [5]:
l_neg_dict = {}
l_neg_total = {}
for epoch in range(75, 225, 50):
    student_ckpt = torch.load('../checkpoints/datafree-adadfkd/cifar10-vgg11-resnet18--infonce_retest_2_5_temp/checkpoint_{}.pth'.format(epoch - 1), map_location='cpu')
    print(student_ckpt.keys())
    buffer = student_ckpt['neg_bank']
    print(buffer.shape)

    tg = datafree.models.generator.DCGAN_Generator_CIFAR10(nz=512, ngf=64, nc=3, img_size=32, d=2, cond=False, type='normal', widen_factor=1)
    tg.load_state_dict(student_ckpt['G'])
    prepare_model(tg)
    
    student.load_state_dict(student_ckpt['state_dict'])
    z = torch.randn(512, 512).to(gpu)
    x = normalizer(tg(z))
    t_out, t_feat = teacher(x, return_features=True)

    s_out, s_feat= student(x, return_features=True)
    tau = 0.07
    t_feat = torch.nn.functional.normalize(t_feat, dim=-1)
    s_feat = torch.nn.functional.normalize(s_feat, dim=-1)

    l_neg = t_feat.T.cpu() @ buffer
    _, l_indices = torch.sort(l_neg, dim=1)
    n, ks = l_neg.size() 
    # Ring Length
    length = 0.8
    # epoch = 99
    begin_fraction = 0.25
    end_fraction = 0.75
    hard_factor = min(max(0, (1 - length) * (epoch - int(begin_fraction * 300)) / int(300 * (end_fraction-begin_fraction))), 1 - length)

    ring_indices = l_indices[:, int(hard_factor * ks) : int((hard_factor + length) * ks)]
    l_neg_ring = torch.gather(l_neg, dim=1, index=ring_indices)
    print(l_neg.shape, l_indices.shape, l_neg_ring.shape)
    l_neg_mean = l_neg_ring.mean(1).detach().numpy()
    l_neg_dict['Epoch {}'.format(epoch)] = l_neg_mean
    l_neg_total['Epoch {}'.format(epoch)] = l_neg_ring.detach().numpy()
    
print(l_neg_dict)

dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer', 'scheduler', 'G', 'neg_bank'])
torch.Size([512, 8192])
torch.Size([512, 8192]) torch.Size([512, 8192]) torch.Size([512, 6553])
dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer', 'scheduler', 'G', 'neg_bank'])
torch.Size([512, 8192])
torch.Size([512, 8192]) torch.Size([512, 8192]) torch.Size([512, 6553])
dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer', 'scheduler', 'G', 'neg_bank'])
torch.Size([512, 8192])
torch.Size([512, 8192]) torch.Size([512, 8192]) torch.Size([512, 6553])
{'Epoch 75': array([0.13276483, 0.18832718, 0.2120286 , 0.28232142, 0.26459154,
       0.21497492, 0.27963513, 0.19833717, 0.371825  , 0.2533015 ,
       0.33420828, 0.34755844, 0.46352285, 0.21470517, 0.14946872,
       0.2010201 , 0.21097696, 0.15700556, 0.17508355, 0.3166058 ,
       0.17741726, 0.19945748, 0.38317114, 0.37472472, 0.4086394 ,
       0.27323613, 0.3902353 , 0.23146543, 0.43364748, 0.30908725,


In [10]:
student.load_state_dict(student_ckpt['state_dict'])
z = torch.randn(512, 512).to(gpu)
x = normalizer(tg(z))
t_out, t_feat = teacher(x, return_features=True)

s_out, s_feat= student(x, return_features=True)
tau = 0.07
t_feat = torch.nn.functional.normalize(t_feat, dim=-1)
s_feat = torch.nn.functional.normalize(s_feat, dim=-1)

l_neg = t_feat.T.cpu() @ buffer
_, l_indices = torch.sort(l_neg, dim=1)
n, ks = l_neg.size() 
# Ring Length
length = 0.8
# epoch = 99
begin_fraction = 0.25
end_fraction = 0.75
hard_factor = min(max(0, (1 - length) * (epoch - int(begin_fraction * 300)) / int(300 * (end_fraction-begin_fraction))), 1 - length)

ring_indices = l_indices[:, int(hard_factor * ks) : int((hard_factor + length) * ks)]
l_neg_ring = torch.gather(l_neg, dim=1, index=ring_indices)
print(l_neg.shape, l_indices.shape, l_neg_ring.shape)

# balance = torch.nn.functional.cross_entropy(t_out, t_out.argmax(1), reduction='none')
# _ , index = torch.sort(balance)
# img_anchor = x[index[:10]].cpu()
# label = t_out[index[:10]].argmax(1)
# print(label)
# new_img = x.cpu()[indice_neg[index[:10]]]
# imgs = new_img.reshape(5120, 3, 32, 32)

torch.Size([512, 8192]) torch.Size([512, 8192]) torch.Size([512, 6553])


In [8]:

if not os.path.exists('Density_type1/'):
    os.mkdir('Density_type1/')
fig = plt.figure(figsize=(8,5))
sns.set()
# l_neg_mean = l_neg_ring.mean(1)
sns.displot(l_neg_dict, kind='kde',legend=True)

# plt.legend()
plt.xlabel(r'$\mathbf{f}_{t}^T\mathbf{f}_{neg}, \mathbf{f}_{neg} \sim q_{neg}$', fontsize=18)
plt.ylabel('Density', fontsize=18)
plt.savefig('Density_type1/Anchor_mean1.pdf')
# plt.close()
# sns.set()
# print(l_neg_ring[-3])
# for i in tqdm.tqdm(range(n), desc='Running'):
#     sns.set()
#     ax = sns.displot(l_neg_ring[i].detach().numpy(), kind='kde')
#     plt.xlabel(r'$\mathbf{f}_{t}^T\mathbf{f}_{neg}, \mathbf{f}_{neg} \sim q_{neg}$', fontsize=18)
#     plt.ylabel('Density', fontsize=18)
#     plt.savefig('Density_type1/Anchor{}.pdf'.format(i))
#     plt.close()

AttributeError: 'FacetGrid' object has no attribute 'transAxes'

<Figure size 576x360 with 0 Axes>

<Figure size 457.85x360 with 1 Axes>

In [32]:
imgs = 100
anchor_imgs = 100
vis_neg = {}
new_vis_neg = {}
import numpy as np
for epoch in [75, 125, 175]:
    vis_neg['Epoch {}'.format(epoch)] = []
    for i in range(imgs):
#         print(i)
        idx = i * (ring_indices.shape[1] // imgs)
        vis_neg['Epoch {}'.format(epoch)].append(l_neg_total['Epoch {}'.format(epoch)][:, idx])
    vis_neg['Epoch {}'.format(epoch)] = np.stack(vis_neg['Epoch {}'.format(epoch)], 1)
    new_vis_neg['Epoch {}'.format(epoch)] = []
    for j in range(anchor_imgs):
        idx = j * (n // anchor_imgs)
        new_vis_neg['Epoch {}'.format(epoch)].append(vis_neg['Epoch {}'.format(epoch)][idx])
    new_vis_neg['Epoch {}'.format(epoch)] = np.stack(new_vis_neg['Epoch {}'.format(epoch)], 0)


In [36]:
plt.figure(figsize=(8, 5))
sns.heatmap(new_vis_neg['Epoch 75'], vmin=0.05, vmax=0.28, cmap='coolwarm', xticklabels=False, yticklabels=False)
plt.savefig('heatmap75.pdf')

<Figure size 576x360 with 2 Axes>

In [37]:
plt.figure(figsize=(8, 5))
sns.heatmap(new_vis_neg['Epoch 125'], vmin=0.05, vmax=0.28, cmap='coolwarm', xticklabels=False, yticklabels=False)
plt.savefig('heatmap125.pdf')

<Figure size 576x360 with 2 Axes>

In [38]:
plt.figure(figsize=(8, 5))
sns.heatmap(new_vis_neg['Epoch 175'], vmin=0.05, vmax=0.28, cmap='coolwarm', xticklabels=False, yticklabels=False)
plt.savefig('heatmap175.pdf')

<Figure size 576x360 with 2 Axes>

In [10]:
print(balance)

tensor([3.5971e-02, 4.9913e-03, 4.8435e-04, 1.1491e-04, 6.2307e-02, 5.7498e-03,
        1.7951e-04, 6.6437e-04, 4.7887e-04, 6.1286e-01, 1.2127e-02, 3.6090e-04,
        1.7019e-03, 1.2994e-05, 2.5305e-02, 1.5854e-04, 4.3419e-04, 1.8120e-05,
        1.4477e-01, 3.5065e-04, 1.4989e-03, 1.1420e-04, 2.0437e-01, 2.1738e-02,
        3.8385e-05, 7.2150e-01, 8.9808e-04, 1.7343e-04, 3.8223e-04, 1.4066e-04,
        3.9219e-05, 8.0208e-04, 5.1520e-01, 1.2898e-04, 1.2065e-01, 3.7653e-03,
        2.8852e-03, 6.1629e-05, 4.8340e-04, 5.2118e-02, 2.6861e-02, 2.1320e-01,
        2.8856e-04, 2.6995e-01, 4.6598e-01, 2.6723e-04, 5.1318e-04, 2.8630e-03,
        3.2706e-04, 3.8578e-03, 2.0919e-01, 2.8749e-04, 4.1369e-04, 6.8200e-04,
        5.4666e-03, 1.8004e-02, 6.1529e-04, 1.0475e+00, 3.0418e-04, 4.0590e-02,
        2.5578e-02, 2.0443e-01, 8.7438e-04, 4.4751e-02, 1.0169e-02, 2.3386e-04,
        7.3311e-05, 1.8726e-04, 1.7403e-04, 4.6731e-04, 1.4777e-03, 7.5695e-05,
        9.5024e-04, 7.8147e-04, 3.0604e-

In [13]:
from torchvision.utils import save_image,make_grid
import matplotlib.pyplot as plt
img_list = []
for i in torch.arange(0.01, 1.01, 0.1):
    print(i)
    img_list.append(new_img[:, -int(512 * i), :, :, :].unsqueeze(1))
    
img_list = torch.cat(img_list, 1)
this_img = img_list.reshape(100, 3, 32, 32)
img = make_grid(img_anchor, nrow=1, padding=2, normalize=True)
neg_img = make_grid(this_img, nrow=10, padding=2, normalize=True)
print(img.max(), img.min())

tensor(0.0100)
tensor(0.1100)
tensor(0.2100)
tensor(0.3100)
tensor(0.4100)
tensor(0.5100)
tensor(0.6100)
tensor(0.7100)
tensor(0.8100)
tensor(0.9100)
tensor(1.) tensor(0.)


In [19]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [14]:
save_image(img, 'anchor.jpg')
save_image(neg_img, 'neg2.jpg')