In [1]:
import torch
import numpy as np
from scipy.spatial import distance
import os

import sys
sys.path.append('..')
import registry
import datafree

In [2]:
!nvidia-smi

Sat Aug  6 04:37:12 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.60.02    Driver Version: 510.60.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:05:00.0 Off |                  N/A |
| 71%   65C    P2   320W / 350W |  16687MiB / 24576MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:09:00.0 Off |                  N/A |
| 51%   54C    P2   177W / 350W |  24043MiB / 24576MiB |     32%      Default |
|       

In [3]:
distributed = False
gpu = 2
batch_size = 128
workers = 8
num_classes = 10
def prepare_model(model):
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
        return model
    elif distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(gpu)
            model.cuda(gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            batch_size = int(batch_size / ngpus_per_node)
            workers = int((workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
            return model
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            return model
    elif gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
        return model
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()
        return model

In [4]:
from torchvision.datasets import CIFAR10,CIFAR100
import datafree
import registry
student = registry.get_model('resnet18', num_classes=num_classes)
teacher = registry.get_model('resnet34', num_classes=num_classes, pretrained=True).eval()
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
student = prepare_model(student)
teacher = prepare_model(teacher)
teacher.load_state_dict(torch.load('../checkpoints/scratch/cifar10_resnet34.pth', map_location='cpu')['state_dict'])
teacher.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [5]:
!ls /data/lijingru/DataFree/checkpoints/datafree-dafl/

cifar10-resnet34-resnet18--dafl_interval.pth
cifar10-vgg11-resnet18--dafl_interval_2.pth
log-cifar10-resnet34-resnet18-dafl_interval.txt
log-cifar10-vgg11-resnet18-dafl_interval_2.txt
log-cifar10-wrn40_2-wrn16_1-dafl_interval.txt
synthetic-dafl_interval_2.png
synthetic-dafl_interval.png


In [6]:
# generator = datafree.models.generator.LargeGenerator(nz=512, ngf=64, img_size=32, nc=3)
# generator = prepare_model(generator)
g = datafree.models.generator.DCGAN_Generator_CIFAR10(nz=512, ngf=64, nc=3, img_size=32, d=2, cond=False)
g = prepare_model(g)
ckpt = torch.load('/data/lijingru/DataFree/checkpoints/datafree-cudfkd/cifar10-resnet34-resnet18--cudfkd_L2_line98_2.pth', map_location='cpu')
print(ckpt.keys())
G_ckpt = ckpt['G_0']
student.load_state_dict(ckpt['state_dict'])
g.load_state_dict(G_ckpt)

dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer', 'scheduler', 'G_0'])


<All keys matched successfully>

In [10]:
normalizer = datafree.utils.Normalizer(**registry.NORMALIZE_DICT['cifar10'])
z = torch.randn(batch_size, 512).to(gpu)
x = g(z)
t_out = teacher(normalizer(x))
s_out = student(normalizer(x))
t_prob = torch.softmax(t_out, 1).detach().cpu().numpy()
s_prob = torch.softmax(s_out, 1).detach().cpu().numpy()
js = distance.jensenshannon(t_prob.T, s_prob.T)
import numpy as np
prob = 1 - np.sqrt(js)
print(prob.mean())

0.41153455
