In [1]:
##### DEMO - Dataset condensation on PhysioNet-2012, 80 condensed samples

### import modules
import numpy as np
from dataset.data_loaders import create_data_loader, build_data_getter
import torch,os,copy,logging
from configs.config_DM import get_args, get_config
from utils.train_utils import get_net, eval_net
from utils.misc import init_logging
from utils.metric_tracker import MetricTracker, TensorboardWriter
from dataset.meta import ds_name_mapping as ds_mp
from dataset.meta import net_name_mapping as net_mp

In [2]:
# get arugments and configs
this_args = get_args(strict=False)

## setting for PhysioNet-2012
this_args.dataset="physio"
this_args.dm_ipc=40
this_args.save_dir_name="PhysioNet_DC"
this_args.pre_process="std"

cf=get_config(this_args)

# create saving directory
os.makedirs(cf.save_dir, exist_ok=True)
log_root = logging.getLogger()
init_logging(log_root, cf.save_dir)

logging.info(f"Dataset: {ds_mp[cf.ds_name]}, results will be saved to: {cf.save_dir}")

2022-12-22,11:45:35-Dataset: PhysioNet-2012, results will be saved to: ../snapshots/PhysioNet_DC


In [3]:
# the real validation/test data loader
logging.info("Loading train set and creating validation/test data loader for: {}".format(ds_mp[cf.ds_name]))

### the dataset are not included; please download and pre-process the datasets by yourselves
_, val_loader, test_loader, tr_data, tr_lb, prpr = cf.data_loader_fn(
    path=cf.data_root, train_batch=cf.train_batch, val_batch=128, test_batch=128, pre_process=cf.pre_process)

logging.info(f"Number of samples - train: {tr_data.shape[0]}, validation: {len(val_loader.dataset)}, test: {len(test_loader.dataset)}")


2022-12-22,11:45:35-Loading train set and creating validation/test data loader for: PhysioNet-2012
2022-12-22,11:45:36-Number of samples - train: 5120, validation: 1280, test: 1600


In [4]:
num_classes = 2 if cf.num_class==1 else cf.num_class  # find class number

# build the original train data getter: get random n data from class c
get_data = build_data_getter(cf.ds_name, tr_data, tr_lb, cf.device)

syn_shape=(num_classes * cf.dm.ipc, cf.dm.syn_time_dim, cf.fea_dim,) # shape of condensed dataset

logging.info(f"Initialising condensed dataset from scratch, condensed samples: {num_classes * cf.dm.ipc}")
data_syn = torch.randn(size=syn_shape, dtype=torch.float, requires_grad=True, device=cf.device)

logging.info("Original train data shape: {}, size: {:.3f} MBs ".format(tr_data.shape, tr_data.nbytes/(1024**2)))
logging.info("Condensed data shape: {}, size: {:.3f} MBs ".format(syn_shape, data_syn.detach().cpu().numpy().nbytes/(1024**2)))

if cf.ds_name == "mimic3" or cf.ds_name=="physio" or cf.ds_name == "covid_b":
    label_syn = np.asarray([np.ones(cf.dm.ipc) * i for i in range(num_classes)])  # [0,0,0, ..., 1,1,1, ]
    label_syn = torch.tensor(label_syn, dtype=cf.label_dtype, requires_grad=False, device=cf.device).view(-1)
else:
    raise NotImplementedError("Dataset {} not implemented".format(cf.ds_name))

logging.info("Using Adam optimizer for DC learning ...")
optimizer_data = torch.optim.Adam([data_syn, ], lr=cf.dm.lr_data)

logging.info("Learning condensed dataset on networks: {}".format([net_mp[e] for e in cf.dm.train_net]))
logging.info("Evaluating condensed dataset on networks: {}".format([net_mp[e] for e in cf.dm.eval_net]))

2022-12-22,11:45:36-Initialising condensed dataset from scratch, condensed samples: 80
2022-12-22,11:45:37-Original train data shape: (5120, 48, 47), size: 88.125 MBs 
2022-12-22,11:45:37-Condensed data shape: (80, 48, 47), size: 0.688 MBs 
2022-12-22,11:45:37-Using Adam optimizer for DC learning ...
2022-12-22,11:45:37-Learning condensed dataset on networks: ['TCN-α', 'LSTM-α', 'ViT-α']
2022-12-22,11:45:37-Evaluating condensed dataset on networks: ['TCN-α', 'LSTM-α', 'ViT-α', 'ViT-β', 'TRSF-α', 'TRSF-β', 'TCN-β', 'TCN-γ', 'LSTM-β', 'RNN-α', 'RNN-β']


In [5]:
# setup tensorboard writer 
logging.info("Creating tensborboard writer ...")
writer = TensorboardWriter(cf.save_dir, cf.enable_tensorboard)
train_metrics = MetricTracker("mmd_loss",writer=writer)
eval_keys = tuple("syn_auc ({})".format(n) for n in cf.dm.eval_net)
eval_metric = MetricTracker(*eval_keys, writer=writer)
all_test_auc = dict()

2022-12-22,11:45:37-Creating tensborboard writer ...


In [6]:
logging.info('DC learning starts ...')
optimizer_data.zero_grad()

#### Learn condensed data ######
for it in range(cf.dm.iteration + 1):

    tr_net = np.random.choice(cf.dm.train_net)   # randomly pick a network from train network candidates
    net = get_net(tr_net, **cf[tr_net+"_args"]).to(cf.device)   # get a random model
    net.train()
    for param in list(net.parameters()):
        param.requires_grad = False
    loss_avg = 0

    # compute MMD loss
    loss = torch.tensor(0.0).to(cf.device)
    for _, c in enumerate(range(num_classes)):
        # the batch size should not exceed total samples of this class
        this_batch_real = min(len(get_data.indices_class[c]), cf.dm.batch_real)
        batch_data_real = get_data(c, this_batch_real)
        batch_data_syn = data_syn[c * cf.dm.ipc : (c + 1) * cf.dm.ipc]

        output_real = net(batch_data_real).detach()
        output_syn = net(batch_data_syn)

        loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0)) ** 2)

    # update condensed data
    optimizer_data.zero_grad()
    loss.backward()
    optimizer_data.step()
    loss_avg += loss.item()
    loss_avg /= (num_classes)
    if train_metrics.writer is not None:
        train_metrics.writer.set_step(it)
    train_metrics.update("mmd_loss", loss_avg)

    if it % cf.dm.logging_iter == 0:
        logging.info('iter = {:04d}, MMD loss = {:.7f}'.format(it, loss_avg))
        
logging.info('Learning completed.')

2022-12-22,11:45:37-DC learning starts ...
2022-12-22,11:45:38-iter = 0000, MMD loss = 0.0000165
2022-12-22,11:46:28-iter = 1000, MMD loss = 0.0000003
2022-12-22,11:47:16-iter = 2000, MMD loss = 0.0000528
2022-12-22,11:48:03-iter = 3000, MMD loss = 0.0000462
2022-12-22,11:48:52-iter = 4000, MMD loss = 0.0001433
2022-12-22,11:49:41-iter = 5000, MMD loss = 0.0002874
2022-12-22,11:50:29-iter = 6000, MMD loss = 0.0003952
2022-12-22,11:51:18-iter = 7000, MMD loss = 0.0000014
2022-12-22,11:52:06-iter = 8000, MMD loss = 0.0000193
2022-12-22,11:52:54-iter = 9000, MMD loss = 0.0000048
2022-12-22,11:53:43-iter = 10000, MMD loss = 0.0000014
2022-12-22,11:54:31-iter = 11000, MMD loss = 0.0000022
2022-12-22,11:55:19-iter = 12000, MMD loss = 0.0000025
2022-12-22,11:56:08-iter = 13000, MMD loss = 0.0000007
2022-12-22,11:56:58-iter = 14000, MMD loss = 0.0006449
2022-12-22,11:57:48-iter = 15000, MMD loss = 0.0001399
2022-12-22,11:58:36-iter = 16000, MMD loss = 0.0000690
2022-12-22,11:59:25-iter = 17000

In [7]:
#### Evaluate condensed data ####
aucs = dict()
for this_net in cf.dm.eval_net:  # iterates through all networks to evaluate condensed dataset
    
    logging.info("Evaluating condensed data on network: {} ...".format(net_mp[this_net]))
    aucs[this_net] = []

    for it_eval in range(cf.dm.num_eval):
        net_eval = get_net(this_net, **cf[this_net+"_args"]).to(cf.device)  # get a random model
        # avoid any unaware modification
        data_syn_eval, label_syn_eval = \
            copy.deepcopy(data_syn.detach()), copy.deepcopy(label_syn.detach())

        # create a data loader from condensed dataset
        syn_train_loader = create_data_loader(
            data_syn_eval, label_syn_eval, batch_size=cf.train_batch, sampler=False)

        # evaluate a network on this condensed dataset
        _, test_auc = eval_net(
            net_eval, syn_train_loader, val_loader, test_loader,
            lr=cf.lr, epochs=cf.epochs, weight_decay=cf.weight_decay,
            save_dir=cf.save_dir, val_metric=cf.val_metric, device=cf.device,
            early_stop=cf.early_stop, early_stop_metric=cf.early_stop_metric,
        )

        logging.info("Eval {:02d}/{:02d}, test auc: {:.4f}".format(it_eval+1, cf.dm.num_eval, test_auc), )
        aucs[this_net].append(test_auc)

    logging.info("Condensed data test AUC ({}): {:.4f}±{:.4f}\n".format(
        net_mp[this_net], np.mean(aucs[this_net]), np.std(aucs[this_net])))
    
    if eval_metric.writer is not None:
        eval_metric.writer.set_step(it, mode="eval")
    eval_metric.update("syn_auc ({})".format(this_net), np.mean(aucs[this_net]))

2022-12-22,12:05:12-Evaluating condensed data on network: TCN-α ...
2022-12-22,12:05:15-Eval 01/05, test auc: 0.8136
2022-12-22,12:05:19-Eval 02/05, test auc: 0.8013
2022-12-22,12:05:22-Eval 03/05, test auc: 0.8087
2022-12-22,12:05:26-Eval 04/05, test auc: 0.8032
2022-12-22,12:05:29-Eval 05/05, test auc: 0.8135
2022-12-22,12:05:29-Condensed data test AUC (TCN-α): 0.8081±0.0051

2022-12-22,12:05:29-Evaluating condensed data on network: LSTM-α ...
2022-12-22,12:05:32-Eval 01/05, test auc: 0.8112
2022-12-22,12:05:35-Eval 02/05, test auc: 0.8081
2022-12-22,12:05:38-Eval 03/05, test auc: 0.8241
2022-12-22,12:05:41-Eval 04/05, test auc: 0.8180
2022-12-22,12:05:44-Eval 05/05, test auc: 0.8224
2022-12-22,12:05:44-Condensed data test AUC (LSTM-α): 0.8167±0.0062

2022-12-22,12:05:44-Evaluating condensed data on network: ViT-α ...
2022-12-22,12:05:50-Eval 01/05, test auc: 0.8037
2022-12-22,12:05:55-Eval 02/05, test auc: 0.7957
2022-12-22,12:06:01-Eval 03/05, test auc: 0.7985
2022-12-22,12:06:07-E

In [8]:
all_test_auc["iter_" + str(it)] = aucs

final_test_auc, auc_all = dict(), []
    
for this_net in cf.dm.eval_net:
    final_test_auc[this_net]=aucs[this_net]
    auc_all+=aucs[this_net]

## overall test performance
logging.info("Condensed data ({}) test AUC (all {} networks): {:.4f}±{:.4f}".format(data_syn.shape[0],len(cf.dm.eval_net), np.mean(auc_all), np.std(auc_all)))


2022-12-22,12:08:19-Condensed data (80) test AUC (all 11 networks): 0.8038±0.0138


In [9]:
# save condensed dataset
syn_data_save_path = os.path.join(cf.save_dir, "syn_data.pt")
data_save=[]
data_save.append([copy.deepcopy(data_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
torch.save(
    {'syn_dataset': data_save,
     'final_test_auc': final_test_auc,
     "all_test_auc":all_test_auc},
    syn_data_save_path)
logging.info("Condensed dataset saved to : {}".format(syn_data_save_path))

2022-12-22,12:08:19-Condensed dataset saved to : ../snapshots/PhysioNet_DC/syn_data.pt
