## Introduction

In the previous presentation, I shared an interesting paper about a technique named dataset distillation, in which we may generate synthetic data to form a smaller dataset from the original large dataset. In this assignment, I will show the simple instructions to generate the synthetic dataset by using the distribution matching method. The main idea of the distribution matching is to calculate the empirical maximum mean discrepancy (MMD) between the original dataset $T$ and the synthetic dataset $S$:

$$E_{\theta \sim P_\theta}|| \frac{1}{|T|} \sum_{i=1}^{|T|} \phi_\theta (x_i) - \frac{1}{|S|} \sum_{i=1}^{|S|} \phi_\theta (s_i)||^2$$

where we first sample the network parameter from the parameter space $P_\theta$ and then construct an embedding function $\phi_\theta (s_i)$ for mapping the data from the distribution into a measure space. Therefore, we can calculate the distance between the two distributions of $T$ and $S$, and compute the gradient to minimize this distance. This paper adopted data augmentation while calculating the MMD; therefore, the final MMD will be:


$$E_{\theta \sim P_\theta}|| \frac{1}{|T|} \sum_{i=1}^{|T|} \phi_\theta (A(x_i)) - \frac{1}{|S|} \sum_{i=1}^{|S|} \phi_\theta (A(s_i))||^2$$

where $A(*)$ is the data augmentation function.

In this assignment, I will adopt ConvNet as the embedding function to generate the synthetic dataset for CIFAR10, and then evaluate this synthetic dataset by implementing the image classification task also with ResNet.

## Reference 
B. Zhao and H. Bilen, "Dataset Condensation with Distribution Matching," 2023 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), Waikoloa, HI, USA, 2023, pp. 6503-6512, doi: 10.1109/WACV56688.2023.00645.

https://github.com/VICO-UoE/DatasetCondensation/tree/master

#### 1. Import libraries and cuda 

In [2]:
import os
import copy
import time
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# iterations for generating images
Iterations = 20000
# image per class
ipc = 10
# learning rate for synthetic image
lr_img = 1.0
# batch size for real images
batch_real = 256
# dataset name
dataset_name = 'cifar10'
# save path
save_path = './results'

#### 2. Load dataset

In [6]:
channel = 3
im_size = (32, 32)
num_classes = 10
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.CIFAR10('./data', train=True, download=True, transform=transform) # no augmentation
dst_test = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
class_names = dst_train.classes
testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)
lr_net = 0.01
# get random n images from class c
images_all = []
labels_all = []
indices_class = [[] for c in range(num_classes)]

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)
def get_images(c, n): 
    idx_shuffle = np.random.permutation(indices_class[c])[:n]
    return images_all[idx_shuffle]

Files already downloaded and verified
Files already downloaded and verified


#### 3. Initalize the synthetic dataset

In [7]:
# total number of images: number of classes * images per class
# initialize the synthetic images with real data
image_syn = torch.randn(size=(num_classes * ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=device)
for c in range(num_classes):
    image_syn.data[c*ipc:(c+1)*ipc] = get_images(c, ipc).detach().data
label_syn = torch.tensor([np.ones(ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

  label_syn = torch.tensor([np.ones(ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]


#### 4. Define the model
Here, the embedding function and the model for evaluation are both using ConvNet

In [8]:
''' Swish activation '''
class Swish(nn.Module): # Swish(x) = x∗σ(x)
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.sigmoid(input)

''' ConvNet '''
class ConvNet(nn.Module):
    def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
        super(ConvNet, self).__init__()

        self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
        num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def embed(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        return out

    def _get_activation(self, net_act):
        if net_act == 'sigmoid':
            return nn.Sigmoid()
        elif net_act == 'relu':
            return nn.ReLU(inplace=True)
        elif net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01)
        elif net_act == 'swish':
            return Swish()
        else:
            exit('unknown activation function: %s'%net_act)

    def _get_pooling(self, net_pooling):
        if net_pooling == 'maxpooling':
            return nn.MaxPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            return nn.AvgPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'none':
            return None
        else:
            exit('unknown net_pooling: %s'%net_pooling)

    def _get_normlayer(self, net_norm, shape_feat):
        # shape_feat = (c*h*w)
        if net_norm == 'batchnorm':
            return nn.BatchNorm2d(shape_feat[0], affine=True)
        elif net_norm == 'layernorm':
            return nn.LayerNorm(shape_feat, elementwise_affine=True)
        elif net_norm == 'instancenorm':
            return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
        elif net_norm == 'groupnorm':
            return nn.GroupNorm(4, shape_feat[0], affine=True)
        elif net_norm == 'none':
            return None
        else:
            exit('unknown net_norm: %s'%net_norm)

    def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
        layers = []
        in_channels = channel
        if im_size[0] == 28:
            im_size = (32, 32)
        shape_feat = [in_channels, im_size[0], im_size[1]]
        for d in range(net_depth):
            layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
            shape_feat[0] = net_width
            if net_norm != 'none':
                layers += [self._get_normlayer(net_norm, shape_feat)]
            layers += [self._get_activation(net_act)]
            in_channels = net_width
            if net_pooling != 'none':
                layers += [self._get_pooling(net_pooling)]
                shape_feat[1] //= 2
                shape_feat[2] //= 2

        return nn.Sequential(*layers), shape_feat


#### 5. Define the data augmentation

In [9]:
# define the parameters for data augmentation
class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'S' #'multiple or single'
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5

# setup the random seed for augmentation
def set_seed_DiffAug(param):
    if param.latestseed == -1:
        return
    else:
        torch.random.manual_seed(param.latestseed)
        param.latestseed += 1

# We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
def rand_scale(x, param):
    # x>1, max scale
    # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
    ratio = param.ratio_scale
    set_seed_DiffAug(param)
    sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    set_seed_DiffAug(param)
    sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    theta = [[[sx[i], 0,  0],
            [0,  sy[i], 0],] for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.Siamese: # Siamese augmentation:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x


def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
    ratio = param.ratio_rotate
    set_seed_DiffAug(param)
    theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
    theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
        [torch.sin(theta[i]), torch.cos(theta[i]),  0],]  for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.Siamese: # Siamese augmentation:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x


def rand_flip(x, param):
    prob = param.prob_flip
    set_seed_DiffAug(param)
    randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
    if param.Siamese: # Siamese augmentation:
        randf[:] = randf[0]
    return torch.where(randf < prob, x.flip(3), x)


def rand_brightness(x, param):
    ratio = param.brightness
    set_seed_DiffAug(param)
    randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        randb[:] = randb[0]
    x = x + (randb - 0.5)*ratio
    return x


def rand_saturation(x, param):
    ratio = param.saturation
    x_mean = x.mean(dim=1, keepdim=True)
    set_seed_DiffAug(param)
    rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        rands[:] = rands[0]
    x = (x - x_mean) * (rands * ratio) + x_mean
    return x


def rand_contrast(x, param):
    ratio = param.contrast
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    set_seed_DiffAug(param)
    randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:  # Siamese augmentation:
        randc[:] = randc[0]
    x = (x - x_mean) * (randc + ratio) + x_mean
    return x


def rand_crop(x, param):
    # The image is padded on its surrounding and then cropped.
    ratio = param.ratio_crop_pad
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:  # Siamese augmentation:
        translation_x[:] = translation_x[0]
        translation_y[:] = translation_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, param):
    ratio = param.ratio_cutout
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:  # Siamese augmentation:
        offset_x[:] = offset_x[0]
        offset_y[:] = offset_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x

AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'crop': [rand_crop],
    'cutout': [rand_cutout],
    'flip': [rand_flip],
    'scale': [rand_scale],
    'rotate': [rand_rotate],
}
def DiffAugment(x, strategy='', seed = -1, param = None):
    if strategy == 'None' or strategy == 'none' or strategy == '':
        return x

    if seed == -1:
        param.Siamese = False
    else:
        param.Siamese = True

    param.latestseed = seed

    if strategy:
        if param.aug_mode == 'M': # original
            for p in strategy.split('_'):
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
        elif param.aug_mode == 'S':
            pbties = strategy.split('_')
            set_seed_DiffAug(param)
            p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
            for f in AUGMENT_FNS[p]:
                x = f(x, param)
        else:
            exit('unknown augmentation mode: %s'%param.aug_mode)
        x = x.contiguous()
    return x

# data augmentation params
dsa_param = ParamDiffAug()
dsa_strategy = 'color_crop_cutout_flip_scale_rotate'

#### 6. Get the embedding function from the neural network

In [10]:
def get_network(channel, num_classes, im_size=(32, 32)):
    # set the different random seed for each run, so it will have different initial weights
    net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
    torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    if device == 'cuda':
        gpu_num = torch.cuda.device_count()
        if gpu_num>1:
            net = nn.DataParallel(net)
    net = net.to(device)
    return net


#### 7. The training iteration for generating the synthetic data

In [11]:
optimizer_img = torch.optim.SGD([image_syn, ], lr=lr_img, momentum=0.5) # optimizer_img for synthetic data
optimizer_img.zero_grad()
for it in range(Iterations + 1):
    ''' Train synthetic data '''
    net = get_network(channel, num_classes).to(device) # get a random model
    net.train()
    for param in list(net.parameters()):
        param.requires_grad = False
    embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

    loss_avg = 0

    ''' update synthetic data '''
    loss = torch.tensor(0.0).to(device)
    for c in range(num_classes):
        img_real = get_images(c, batch_real)
        img_syn = image_syn[c*ipc:(c+1)*ipc].reshape((ipc, channel, im_size[0], im_size[1]))

        seed = int(time.time() * 1000) % 100000
        img_real = DiffAugment(img_real, dsa_strategy, seed=seed, param=dsa_param)
        img_syn = DiffAugment(img_syn, dsa_strategy, seed=seed, param=dsa_param)

        output_real = embed(img_real).detach()
        output_syn = embed(img_syn)

        loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

    optimizer_img.zero_grad()
    loss.backward()
    optimizer_img.step()
    loss_avg += loss.item()


    loss_avg /= (num_classes)

    if it % 100 == 0:
        print('iter = %05d, loss = %.4f' % (it, loss_avg))

        ''' visualize and save '''
        save_name = os.path.join(save_path, 'vis_%s_%dipc_iter%d.png'%(dataset_name, ipc, it))
        image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
        for ch in range(channel):
            image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
        image_syn_vis[image_syn_vis<0] = 0.0
        image_syn_vis[image_syn_vis>1] = 1.0
        save_image(image_syn_vis, save_name, nrow=ipc) # Trying normalize = True/False may get better visual effects.

    if it == Iterations: # only record the final results
        data_save = []
        data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
        torch.save({'data': data_save}, os.path.join(save_path, 'res_%s_%dipc.pt'%(dataset_name, ipc)))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


iter = 00000, loss = 24.1172
iter = 00100, loss = 7.5307
iter = 00200, loss = 6.6096
iter = 00300, loss = 5.9546
iter = 00400, loss = 5.9315
iter = 00500, loss = 5.7106
iter = 00600, loss = 5.5293
iter = 00700, loss = 5.0596
iter = 00800, loss = 5.1546
iter = 00900, loss = 5.0365
iter = 01000, loss = 4.7591
iter = 01100, loss = 5.0294
iter = 01200, loss = 4.8128
iter = 01300, loss = 4.7329
iter = 01400, loss = 4.6285
iter = 01500, loss = 5.2236
iter = 01600, loss = 4.6348
iter = 01700, loss = 4.6310
iter = 01800, loss = 4.4247
iter = 01900, loss = 4.5451
iter = 02000, loss = 4.6869
iter = 02100, loss = 4.2667
iter = 02200, loss = 4.2211
iter = 02300, loss = 4.6221
iter = 02400, loss = 4.3642
iter = 02500, loss = 4.3483
iter = 02600, loss = 4.2857
iter = 02700, loss = 4.2776
iter = 02800, loss = 4.1695
iter = 02900, loss = 4.3999
iter = 03000, loss = 4.6242
iter = 03100, loss = 4.2816
iter = 03200, loss = 4.1790
iter = 03300, loss = 4.0455
iter = 03400, loss = 4.1376
iter = 03500, loss 

#### 8. Visualization for the result
The distilled dataset from CIFAR10:

<img src="./vis_cifar10_10ipc_iter20000.png"/> 


#### 9. Evaluate the synthetic data by classification task

In [12]:
def epoch(mode, dataloader, net, optimizer, criterion, aug):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(device)
    criterion = criterion.to(device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):
        img = datum[0].float().to(device)
        if aug:
            img = DiffAugment(img, dsa_strategy, param=dsa_param)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg


# parameters for network training
lr = 0.1
Epoch = 1000
batch_train = 256
net = get_network(channel, num_classes).to(device) # get a random model
images_train = image_syn.to(device)
labels_train = label_syn.to(device)
lr_schedule = [Epoch//2+1]
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss().to(device)

dst_train = TensorDataset(images_train, labels_train)
trainloader = torch.utils.data.DataLoader(dst_train, batch_size=batch_train, shuffle=True, num_workers=0)

for ep in range(Epoch+1):
    loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion,  aug = True)
    if ep in lr_schedule:
        lr *= 0.1
        optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, aug = False)
print('Evaluate: test loss = %.4f, test acc = %.4f' % (loss_test, acc_test))


Evaluate: test loss = 2.8109, test acc = 0.4302
