In [1]:
import os
import numpy as np

base_path = '/mnt/disk1/heonseok/MPMLD/output0727'

In [65]:
def get_model_list(beta_list, z_dim_list, setsize_list, lr_list, ref_list, weight_list):
    recon_model_list = []
    recon_spec_list = []
    for beta in beta_list:
        for z_dim in z_dim_list:
            for setsize in setsize_list:
                for lr in lr_list:
                    for ref in ref_list:
                        for weight in weight_list:
                            recon_model = 'VAE{}_z{}_setsize{}_lr{}_ref{}_rw{}_cc{}_cm{}_mc{}_mm{}'.format(
                                beta, z_dim, setsize, lr, ref, float(weight[0]), float(weight[1]), float(weight[2]), float(weight[3]), float(weight[4]), )
                            recon_model_list.append(recon_model)
                            recon_spec_list.append([beta, z_dim, setsize, lr, ref, weight[0], weight[1], weight[2], weight[3], weight[4]])

    for recon_model in recon_model_list:
        print('\''+recon_model+'\',')

    class_model_list = []
    for idx in range(len(recon_model_list)):
        class_model_list.append(recon_model_list[idx] + '_ResNet18')
        # print('\''+class_model_list[idx]+'\',')

    for recon_spec in recon_spec_list:
        # print(*recon_spec)
        spec_str = ''
        for spec in recon_spec:
            spec_str += str(spec) + ','
        print(spec_str)
    return recon_model_list, class_model_list

In [3]:
def collate_reconstruction_results(dataset, recon_model_list):
    for recon_model in recon_model_list:
        print('\''+recon_model+'\',')

    for recon_model in recon_model_list:
        recon_path = os.path.join(base_path, dataset, 'reconstructor', recon_model)

        for repeat in range(5):
            try:
                recon_repeat_path = os.path.join(recon_path, 'repeat{}'.format(repeat))
                recon_mse_path = os.path.join(recon_repeat_path, 'mse.npy')
                recon_mse = np.load(recon_mse_path, allow_pickle=True)

                if repeat == 0:
                    mse_list = recon_mse
                else:
                    mse_list += recon_mse

            except FileNotFoundError:
                # print('File not found: ', recon_acc_path)
                continue

        mse_list = mse_list/5

        result_list = []
        for idx, mse in enumerate(mse_list):
            result_list.append('{:.4f}'.format(mse))
        print(','.join(result_list))


In [70]:
def collate_disentanglement_result(dataset, recon_model_list):
    for recon_model in recon_model_list:
        print('\''+recon_model+'\',')

    for recon_model in recon_model_list:
        recon_path = os.path.join(base_path, dataset, 'reconstructor', recon_model)

        class_fz_list = []
        class_cz_list = []
        class_mz_list = []
        membership_fz_list = []
        membership_cz_list = []
        membership_mz_list = []

        for repeat in range(5):
            try:
                recon_repeat_path = os.path.join(recon_path, 'repeat{}'.format(repeat))
                recon_acc_path = os.path.join(recon_repeat_path, 'acc.npy')
                recon_acc = np.load(recon_acc_path, allow_pickle=True).item()

                class_fz_list.append(recon_acc['class_acc_full'])
                class_cz_list.append(recon_acc['class_acc_content'])
                class_mz_list.append(recon_acc['class_acc_style'])

                membership_fz_list.append(recon_acc['membership_acc_full'])
                membership_cz_list.append(recon_acc['membership_acc_content'])
                membership_mz_list.append(recon_acc['membership_acc_style'])

            except FileNotFoundError:
                # print('File not found: ', recon_acc_path)
                continue

        result_list = [
            # *recon_model.split('_'),
            np.average(class_fz_list), np.average(class_cz_list), np.average(class_mz_list),
            np.average(membership_fz_list), np.average(membership_cz_list), np.average(membership_mz_list)
        ]

        for idx, result in enumerate(result_list):
            if not type(result) is str:
                result_list[idx] = '{:.4f}'.format(result)
        print(','.join(result_list))


In [5]:
def collate_classification_result(dataset, class_model_list):
    for class_model in class_model_list:
        print('\''+class_model+'\',')

    for class_model in class_model_list:
        class_path = os.path.join(base_path, dataset, 'classifier', class_model)

        result_list = []
        for recon_type in recon_list:
            train_acc_list = []
            valid_acc_list = []
            test_acc_list = []
            for repeat in range(5):
                try:
                    class_repeat_path = os.path.join(class_path, recon_type, 'repeat{}'.format(repeat))
                    class_acc = np.load(os.path.join(class_repeat_path, 'acc.npy'), allow_pickle=True).item()
                    train_acc_list.append(class_acc['train'])
                    valid_acc_list.append(class_acc['valid'])
                    test_acc_list.append(class_acc['test'])

                except FileNotFoundError:
                    # print('File not found: ', class_repeat_path)
                    continue

            result_list.extend([
                np.average(train_acc_list),
                np.average(valid_acc_list),
                np.average(test_acc_list)
            ])

        for idx, result in enumerate(result_list):
            if not type(result) is str:
                result_list[idx] = '{:.4f}'.format(result)
        print(','.join(result_list))

In [48]:
def collate_attack_result(dataset, class_model_list):
    metric = 'acc'
    # metric = 'auroc'

    attack_type_list = [
        'stat',
        'black',
        # 'white',
    ]
    for class_model in class_model_list:
        print('\''+class_model+'\',')

    for class_model in class_model_list:
        attack_path = os.path.join(base_path, dataset, 'attacker', class_model)
        result_list = []
        for recon_type in recon_list:
            # white_acc_list = []
            for attack_type in attack_type_list:
                for repeat in range(5):

                    acc_list = []
                    attack_repeat_path = os.path.join(attack_path, recon_type, 'repeat{}'.format(repeat))
                    try:
                        attack_acc = np.load(os.path.join(attack_repeat_path, attack_type, '{}.npy'.format(metric)), allow_pickle=True)
                        if attack_type == 'stat':
                            acc_list.append(attack_acc)
                        elif attack_type == 'black':
                            acc_list.append(attack_acc.item()['test'])
                        # elif attack_type == 'white':
                        #     white_acc_list.append(attack_acc.item()['test'])

                    except FileNotFoundError:
                        # print('File not found: ', attack_repeat_path)
                        continue

                    result_list.extend([np.average(acc_list)])

        for idx, result in enumerate(result_list):
            if not type(result) is str:
                result_list[idx] = '{:.4f}'.format(result)
        print(','.join(result_list))


In [135]:
# Todo : refactoring
metric = 'acc'
# metric = 'auroc'

classifier_with_raw_data = 'original_setsize1000_FCNClassifierA'
class_path = os.path.join(base_path, dataset, 'classifier', classifier_with_raw_data)
attack_path = os.path.join(base_path, dataset, 'attacker', classifier_with_raw_data)
train_acc_list = []
valid_acc_list = []
test_acc_list = []
stat_acc_list = []
black_acc_list = []
for repeat in range(5):
    try:
        class_repeat_path = os.path.join(class_path, 'repeat{}'.format(repeat))
        attack_repeat_path = os.path.join(attack_path, 'repeat{}'.format(repeat))

        class_acc = np.load(os.path.join(class_repeat_path, 'acc.npy'), allow_pickle=True).item()
        train_acc_list.append(class_acc['train'])
        valid_acc_list.append(class_acc['valid'])
        test_acc_list.append(class_acc['test'])

        for attack_type in attack_type_list:
            attack_acc = np.load(os.path.join(attack_repeat_path, attack_type, '{}.npy'.format(metric)), allow_pickle=True)
            if attack_type == 'stat':
                stat_acc_list.append(attack_acc)
            elif attack_type == 'black':
                black_acc_list.append(attack_acc.item()['test'])

    except FileNotFoundError:
        # print('File not found: ', class_repeat_path)
        continue

print('{:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(
    np.average(train_acc_list), np.average(valid_acc_list), np.average(test_acc_list),
    np.average(stat_acc_list), np.average(black_acc_list),
))
# print(np.average(train_acc_list), np.average(valid_acc_list), np.average(test_acc_list))
# print(np.average(stat_acc_list), np.average(black_acc_list), np.average(test_acc_list))


0.9486 0.6470 0.5900 0.6420 nan


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


In [66]:
dataset = 'SVHN'

beta_list = [
    0.000001,
    # 0.00001,
    # 0.0001,
    # 0.001,
    # 0.01,
    # 0.1,
    # 1.0,
]

z_dim_list = [
    # '16',
    '64',
    # '128'
    # '256',
]

setsize_list = [
    # '500',
    # '1000',
    # '2000',
    '10000',
]

lr_list = [
    '0.001',
    # '0.01',
    # '0.1',
]

ref_list = [
    # '0.1',
    '1.0',
]


weight_list = [
    # ref 1.0 + permuted ref
    # [100., 0., 1., 1., 0.], # best result at 0727 6:37
    # [100., 0., 10., 1., 0.],
    # [100., 0., 1., 10., 0.],
    # [100., 0., 10., 10., 0.],
    # [100., 1., 1., 1., 1.],

    [1, 0, 1, 1, 0],
    [1, 0, 10, 1, 0],
    [1, 0, 1, 10, 0],
    [1, 0, 10, 10, 0],
    [1, 1, 1, 1, 1],
]

recon_model_list, class_model_list = get_model_list(beta_list, z_dim_list, setsize_list, lr_list, ref_list, weight_list)

recon_list = [
    'cb_mb',
    'cz_mb',
    'cb_mz',
]

'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm1.0_mc1.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm10.0_mc1.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm1.0_mc10.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm10.0_mc10.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc1.0_cm1.0_mc1.0_mm1.0',
1e-06,64,10000,0.001,1.0,1,0,1,1,0,
1e-06,64,10000,0.001,1.0,1,0,10,1,0,
1e-06,64,10000,0.001,1.0,1,0,1,10,0,
1e-06,64,10000,0.001,1.0,1,0,10,10,0,
1e-06,64,10000,0.001,1.0,1,1,1,1,1,


In [71]:
collate_disentanglement_result(dataset, recon_model_list)

'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm1.0_mc1.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm10.0_mc1.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm1.0_mc10.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc0.0_cm10.0_mc10.0_mm0.0',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw1.0_cc1.0_cm1.0_mc1.0_mm1.0',
0.5148,0.4786,0.1926,0.5000,0.5000,0.5000
0.6501,0.6350,0.1926,0.5000,0.5000,0.5009
0.7095,0.6356,0.1926,0.5000,0.5000,0.5112
0.2252,0.2107,0.1923,0.5000,0.5000,0.5000
0.1942,0.1940,0.1924,0.5000,0.5000,0.5000


In [10]:
collate_classification_result(dataset, class_model_list)

'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm1.0_mc1.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm10.0_mc1.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm1.0_mc10.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm10.0_mc10.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc1.0_cm1.0_mc1.0_mm1.0_ResNet18',
1.0000,0.8820,0.8665,0.3204,0.1960,0.1885,0.6960,0.5100,0.4835
1.0000,0.8600,0.8485,0.9773,0.1830,0.1685,1.0000,0.6195,0.6230
1.0000,0.8820,0.8690,0.2248,0.1730,0.1725,0.9944,0.3820,0.3730
1.0000,0.8570,0.8435,0.1791,0.1560,0.1575,0.5810,0.2695,0.2695
1.0000,0.8345,0.8110,0.3609,0.1520,0.1595,1.0000,0.5250,0.5210


In [50]:
collate_attack_result(dataset, class_model_list)

'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm1.0_mc1.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm10.0_mc1.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm1.0_mc10.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc0.0_cm10.0_mc10.0_mm0.0_ResNet18',
'VAE1e-06_z64_setsize10000_lr0.001_ref1.0_rw100.0_cc1.0_cm1.0_mc1.0_mm1.0_ResNet18',
0.6942,0.6567,0.5312,0.6763,0.5372,0.5653
0.6886,0.6767,0.8206,0.9187,0.7965,0.7713
0.7152,0.6793,0.5133,0.6123,0.7507,0.7837
0.7020,0.6880,0.5053,0.5310,0.5394,0.5673
0.7180,0.6997,0.5247,0.6063,0.7500,0.7550
