In [1]:
!pip install import_ipynb
import import_ipynb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Test original dataset on model to establish upper bound

In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import time 


class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'



args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

net = get_network(args.model, channel, num_classes, im_size).to(args.device)

net.train()
optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
optimizer_net.zero_grad()

criterion = nn.CrossEntropyLoss().to(args.device)

num_epochs = 20

trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)

start = time.time()

for i in range(num_epochs):
  loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
  for param_group in optimizer_net.param_groups:
    print("Current learning rate is: {}".format(param_group['lr']))

  print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
  scheduler.step()

print(time.time() - start)


Current learning rate is: 0.01
Epoch[1/20]: Loss: 0.5307, ACC: 0.9073
Current learning rate is: 0.009938441702975689
Epoch[2/20]: Loss: 0.1495, ACC: 0.9681
Current learning rate is: 0.009755282581475769
Epoch[3/20]: Loss: 0.1055, ACC: 0.9756
Current learning rate is: 0.00945503262094184
Epoch[4/20]: Loss: 0.0852, ACC: 0.9799
Current learning rate is: 0.009045084971874739
Epoch[5/20]: Loss: 0.0728, ACC: 0.9824
Current learning rate is: 0.008535533905932738
Epoch[6/20]: Loss: 0.0645, ACC: 0.9843
Current learning rate is: 0.007938926261462366
Epoch[7/20]: Loss: 0.0591, ACC: 0.9857
Current learning rate is: 0.007269952498697735
Epoch[8/20]: Loss: 0.0545, ACC: 0.9868
Current learning rate is: 0.006545084971874738
Epoch[9/20]: Loss: 0.0508, ACC: 0.9875
Current learning rate is: 0.005782172325201155
Epoch[10/20]: Loss: 0.0483, ACC: 0.9883
Current learning rate is: 0.005
Epoch[11/20]: Loss: 0.0460, ACC: 0.9888
Current learning rate is: 0.004217827674798847
Epoch[12/20]: Loss: 0.0442, ACC: 0.98

In [None]:
testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
net.eval()
loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))

Test set: Loss: 0.0345, ACC: 0.9906


## Dataset distillation with matching gradient 

In [None]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug


class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'noise'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

def main():

    # parser = argparse.ArgumentParser(description='Parameter Processing')
    # parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
    # parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    # parser.add_argument('--model', type=str, default='ConvNet', help='model')
    # parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    # parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    # parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    # parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    # parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
    # parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    # parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    # parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    # parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    # parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    # parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    # parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    # parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    # parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    # parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

    # args = parser.parse_args()
    args = myArgs()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False


    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 20).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    if args.dsa:
                        args.epoch_eval_train = 1000
                        args.dc_aug_param = None
                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
                    else:
                        args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
                        print('DC augmentation parameters: \n', args.dc_aug_param)

                    if args.dsa or args.dc_aug_param['strategy'] != 'none':
                        args.epoch_eval_train = 1000  # Training with data augmentation needs more epochs.
                    else:
                        args.epoch_eval_train = 300

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.


            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)


            loss_avg /= (num_classes*args.outer_loop)

            if it%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))


In [None]:
main()

eval_it_pool:  [0, 20, 40, 60, 80, 100]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 5, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'None', 'data_path': 'data', 'save_path': 'result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7efe3ea05750>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 0, mean = -0.0001, std = 1.0000
initialize synthetic data from random real images
[2022-11-26 00:05:57] training begins

In [None]:
main()

eval_it_pool:  [0, 20, 40, 60, 80, 100]
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Attention: Here I will replace BN with IN in evaluation, as the synthetic set is too small to measure BN hyper-parameters.

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'CIFAR10', 'model': 'ResNet18BN_AP', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 5, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'None', 'data_path': 'data', 'save_path': 'result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7f8a44d56e50>, 'dsa': False}
Evaluation model pool:  ['ResNet18']
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real



[2022-11-26 03:01:56] Evaluate_00: epoch = 0300 train time = 33 s train loss = 0.005744 train acc = 1.0000, test acc = 0.1986
[2022-11-26 03:02:29] Evaluate_01: epoch = 0300 train time = 27 s train loss = 0.005755 train acc = 1.0000, test acc = 0.2089
[2022-11-26 03:03:02] Evaluate_02: epoch = 0300 train time = 28 s train loss = 0.005846 train acc = 1.0000, test acc = 0.1960
[2022-11-26 03:03:35] Evaluate_03: epoch = 0300 train time = 28 s train loss = 0.005902 train acc = 1.0000, test acc = 0.2062
[2022-11-26 03:04:09] Evaluate_04: epoch = 0300 train time = 29 s train loss = 0.005653 train acc = 1.0000, test acc = 0.2030
Evaluate 5 random ResNet18, mean = 0.2025 std = 0.0047
-------------------------
[2022-11-26 03:05:30] iter = 0000, loss = 1895.7798
[2022-11-26 03:18:59] iter = 0010, loss = 2683.2079
-------------------------
Evaluation
model_train = ResNet18BN_AP, model_eval = ResNet18, iteration = 20
DC augmentation parameters: 
 {'crop': 4, 'scale': 0.2, 'rotate': 45, 'noise': 0.

In [None]:
main()

eval_it_pool:  [0, 20, 40, 60, 80, 100]

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 5, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'noise', 'dsa_strategy': 'None', 'data_path': 'data', 'save_path': 'result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7fc13e328210>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
class c = 0: 5923 real images
class c = 1: 6742 real images
class c = 2: 5958 real images
class c = 3: 6131 real images
class c = 4: 5842 real images
class c = 5: 5421 real images
class c = 6: 5918 real images
class c = 7: 6265 real images
class c = 8: 5851 real images
class c = 9: 5949 real images
real images channel 0, mean = -0.0001, std = 1.0000
initialize synthetic data from random noise
[2022-11-27 16:30:24] training begins
----



[2022-11-27 16:31:00] Evaluate_00: epoch = 1000 train time = 34 s train loss = 0.012398 train acc = 1.0000, test acc = 0.0855
[2022-11-27 16:31:37] Evaluate_01: epoch = 1000 train time = 35 s train loss = 0.008104 train acc = 1.0000, test acc = 0.0779
[2022-11-27 16:32:14] Evaluate_02: epoch = 1000 train time = 34 s train loss = 0.019059 train acc = 1.0000, test acc = 0.0687
[2022-11-27 16:32:51] Evaluate_03: epoch = 1000 train time = 34 s train loss = 0.010604 train acc = 1.0000, test acc = 0.0666
[2022-11-27 16:33:28] Evaluate_04: epoch = 1000 train time = 34 s train loss = 0.010578 train acc = 1.0000, test acc = 0.0496
Evaluate 5 random ConvNet, mean = 0.0697 std = 0.0121
-------------------------
[2022-11-27 16:33:37] iter = 0000, loss = 221.2675
[2022-11-27 16:35:10] iter = 0010, loss = 93.1256
-------------------------
Evaluation
model_train = ConvNet, model_eval = ConvNet, iteration = 20
DC augmentation parameters: 
 {'crop': 4, 'scale': 0.2, 'rotate': 45, 'noise': 0.001, 'strat

In [None]:
!zip -r /content/result.zip /content/result
from google.colab import files
files.download("/content/result.zip")

updating: content/result/ (stored 0%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter60.png (deflated 5%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter40.png (deflated 4%)
updating: content/result/res_DC_MNIST_ConvNet_10ipc.pt (deflated 9%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter100.png (deflated 5%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter20.png (deflated 4%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter0.png (deflated 5%)
updating: content/result/vis_DC_MNIST_ConvNet_10ipc_exp0_iter80.png (deflated 5%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Load up distilled dataset and train model 

In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import time

class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


distilled_set_name = '/content/result/res_DC_MNIST_ConvNet_10ipc.pt'
dataset = torch.load(distilled_set_name)
syn_data = dataset['data']
syn_data = TensorDataset(syn_data[0][0], syn_data[0][1])

channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
net = get_network(args.model, channel, num_classes, im_size).to(args.device)


net.train()
optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
optimizer_net.zero_grad()

criterion = nn.CrossEntropyLoss().to(args.device)

num_epochs = 20

trainloader = torch.utils.data.DataLoader(syn_data, batch_size=1, shuffle=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)


start = time.time()

for i in range(num_epochs):
  loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
  for param_group in optimizer_net.param_groups:
    print("Current learning rate is: {}".format(param_group['lr']))

  print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
  scheduler.step()

print(time.time() - start)

testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
net.eval()
loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))

Current learning rate is: 0.01
Epoch[1/20]: Loss: 3.5412, ACC: 0.0400
Current learning rate is: 0.009938441702975689
Epoch[2/20]: Loss: 1.6428, ACC: 0.3900
Current learning rate is: 0.009755282581475769
Epoch[3/20]: Loss: 1.0073, ACC: 0.6000
Current learning rate is: 0.00945503262094184
Epoch[4/20]: Loss: 0.4715, ACC: 0.8600
Current learning rate is: 0.009045084971874739
Epoch[5/20]: Loss: 0.2918, ACC: 0.9700
Current learning rate is: 0.008535533905932738
Epoch[6/20]: Loss: 0.0921, ACC: 1.0000
Current learning rate is: 0.007938926261462366
Epoch[7/20]: Loss: 0.0562, ACC: 1.0000
Current learning rate is: 0.007269952498697735
Epoch[8/20]: Loss: 0.0385, ACC: 1.0000
Current learning rate is: 0.006545084971874738
Epoch[9/20]: Loss: 0.0303, ACC: 1.0000
Current learning rate is: 0.005782172325201155
Epoch[10/20]: Loss: 0.0263, ACC: 1.0000
Current learning rate is: 0.005
Epoch[11/20]: Loss: 0.0230, ACC: 1.0000
Current learning rate is: 0.004217827674798847
Epoch[12/20]: Loss: 0.0211, ACC: 1.00

## Test trajectory matching results

In [None]:
!pip install kornia
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 36.4 MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.29-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 68.9 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.11.1-py2.py3-none-any.whl (168 kB)
[K     |█████

In [None]:
!python distill.py --dataset=CIFAR10 --model=ConvNet --ipc=10 --syn_steps=30 --expert_epochs=2 --max_start_epoch=15 --Iteration=1000 --num_eval=2

CUDNN STATUS: True
Files already downloaded and verified
Files already downloaded and verified
[34m[1mwandb[0m: Currently logged in as: [33mfredshi1997[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.13.5
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20221129_205400-3axihmoj[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mclear-river-13[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/fredshi1997/dip[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/fredshi1997/dip/runs/3axihmoj[0m
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_eval': 2, 'eval_it': 100, 'epoch_eval_train': 1000, 'Iteration': 1000, 'lr_img': 1000, 'lr_lr': 1e-05, 'lr_teacher': 0.01, 'lr_init': 0.01, 'batch_real': 256, 'batch_syn': 100, 'batch_train': 256,

In [None]:
!zip -r /content/logged_files.zip /content/logged_files
from google.colab import files
files.download("/content/logged_files.zip")

  adding: content/logged_files/ (stored 0%)
  adding: content/logged_files/CIFAR10/ (stored 0%)
  adding: content/logged_files/CIFAR10/.ipynb_checkpoints/ (stored 0%)
  adding: content/logged_files/CIFAR10/clear-river-13/ (stored 0%)
  adding: content/logged_files/CIFAR10/clear-river-13/labels_1000.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/clear-river-13/images_0.pt (deflated 65%)
  adding: content/logged_files/CIFAR10/clear-river-13/labels_best.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/clear-river-13/images_500.pt (deflated 8%)
  adding: content/logged_files/CIFAR10/clear-river-13/images_1000.pt (deflated 8%)
  adding: content/logged_files/CIFAR10/clear-river-13/images_900.pt (deflated 8%)
  adding: content/logged_files/CIFAR10/clear-river-13/labels_500.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/clear-river-13/labels_700.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/clear-river-13/images_600.pt (deflated 8%)
  adding: content/logg

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
!unzip logged_files.zip

Archive:  /content/logged_files.zip
   creating: content/logged_files/
   creating: content/logged_files/CIFAR10/
   creating: content/logged_files/CIFAR10/atomic-field-23/
  inflating: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter100.png  
  inflating: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter300.png  
  inflating: content/logged_files/CIFAR10/atomic-field-23/labels_200.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/labels_600.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter1000.png  
  inflating: content/logged_files/CIFAR10/atomic-field-23/images_300.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/labels_500.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/images_200.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/images_1000.pt  
  inflating: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_ite

Train from trajectory matching initialized from real 

In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image

class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'CIFAR10'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


distilled_set_name = '/content/logged_files/CIFAR10/clear-river-13/images_900.pt'
distilled_label_name = '/content/logged_files/CIFAR10/clear-river-13/labels_900.pt'
data = torch.load(distilled_set_name)
label = torch.load(distilled_label_name)

print(len(label))


syn_data = TensorDataset(data, label)


channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
net = get_network(args.model, channel, num_classes, im_size).to(args.device)


net.train()
optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
optimizer_net.zero_grad()

criterion = nn.CrossEntropyLoss().to(args.device)

num_epochs = 20

trainloader = torch.utils.data.DataLoader(syn_data, batch_size=1, shuffle=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)


for i in range(num_epochs):
  loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
  for param_group in optimizer_net.param_groups:
    print("Current learning rate is: {}".format(param_group['lr']))

  print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
  scheduler.step()


testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
net.eval()
loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))

100
Files already downloaded and verified
Files already downloaded and verified
Current learning rate is: 0.01
Epoch[1/20]: Loss: 4.4712, ACC: 0.1100
Current learning rate is: 0.009938441702975689
Epoch[2/20]: Loss: 3.7992, ACC: 0.1400
Current learning rate is: 0.009755282581475769
Epoch[3/20]: Loss: 2.6815, ACC: 0.1700
Current learning rate is: 0.00945503262094184
Epoch[4/20]: Loss: 1.8079, ACC: 0.3900
Current learning rate is: 0.009045084971874739
Epoch[5/20]: Loss: 0.9629, ACC: 0.6500
Current learning rate is: 0.008535533905932738
Epoch[6/20]: Loss: 0.2185, ACC: 0.9900
Current learning rate is: 0.007938926261462366
Epoch[7/20]: Loss: 0.0614, ACC: 1.0000
Current learning rate is: 0.007269952498697735
Epoch[8/20]: Loss: 0.0388, ACC: 1.0000
Current learning rate is: 0.006545084971874738
Epoch[9/20]: Loss: 0.0308, ACC: 1.0000
Current learning rate is: 0.005782172325201155
Epoch[10/20]: Loss: 0.0261, ACC: 1.0000
Current learning rate is: 0.005
Epoch[11/20]: Loss: 0.0231, ACC: 1.0000
Curr

Train from trajectory matching initialized from noise 

In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image

class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'CIFAR10'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


distilled_set_name = '/content/logged_files/CIFAR10/efficient-waterfall-22/images_best.pt'
distilled_label_name = '/content/logged_files/CIFAR10/efficient-waterfall-22/labels_best.pt'
data = torch.load(distilled_set_name)
label = torch.load(distilled_label_name)

print(len(label))


syn_data = TensorDataset(data, label)


channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
net = get_network(args.model, channel, num_classes, im_size).to(args.device)


net.train()
optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
optimizer_net.zero_grad()

criterion = nn.CrossEntropyLoss().to(args.device)

num_epochs = 20

trainloader = torch.utils.data.DataLoader(syn_data, batch_size=1, shuffle=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)


for i in range(num_epochs):
  loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
  for param_group in optimizer_net.param_groups:
    print("Current learning rate is: {}".format(param_group['lr']))

  print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
  scheduler.step()


testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
net.eval()
loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))

100
Files already downloaded and verified
Files already downloaded and verified
Current learning rate is: 0.01
Epoch[1/20]: Loss: 4.1695, ACC: 0.1300
Current learning rate is: 0.009938441702975689
Epoch[2/20]: Loss: 3.0258, ACC: 0.1400
Current learning rate is: 0.009755282581475769
Epoch[3/20]: Loss: 1.8360, ACC: 0.3700
Current learning rate is: 0.00945503262094184
Epoch[4/20]: Loss: 0.7931, ACC: 0.7600
Current learning rate is: 0.009045084971874739
Epoch[5/20]: Loss: 0.1531, ACC: 0.9900
Current learning rate is: 0.008535533905932738
Epoch[6/20]: Loss: 0.0424, ACC: 1.0000
Current learning rate is: 0.007938926261462366
Epoch[7/20]: Loss: 0.0260, ACC: 1.0000
Current learning rate is: 0.007269952498697735
Epoch[8/20]: Loss: 0.0213, ACC: 1.0000
Current learning rate is: 0.006545084971874738
Epoch[9/20]: Loss: 0.0186, ACC: 1.0000
Current learning rate is: 0.005782172325201155
Epoch[10/20]: Loss: 0.0166, ACC: 1.0000
Current learning rate is: 0.005
Epoch[11/20]: Loss: 0.0152, ACC: 1.0000
Curr

In [None]:
!pip install flopth

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image

class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

net = get_network(args.model, channel, num_classes, im_size).to(args.device)



In [None]:
print(im_size)
print(len(dst_train))

from flopth import flopth
flops, params = flopth(net, in_size=((1, 28, 28),), bare_number=False)
print("1 image FLOPS is {0}".format(flops))
num = len(dst_train)
flops, params = flopth(net, in_size=((1, 28, 28),), show_detail=True, bare_number=True)
flops_original = flops * num / (1000 * 1000 * 1000)
flops_distilled = 100 * num / (1000 * 1000 * 1000)
print("Train entire training set of {0} images, FLOPS is {1} G".format(num, flops_original))
print("Train distilled training set of 100 images, FLOPS is {0} G".format(flops_distilled))
# 

(28, 28)
60000
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at now.
Op GroupNorm is not supported at no

## application for distilled dataset - Neural architecture research


In [None]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image

class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 256
    self.batch_train = 256
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


distilled_set_name = '/content/result/res_DC_MNIST_ConvNet_10ipc.pt'
dataset = torch.load(distilled_set_name)
syn_data = dataset['data']
syn_data = TensorDataset(syn_data[0][0], syn_data[0][1])

channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

model_list = ['ConvNetD1','ConvNetD2', 'ConvNetD4']
train_list = ['Distilled', 'Original']
accuracy_list = []

for model in model_list:
  for train in train_list:
    run_acc_avg = 0.0
    for i in range(3):
      print(f"Dataset: {train}, model: {model}, run: {i}")
      net = get_network(model, channel, num_classes, im_size).to(args.device)
      net.train()
      optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
      optimizer_net.zero_grad()

      criterion = nn.CrossEntropyLoss().to(args.device)

      num_epochs = 20

      if train == 'Distilled':
        trainloader = torch.utils.data.DataLoader(syn_data, batch_size=1, shuffle=True, num_workers=0)
      else:
        trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)


      for i in range(num_epochs):
        loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
        # for param_group in optimizer_net.param_groups:
        #   print("Current learning rate is: {}".format(param_group['lr']))

        # print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
        scheduler.step()


      testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
      net.eval()
      loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
      print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))
      run_acc_avg += acc_avg
    run_acc_avg /= 3.0
    print(f"average test accuracy over 3 runs is: {run_acc_avg}")
    accuracy_list.append(run_acc_avg)


print(accuracy_list)


Dataset: Distilled, model: ConvNetD1, run: 0
Test set: Loss: 0.5919, ACC: 0.8431
Dataset: Distilled, model: ConvNetD1, run: 1
Test set: Loss: 0.5685, ACC: 0.8296
Dataset: Distilled, model: ConvNetD1, run: 2
Test set: Loss: 0.5389, ACC: 0.8415
average test accuracy over 3 runs is: 0.8380666666666666
Dataset: Original, model: ConvNetD1, run: 0
Test set: Loss: 0.0657, ACC: 0.9808
Dataset: Original, model: ConvNetD1, run: 1
Test set: Loss: 0.0630, ACC: 0.9815
Dataset: Original, model: ConvNetD1, run: 2
Test set: Loss: 0.0640, ACC: 0.9810
average test accuracy over 3 runs is: 0.9811
Dataset: Distilled, model: ConvNetD2, run: 0
Test set: Loss: 0.3254, ACC: 0.9112
Dataset: Distilled, model: ConvNetD2, run: 1
Test set: Loss: 0.2925, ACC: 0.9163
Dataset: Distilled, model: ConvNetD2, run: 2
Test set: Loss: 0.3301, ACC: 0.9107
average test accuracy over 3 runs is: 0.9127333333333333
Dataset: Original, model: ConvNetD2, run: 0
Test set: Loss: 0.0407, ACC: 0.9875
Dataset: Original, model: ConvNetD2

## mhist 


Read in data from annotation and png 

In [2]:
from google.colab import drive
drive.mount('/content/drive')
!unzip  "/content/drive/My Drive/Colab Notebooks/mhist_dataset.zip" 

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Archive:  /content/drive/My Drive/Colab Notebooks/mhist_dataset.zip
replace mhist_dataset/annotations.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: yes
  inflating: mhist_dataset/annotations.csv  
replace mhist_dataset/images/MHIST_aaa.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: a
error:  invalid response [a]
replace mhist_dataset/images/MHIST_aaa.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: all
error:  invalid response [all]
replace mhist_dataset/images/MHIST_aaa.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: mhist_dataset/images/MHIST_aaa.png  
  inflating: mhist_dataset/images/MHIST_aab.png  
  inflating: mhist_dataset/images/MHIST_aac.png  
  inflating: mhist_dataset/images/MHIST_aad.png  
  inflating: mhist_dataset/images/MHIST_aae.png  
  inflating: mhist_dataset/images/MHIST_aaf.png  
  inflating: mhist_dataset/images/MHIST_aag.png  
  inflating: 

In [2]:
import torch
import numpy as np
import csv
from torch.utils.data import Dataset
from torchvision import transforms

class MHISTDataset(Dataset):
  def __init__(self, images, labels, mean, std): 
      self.images = images
      self.labels = labels
      self.mean = mean
      self.std = std

  def __getitem__(self, idx):

        data = self.images[idx]
        label = self.labels[idx]
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
        return transform(data), label

  def __len__(self):
      return len(self.labels)

def getMHIST():
  # csv file name
  filename = "mhist_dataset/annotations.csv"
  import cv2
  
  # initializing the titles and rows list
  fields = []
  rows = []
  path = "mhist_dataset/images/"
  
  # reading csv file
  with open(filename, 'r') as csvfile:
    # creating a csv reader object
    csvreader = csv.reader(csvfile)
      
    # extracting field names through first row
    fields = next(csvreader)
  
    # extracting each data row one by one
    for row in csvreader:
      rows.append(row)
  


  #label HP = 0 SSA = 1


  training_image = []
  training_label = []
  testing_image = []
  testing_label = []

  mean = np.array([0.,0.,0.])
  stdTemp = np.array([0.,0.,0.])
  std = np.array([0.,0.,0.])


  for row in rows:
    # parsing each column of a row
    img_path = path + row[0]
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image.astype(float) / 255.
    # for j in range(3):
    #     mean[j] += np.mean(image[:,:,j])
    label = 0
    if row[1] == "SSA":
      label = 1
      
    if row[3] == "train":
      training_image.append(image)
      training_label.append(label)

    elif row[3] == "test":
      testing_image.append(image)
      testing_label.append(label)

  print("Training Label: ")  
  print(len(training_label))
  print("Testing Label: ")
  print(len(testing_label))

  # mean = (mean/len(rows))
  # print(mean)


  
  # for row in rows:
  #   # parsing each column of a row
  #   img_path = path + row[0]
  #   image = cv2.imread(img_path)
  #   image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  #   image = image.astype(float) / 255.

  #   for j in range(3):
  #     stdTemp[j] += ((image[:,:,j] - mean[j])**2).sum()/(image.shape[0]*image.shape[1])

  # std = np.sqrt(stdTemp/len(rows))
  # print(std)




  channel = 3
  im_size = (224, 224)
  num_classes = 2
  class_names = ['HP', 'SSA']
  mean = [0.73943309, 0.65267006, 0.77742641]
  std = [0.19637706, 0.24239548, 0.16935587]


  dst_train = MHISTDataset(training_image, training_label, mean, std) 

  dst_test = MHISTDataset(testing_image, testing_label, mean, std)
  testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)

  return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader

#training_image = tf.convert_to_tensor(np.array(training_image))
#testing_image = tf.convert_to_tensor(np.array(testing_image))
#training_label = tf.convert_to_tensor(np.array(training_label))
#testing_label = tf.convert_to_tensor(np.array(testing_label))

In [16]:
getMHIST()

Training Label: 
77
Testing Label: 
23
[0.02350418 0.02076763 0.02470523]
[0.13253466 0.12098372 0.13772575]


(3,
 (224, 224),
 2,
 ['HP', 'SSA'],
 array([0.02350418, 0.02076763, 0.02470523]),
 array([0.13253466, 0.12098372, 0.13772575]),
 <__main__.TensorDataset at 0x7f45e8fc7400>,
 <__main__.TensorDataset at 0x7f45e8fc7490>,
 <torch.utils.data.dataloader.DataLoader at 0x7f45e8fc7fa0>)

In [3]:
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import time 


class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 16
    self.batch_train = 16
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'



args = myArgs()
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False


channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = getMHIST()
#channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

net = get_network(args.model, channel, num_classes, im_size).to(args.device)

net.train()
optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  # optimizer_img for synthetic data
optimizer_net.zero_grad()

criterion = nn.CrossEntropyLoss().to(args.device)

num_epochs = 20

trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, num_epochs)

start = time.time()

for i in range(num_epochs):
  loss_avg, acc_avg = epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
  for param_group in optimizer_net.param_groups:
    print("Current learning rate is: {}".format(param_group['lr']))

  print("Epoch[{}/{}]: Loss: {:.4f}, ACC: {:.4f}".format(i+1,num_epochs, loss_avg, acc_avg))
  scheduler.step()

print(time.time() - start)

testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_train, shuffle=True, num_workers=0)
net.eval()
loss_avg, acc_avg = epoch('test', testloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

print("Test set: Loss: {:.4f}, ACC: {:.4f}".format(loss_avg, acc_avg))


importing Jupyter notebook from utils.ipynb
importing Jupyter notebook from networks.ipynb
Training Label: 
2175
Testing Label: 
977
Current learning rate is: 0.01
Epoch[1/20]: Loss: 13.1348, ACC: 0.6000
Current learning rate is: 0.009938441702975689
Epoch[2/20]: Loss: 1.4539, ACC: 0.7140


KeyboardInterrupt: ignored

In [None]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug


class myArgs:
  def __init__(self):
    self.method = 'DC'
    self.dataset = 'MNIST'
    self.model = 'ConvNet'
    self.ipc = 10
    self.eval_mode = 'S'
    self.num_exp = 1
    self.num_eval = 5
    self.epoch_eval_train = 300
    self.Iteration = 100
    self.lr_img = 0.1
    self.lr_net = 0.01
    self.batch_real = 16
    self.batch_train = 16
    self.init = 'real'
    self.dsa_strategy = 'None'
    self.data_path = 'data'
    self.save_path = 'result'
    self.dis_metric = 'ours'

def main():

    # parser = argparse.ArgumentParser(description='Parameter Processing')
    # parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
    # parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    # parser.add_argument('--model', type=str, default='ConvNet', help='model')
    # parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    # parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    # parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    # parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    # parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
    # parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    # parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    # parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    # parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    # parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    # parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    # parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    # parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    # parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    # parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

    # args = parser.parse_args()
    args = myArgs()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False


    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, 20).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = getMHIST()
    #channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)

        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        criterion = nn.CrossEntropyLoss().to(args.device)
        print('%s training begins'%get_time())

        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    if args.dsa:
                        args.epoch_eval_train = 1000
                        args.dc_aug_param = None
                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
                    else:
                        args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
                        print('DC augmentation parameters: \n', args.dc_aug_param)

                    if args.dsa or args.dc_aug_param['strategy'] != 'none':
                        args.epoch_eval_train = 1000  # Training with data augmentation needs more epochs.
                    else:
                        args.epoch_eval_train = 300
                    
                    args.epoch_eval_train = 300

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.


            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()
            net_parameters = list(net.parameters())
            optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
            optimizer_net.zero_grad()
            loss_avg = 0
            args.dc_aug_param = None  # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.


            for ol in range(args.outer_loop):

                ''' freeze the running mu and sigma for BatchNorm layers '''
                # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
                # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
                # This would make the training with BatchNorm layers easier.

                BN_flag = False
                BNSizePC = 16  # for batch normalization
                for module in net.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True
                if BN_flag:
                    img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                    net.train() # for updating the mu, sigma of BatchNorm
                    output_real = net(img_real) # get running mu, sigma
                    for module in net.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer


                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    img_real = get_images(c, args.batch_real)
                    lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = net(img_real)
                    loss_real = criterion(output_real, lab_real)
                    gw_real = torch.autograd.grad(loss_real, net_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    output_syn = net(img_syn)
                    loss_syn = criterion(output_syn, lab_syn)
                    gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                    loss += match_loss(gw_syn, gw_real, args)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()

                if ol == args.outer_loop - 1:
                    break


                ''' update network '''
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)


            loss_avg /= (num_classes*args.outer_loop)

            if it%10 == 0:
                print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))


main()

eval_it_pool:  [0, 20, 40, 60, 80, 100]
Training Label: 
2175
Testing Label: 
977

 
Hyper-parameters: 
 {'method': 'DC', 'dataset': 'MNIST', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_exp': 1, 'num_eval': 5, 'epoch_eval_train': 300, 'Iteration': 100, 'lr_img': 0.1, 'lr_net': 0.01, 'batch_real': 16, 'batch_train': 16, 'init': 'real', 'dsa_strategy': 'None', 'data_path': 'data', 'save_path': 'result', 'dis_metric': 'ours', 'outer_loop': 10, 'inner_loop': 50, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7f596ef6c7f0>, 'dsa': False}
Evaluation model pool:  ['ConvNet']
