In [None]:
!pip install import_ipynb
import import_ipynb
!pip install kornia

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting import_ipynb
  Downloading import_ipynb-0.1.4-py3-none-any.whl (4.1 kB)
Collecting jedi>=0.10
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 13.0 MB/s 
Installing collected packages: jedi, import-ipynb
Successfully installed import-ipynb-0.1.4 jedi-0.18.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting kornia
  Downloading kornia-0.6.8-py2.py3-none-any.whl (551 kB)
[K     |████████████████████████████████| 551 kB 14.9 MB/s 
Installing collected packages: kornia
Successfully installed kornia-0.6.8


In [None]:
import os
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import get_dataset, get_network, get_daparam,\
    TensorDataset, epoch, ParamDiffAug
import copy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

class myArgs:
  def __init__(self):
    self.dataset = 'CIFAR10'
    self.subset = 'imagenette'
    self.model = 'ConvNet'
    self.num_experts = 10
    self.lr_teacher = 0.01
    self.batch_train = 256
    self.batch_real = 256
    self.dsa = 'False'
    self.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
    self.data_path = 'data'
    self.buffer_path = './buffers'
    self.train_epochs = 20
    self.zca = False
    self.decay = False
    self.mom = 0
    self.l2 = 0
    self.save_interval = 5


def main():

    args = myArgs()

    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)

    # print('\n================== Exp %d ==================\n '%exp)
    print('Hyper-parameters: \n', args.__dict__)

    save_dir = os.path.join(args.buffer_path, args.dataset)
    if args.dataset == "ImageNet":
        save_dir = os.path.join(save_dir, args.subset, str(args.res))
    if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
        save_dir += "_NO_ZCA"
    save_dir = os.path.join(save_dir, args.model)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)


    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]
    print("BUILDING DATASET")
    for i in tqdm(range(len(dst_train))):
        sample = dst_train[i]
        images_all.append(torch.unsqueeze(sample[0], dim=0))
        labels_all.append(class_map[torch.tensor(sample[1]).item()])

    for i, lab in tqdm(enumerate(labels_all)):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to("cpu")
    labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")

    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

    criterion = nn.CrossEntropyLoss().to(args.device)

    trajectories = []

    dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach()))
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    ''' set augmentation for whole-dataset training '''
    args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
    args.dc_aug_param['strategy'] = 'crop_scale_rotate'  # for whole-dataset training
    print('DC augmentation parameters: \n', args.dc_aug_param)

    for it in range(0, args.num_experts):

        ''' Train synthetic data '''
        teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
        teacher_net.train()
        lr = args.lr_teacher
        teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)  # optimizer_img for synthetic data
        teacher_optim.zero_grad()

        timestamps = []

        timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

        lr_schedule = [args.train_epochs // 2 + 1]

        for e in range(args.train_epochs):

            train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
                                        criterion=criterion, args=args, aug=True)

            test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None,
                                        criterion=criterion, args=args, aug=False)

            print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))

            timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

            if e in lr_schedule and args.decay:
                lr *= 0.1
                teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
                teacher_optim.zero_grad()

        trajectories.append(timestamps)

        if len(trajectories) == args.save_interval:
            n = 0
            while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
                n += 1
            print("Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))))
            torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
            trajectories = []



In [None]:
main()

Files already downloaded and verified
Files already downloaded and verified
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'num_experts': 10, 'lr_teacher': 0.01, 'batch_train': 256, 'batch_real': 256, 'dsa': False, 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': 'data', 'buffer_path': './buffers', 'train_epochs': 20, 'zca': False, 'decay': False, 'mom': 0, 'l2': 0, 'save_interval': 5, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7fc56393efd0>}
BUILDING DATASET


100%|██████████| 50000/50000 [00:11<00:00, 4188.96it/s]
50000it [00:00, 2410048.61it/s]


class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images channel 0, mean = -0.0000, std = 1.2211
real images channel 1, mean = -0.0002, std = 1.2211
real images channel 2, mean = 0.0002, std = 1.3014
DC augmentation parameters: 
 {'crop': 4, 'scale': 0.2, 'rotate': 45, 'noise': 0.001, 'strategy': 'crop_scale_rotate'}
Itr: 0	Epoch: 0	Train Acc: 0.34212	Test Acc: 0.439
Itr: 0	Epoch: 1	Train Acc: 0.45014	Test Acc: 0.4859
Itr: 0	Epoch: 2	Train Acc: 0.49666	Test Acc: 0.549
Itr: 0	Epoch: 3	Train Acc: 0.52748	Test Acc: 0.5514
Itr: 0	Epoch: 4	Train Acc: 0.55556	Test Acc: 0.5962
Itr: 0	Epoch: 5	Train Acc: 0.57562	Test Acc: 0.6024
Itr: 0	Epoch: 6	Train Acc: 0.59322	Test Acc: 0.6354
Itr: 0	Epoch: 7	Train Acc: 0.60972	Test Acc: 0.6292
Itr: 0	Epoch

In [None]:
!pip install kornia
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting kornia
  Downloading kornia-0.6.8-py2.py3-none-any.whl (551 kB)
[K     |████████████████████████████████| 551 kB 5.0 MB/s 
Installing collected packages: kornia
Successfully installed kornia-0.6.8
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 5.1 MB/s 
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_

In [None]:
!python distill.py --dataset=CIFAR10 --model=ConvNet --ipc=10 --syn_steps=30 --expert_epochs=2 --max_start_epoch=15 --Iteration=1000 --num_eval=2 --pix_init=real

CUDNN STATUS: True
Files already downloaded and verified
Files already downloaded and verified
[34m[1mwandb[0m: Currently logged in as: [33mfredshi1997[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.13.5
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20221202_020558-27h4lg8i[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33matomic-field-23[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/fredshi1997/dip[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/fredshi1997/dip/runs/27h4lg8i[0m
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_eval': 2, 'eval_it': 100, 'epoch_eval_train': 1000, 'Iteration': 1000, 'lr_img': 1000, 'lr_lr': 1e-05, 'lr_teacher': 0.01, 'lr_init': 0.01, 'batch_real': 256, 'batch_syn': 100, 'batch_train': 256

In [None]:
import torch 
import copy
from torchvision.utils import save_image
from utils import get_dataset

class myArgs:
  def __init__(self):
    self.zca = False

args = myArgs()
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset('CIFAR10', 'data', 256, 'imagenette', args=args)

distilled_set_name = '/content/logged_files/CIFAR10/clear-river-13/images_900.pt'
data = torch.load(distilled_set_name)


image_syn_vis = copy.deepcopy(data.detach().cpu())
image_syn_vis[image_syn_vis<0] = 0.0
image_syn_vis[image_syn_vis>1] = 1.0
# for ch in range(channel):
#     image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
image_syn_vis[image_syn_vis < 0] = 0.0
image_syn_vis[image_syn_vis > 1] = 1.0
save_image(image_syn_vis, '/content/distilled_from_real_final.png', nrow=10)

# distilled_set_name = '/content/logged_files/CIFAR10/clear-river-13/images_0.pt'
# data = torch.load(distilled_set_name)


# image_syn_vis = copy.deepcopy(data.detach().cpu())
# for ch in range(channel):
#     image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
# image_syn_vis[image_syn_vis < 0] = 0.0
# image_syn_vis[image_syn_vis > 1] = 1.0
# save_image(image_syn_vis, '/content/distilled_from_real_initial.png', nrow=10)



Files already downloaded and verified
Files already downloaded and verified


In [None]:
!python distill.py --dataset=CIFAR10 --model=ConvNet --ipc=10 --syn_steps=30 --expert_epochs=2 --max_start_epoch=15 --Iteration=1000 --num_eval=2 --pix_init=noise --dsa=False

CUDNN STATUS: True
Files already downloaded and verified
Files already downloaded and verified
[34m[1mwandb[0m: Currently logged in as: [33mfredshi1997[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.13.5
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20221202_013003-jd8mz99u[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mefficient-waterfall-22[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/fredshi1997/dip[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/fredshi1997/dip/runs/jd8mz99u[0m
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'subset': 'imagenette', 'model': 'ConvNet', 'ipc': 10, 'eval_mode': 'S', 'num_eval': 2, 'eval_it': 100, 'epoch_eval_train': 1000, 'Iteration': 1000, 'lr_img': 1000, 'lr_lr': 1e-05, 'lr_teacher': 0.01, 'lr_init': 0.01, 'batch_real': 256, 'batch_syn': 100, 'batch_trai

In [None]:
!zip -r /content/logged_files.zip /content/logged_files
from google.colab import files
files.download("/content/logged_files.zip")

  adding: content/logged_files/ (stored 0%)
  adding: content/logged_files/CIFAR10/ (stored 0%)
  adding: content/logged_files/CIFAR10/atomic-field-23/ (stored 0%)
  adding: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter100.png (deflated 0%)
  adding: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter300.png (deflated 0%)
  adding: content/logged_files/CIFAR10/atomic-field-23/labels_200.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/atomic-field-23/labels_600.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/atomic-field-23/vis_CIFAR10_ConvNet_10ipc_iter1000.png (deflated 0%)
  adding: content/logged_files/CIFAR10/atomic-field-23/images_300.pt (deflated 8%)
  adding: content/logged_files/CIFAR10/atomic-field-23/labels_500.pt (deflated 76%)
  adding: content/logged_files/CIFAR10/atomic-field-23/images_200.pt (deflated 8%)
  adding: content/logged_files/CIFAR10/atomic-field-23/images_1000.pt (deflated 8%)
  adding: co

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>