В этом ноубуке используются обeченные тестовые модели, с их помощью происходит сэмплирование и далее закодированные категориальные переменные переводятся в изначальные. После чего замеряется качество.

Способ нормализации: quantile transformation для числовых признаков и standart scale для категориальных OHE

Добавление шума: ко всем признакам

In [1]:
import json
import pandas as pd
import torch
import matplotlib.pyplot as plt

from utils import *
from models.tabddpm_ON_QnSC.tabddpm_ON_QnSC import *

from tqdm.notebook import tqdm

from eval.base_metrics import calculate_base_metrics
from eval.similarity import calculate_similarity
from eval.mle import calculate_mle
from eval.alpha_beta import calculate_alpha_beta
from eval.detection import calculate_detection
from eval.dcr import calculate_DCR

### Подготовка

In [2]:
dataname = 'adult_ON_QnSC'
model_short = 'tabddpm_ON_QnSC'

In [3]:
if torch.cuda.is_available():
    device = f'cuda'
else:
    device = 'cpu'
CONFIG.add_arg('device', device)
print(f"Device: {device}")

CONFIG.add_arg('dataname', dataname)
CONFIG.add_arg('real_path',
                   f'./synthetic/{dataname}/initial_real.csv')
CONFIG.add_arg('test_path',
                   f'./synthetic/{dataname}/initial_test.csv')

CONFIG.add_arg('info_path',
                   f'./data/{dataname}/initial_info.json')

Device: cuda


In [4]:
model_save_path_hist = []
for sigma in [0, 0.001, 0.01, 0.1, 0.25, 0.5]:
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_const_{str(sigma).replace('.', '_')}"
    model_save_path_hist.append(model_save_path)

for i, sigmas in enumerate(range(10)):
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_mult_{i}"
    model_save_path_hist.append(model_save_path)

In [5]:
from pickle import dump, load
with open(f'data/{dataname}/normalizers.json', 'r') as f:
        normalizers = json.load(f)
normalizers['num_normalizer'] = load(open(f'./data/{dataname}/num_normalizer_{dataname}.pkl', 'rb'))
normalizers['cat_normalizer'] = load(open(f'./data/{dataname}/cat_normalizer_{dataname}.pkl', 'rb'))
print(f"normalizers received from `data/{dataname}` folder")

normalizers received from `data/adult_ON_QnSC` folder


In [6]:
normalizers

{'len_num_prev': 6,
 'len_cat_prev': 102,
 'len_target_prev': 1,
 'num_normalizer': QuantileTransformer(output_distribution='normal', random_state=0,
                     subsample=1000000000),
 'cat_normalizer': StandardScaler()}

### Sample + Eval

In [7]:
CONFIG.get_all_args()

{'dataname': 'adult_ON_QnSC',
 'method': 'tabddpm_ON_QnSC_const_0',
 'device': 'cuda',
 'mode': 'train',
 'train': 1,
 'sample_save_path': 'synthetic/adult_ON_QnSC/tabddpm_ON_QnSC_const_0.csv',
 'sigma_scheduller_name': 'constant',
 'sigma_value': 0.001,
 'num_noise': 103,
 'real_path': './synthetic/adult_ON_QnSC/initial_real.csv',
 'test_path': './synthetic/adult_ON_QnSC/initial_test.csv',
 'info_path': './data/adult_ON_QnSC/initial_info.json',
 'save_path': './synthetic/adult_ON_QnSC/initial_tabddpm_ON_QnSC_const_0.csv'}

In [None]:
overall_metrics = {}

for model_save_path in model_save_path_hist:
    model_name = model_save_path.split('/')[-1]
    sample_save_path = f'./synthetic/{dataname}/{model_name}.csv'

    CONFIG.add_arg('method', model_name)
    CONFIG.add_arg('sample_save_path',
                       f"synthetic/{CONFIG.get_arg('dataname')}/{CONFIG.get_arg('method')}.csv")

    overall_metrics[model_name] = {}

    print(model_save_path, model_name, sample_save_path)

    tabddpm_noise_ohe = TabDDPM_OHE_Noise_QnSC(CONFIG, model_save_path=model_save_path, sigmas=None)
    tabddpm_noise_ohe.sample(sample_save_path=sample_save_path)
    postsample_OHE(dataname, f'./synthetic/{dataname}/initial_{model_name}.csv', normalizers=normalizers)

    # подсчет метрик
    try:
        overall_metrics[model_name]['base_metrics'] = calculate_base_metrics(make_binary=True, value=' >50K') # ошибка
    except Exception as e:
        overall_metrics[model_name]['base_metrics'] = np.nan
    overall_metrics[model_name]['similarity'] = calculate_similarity()
    overall_metrics[model_name]['mle'] = calculate_mle()
    overall_metrics[model_name]['detection'] = calculate_detection()
    overall_metrics[model_name]['DCR'] = calculate_DCR()
    overall_metrics[model_name]['quality'] = calculate_alpha_beta()

In [11]:
final_metrics_table = []
for m in overall_metrics.keys():
    # Сбор таблички результатов
    tmp = pd.DataFrame([{'Model':model_short, 'Type':m, 'Data':dataname}])
    tmp.columns = pd.MultiIndex.from_tuples([('', i) for i in tmp.columns])
    result = [tmp]
    
    for metric_group in overall_metrics[m].keys():
        tmp = pd.DataFrame([overall_metrics[m][metric_group]])
        tmp.columns = pd.MultiIndex.from_tuples([(metric_group, i) for i in tmp.columns])
        result.append(tmp)
    result = pd.concat(result, axis = 1)
    final_metrics_table.append(result)
    
final_metrics_table = pd.concat(final_metrics_table)

In [12]:
final_metrics_table.drop(columns=('base_metrics', 0))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,similarity,similarity,similarity,similarity,similarity,similarity,similarity,...,detection,DCR,quality,quality,base_metrics,base_metrics,base_metrics,base_metrics,base_metrics,base_metrics
Unnamed: 0_level_1,Model,Type,Data,"Column Shapes Score, %","Column Pair Trends Score, %","Overall Score (Average), %","Error rate (%) of column-wise density estimation, %","Error rate (%) of column-wise density estimation std, %","Error rate (%) of pair-wise column correlation score, %","Error rate (%) of pair-wise column correlation score std, %",...,Score,Score,alpha precision,beta recall,Original Logistic,Synthetic Logistic,Original Tree,Synthetic Tree,"Accuracy Loss Logistic, %","Accuracy Loss Tree, %"
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0,adult_ON_QnSC,97.761944,95.860595,96.81127,2.238056,1.406972,4.139405,2.181991,...,0.888111,0.666196,0.922654,0.463209,,,,,,
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_001,adult_ON_QnSC,97.912022,95.827279,96.86965,2.087978,1.216163,4.172721,2.150887,...,0.888216,0.668192,0.93127,0.455224,0.657095,0.651727,0.659361,0.623906,0.817014,5.377264
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_01,adult_ON_QnSC,98.227327,96.069651,97.148489,1.772673,0.943924,3.930349,1.823743,...,0.924525,0.66767,0.93725,0.478229,0.657095,0.644119,0.66094,0.62303,1.974756,5.735764
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_1,adult_ON_QnSC,97.577879,95.433992,96.505936,2.422121,1.565294,4.566008,2.765825,...,0.883197,0.666718,0.915629,0.449843,0.657095,0.632085,0.660812,0.624122,3.806212,5.55223
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_25,adult_ON_QnSC,98.005794,96.362806,97.1843,1.994206,1.286598,3.637194,2.687609,...,0.922326,0.668591,0.937418,0.43264,0.657095,0.653177,0.660136,0.63481,0.59623,3.836479
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_5,adult_ON_QnSC,97.806988,95.469228,96.638108,2.193012,0.878965,4.530772,4.474248,...,0.891523,0.671509,0.97292,0.310906,0.657095,0.620161,0.660327,0.617277,5.620826,6.519481
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_0,adult_ON_QnSC,94.681367,90.343251,92.512309,5.318633,2.778405,9.656749,4.195256,...,0.84759,0.665858,0.946847,0.433343,0.657095,0.608306,0.660545,0.630422,7.424993,4.560296
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_1,adult_ON_QnSC,96.826469,93.142738,94.984604,3.173531,1.732151,6.857262,3.043006,...,0.875991,0.670465,0.947879,0.4449,0.657095,0.640202,0.660539,0.627513,2.57087,4.999925
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_2,adult_ON_QnSC,97.720381,93.94949,95.834936,2.279619,1.289328,6.05051,3.457592,...,0.886685,0.661773,0.943019,0.460733,0.657095,0.627952,0.66005,0.63148,4.435208,4.328476
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_3,adult_ON_QnSC,97.76911,94.897559,96.333335,2.23089,1.429313,5.102441,2.969206,...,0.887242,0.671816,0.94073,0.466556,0.657095,0.64389,0.660186,0.630087,2.009561,4.559181


In [17]:
pd.set_option('display.max_columns', None)
final_metrics_table.round(3).drop(columns=('base_metrics', 0)).sort_values([('similarity', 'Overall Score (Average), %')],
                                                                          ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,similarity,similarity,similarity,similarity,similarity,similarity,similarity,mle,mle,detection,DCR,quality,quality,base_metrics,base_metrics,base_metrics,base_metrics,base_metrics,base_metrics
Unnamed: 0_level_1,Model,Type,Data,"Column Shapes Score, %","Column Pair Trends Score, %","Overall Score (Average), %","Error rate (%) of column-wise density estimation, %","Error rate (%) of column-wise density estimation std, %","Error rate (%) of pair-wise column correlation score, %","Error rate (%) of pair-wise column correlation score std, %",ROC - AUC обучения на синтетических данных,"ROC - AUC обучения на синтетических данных, std",Score,Score,alpha precision,beta recall,Original Logistic,Synthetic Logistic,Original Tree,Synthetic Tree,"Accuracy Loss Logistic, %","Accuracy Loss Tree, %"
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_5,adult_ON_QnSC,98.134,96.582,97.358,1.866,1.056,3.418,1.81,0.875,0.008,0.918,0.671,0.936,0.48,0.657,0.652,0.661,0.632,0.757,4.37
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_25,adult_ON_QnSC,98.006,96.363,97.184,1.994,1.287,3.637,2.688,0.877,0.007,0.922,0.669,0.937,0.433,0.657,0.653,0.66,0.635,0.596,3.836
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_6,adult_ON_QnSC,98.029,96.325,97.177,1.971,1.368,3.675,2.252,0.867,0.008,0.919,0.668,0.94,0.473,,,,,,
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_01,adult_ON_QnSC,98.227,96.07,97.148,1.773,0.944,3.93,1.824,0.857,0.007,0.925,0.668,0.937,0.478,0.657,0.644,0.661,0.623,1.975,5.736
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_8,adult_ON_QnSC,97.98,96.015,96.998,2.02,1.162,3.985,2.016,0.882,0.008,0.904,0.67,0.923,0.48,0.657,0.651,0.661,0.632,0.886,4.395
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_7,adult_ON_QnSC,97.92,96.069,96.995,2.08,1.301,3.931,2.19,0.879,0.005,0.9,0.667,0.919,0.476,,,,,,
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_mult_9,adult_ON_QnSC,97.863,95.953,96.908,2.137,1.345,4.047,2.291,0.876,0.008,0.875,0.665,0.93,0.458,,,,,,
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_001,adult_ON_QnSC,97.912,95.827,96.87,2.088,1.216,4.173,2.151,0.857,0.008,0.888,0.668,0.931,0.455,0.657,0.652,0.659,0.624,0.817,5.377
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0,adult_ON_QnSC,97.762,95.861,96.811,2.238,1.407,4.139,2.182,0.877,0.006,0.888,0.666,0.923,0.463,,,,,,
0,tabddpm_ON_QnSC,tabddpm_ON_QnSC_const_0_5,adult_ON_QnSC,97.807,95.469,96.638,2.193,0.879,4.531,4.474,0.835,0.007,0.892,0.672,0.973,0.311,0.657,0.62,0.66,0.617,5.621,6.519


In [15]:
import os
if not os.path.exists(f'./eval/total/{dataname}'):
    os.makedirs(f'./eval/total/{dataname}')
final_metrics_table.to_csv(f'./eval/total/{dataname}/{model_short}_final_metrcs_table.csv', index=False)