В этом ноубуке используются обeченные тестовые модели, с их помощью происходит сэмплирование и далее закодированные категориальные переменные переводятся в изначальные. После чего замеряется качество.

Способ нормализации: standart scale

Добавление шума: только к категориальным признакам OHE

In [18]:
import json
import pandas as pd
import torch
import matplotlib.pyplot as plt

from utils import *
from models.tabddpm_ohe_noise.tabddpm_ohe_noise import *

from tqdm.notebook import tqdm

from eval.base_metrics import calculate_base_metrics
from eval.similarity import calculate_similarity
from eval.mle import calculate_mle
from eval.alpha_beta import calculate_alpha_beta
from eval.detection import calculate_detection
from eval.dcr import calculate_DCR

### Подготовка

In [2]:
dataname = 'adult_ON_SC'
model_short = 'tabddpm_ON_SC'

In [3]:
if torch.cuda.is_available():
    device = f'cuda'
else:
    device = 'cpu'
CONFIG.add_arg('device', device)
print(f"Device: {device}")

CONFIG.add_arg('dataname', dataname)
CONFIG.add_arg('real_path',
                   f'./synthetic/{dataname}/initial_real.csv')
CONFIG.add_arg('test_path',
                   f'./synthetic/{dataname}/initial_test.csv')

CONFIG.add_arg('info_path',
                   f'./data/{dataname}/initial_info.json')

Device: cuda


In [4]:
model_save_path_hist = []
for sigma in [0, 0.001, 0.01, 0.1, 0.25, 0.5]:
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_const_{str(sigma).replace('.', '_')}"
    model_save_path_hist.append(model_save_path)

for i, sigmas in enumerate(range(10)):
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_mult_{i}"
    model_save_path_hist.append(model_save_path)

### Sample + Eval

In [5]:
CONFIG.get_all_args()

{'dataname': 'adult_ON_SC',
 'method': 'tabddpm_ON_SC_const_0',
 'device': 'cuda',
 'mode': 'train',
 'train': 1,
 'sample_save_path': 'synthetic/adult_ON_SC/tabddpm_ON_SC_const_0.csv',
 'sigma_scheduller_name': 'constant',
 'sigma_value': 0.001,
 'num_noise': 103,
 'real_path': './synthetic/adult_ON_SC/initial_real.csv',
 'test_path': './synthetic/adult_ON_SC/initial_test.csv',
 'info_path': './data/adult_ON_SC/initial_info.json',
 'save_path': './synthetic/adult_ON_QnSC/initial_tabddpm_test.csv'}

In [21]:
overall_metrics = {}

for model_save_path in model_save_path_hist[:1]:
    model_name = model_save_path.split('/')[-1]
    sample_save_path = f'./synthetic/{dataname}/{model_name}.csv'

    CONFIG.add_arg('method', model_name)
    CONFIG.add_arg('sample_save_path',
                       f"synthetic/{CONFIG.get_arg('dataname')}/{CONFIG.get_arg('method')}.csv")

    overall_metrics[model_name] = {}

    print(model_save_path, model_name, sample_save_path)

    tabddpm_noise_ohe = TabDDPM_OHE_Noise(CONFIG, model_save_path=model_save_path, sigmas=None,
                                         dataname=dataname, device=device)
    tabddpm_noise_ohe.sample(sample_save_path=sample_save_path)
    postsample_OHE(dataname, f'./synthetic/{dataname}/initial_{model_name}.csv')

    # подсчет метрик
    # overall_metrics[model_name]['base_metrics'] = calculate_base_metrics(make_binary=True, value=' >50K') - ошибка
    overall_metrics[model_name]['similarity'] = calculate_similarity()
    overall_metrics[model_name]['mle'] = calculate_mle()
    overall_metrics[model_name]['detection'] = calculate_detection()
    overall_metrics[model_name]['DCR'] = calculate_DCR()
    overall_metrics[model_name]['quality'] = calculate_alpha_beta()

In [22]:
final_metrics_table = []
for m in overall_metrics.keys():
    # Сбор таблички результатов
    tmp = pd.DataFrame([{'Model':'TabDDPM ON', 'Type':m, 'Data':dataname}])
    tmp.columns = pd.MultiIndex.from_tuples([('', i) for i in tmp.columns])
    result = [tmp]
    
    for metric_group in overall_metrics[m].keys():
        tmp = pd.DataFrame([overall_metrics[m][metric_group]])
        tmp.columns = pd.MultiIndex.from_tuples([(metric_group, i) for i in tmp.columns])
        result.append(tmp)
    result = pd.concat(result, axis = 1)
    final_metrics_table.append(result)
    
final_metrics_table = pd.concat(final_metrics_table)

In [24]:
final_metrics_table.round(3).sort_values(('', 'Type'))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,similarity,similarity,similarity,similarity,similarity,similarity,similarity,mle,mle,detection,DCR,quality,quality
Unnamed: 0_level_1,Model,Type,Data,"Column Shapes Score, %","Column Pair Trends Score, %","Overall Score (Average), %","Error rate (%) of column-wise density estimation, %","Error rate (%) of column-wise density estimation std, %","Error rate (%) of pair-wise column correlation score, %","Error rate (%) of pair-wise column correlation score std, %",ROC - AUC обучения на синтетических данных,"ROC - AUC обучения на синтетических данных, std",Score,Score,alpha precision,beta recall
0,TabDDPM ON,tabddpm_ON_SC_const_0,adult_ON_SC,89.149,52.828,70.988,10.851,16.302,47.172,43.698,0.871,0.009,0.911,0.667,0.93,0.45
0,TabDDPM ON,tabddpm_ON_SC_const_0_001,adult_ON_SC,89.096,52.392,70.744,10.904,16.445,47.608,41.796,0.871,0.009,0.928,0.67,0.928,0.447
0,TabDDPM ON,tabddpm_ON_SC_const_0_01,adult_ON_SC,30.047,26.883,28.465,69.953,25.964,73.117,32.148,0.926,0.005,0.0,0.693,0.0,0.0
0,TabDDPM ON,tabddpm_ON_SC_const_0_1,adult_ON_SC,63.131,45.859,54.495,36.869,18.228,54.141,30.985,0.502,0.009,0.124,0.666,0.0,0.0
0,TabDDPM ON,tabddpm_ON_SC_const_0_25,adult_ON_SC,63.05,45.642,54.346,36.95,18.033,54.358,30.925,0.504,0.007,0.121,0.664,0.0,0.0
0,TabDDPM ON,tabddpm_ON_SC_const_0_5,adult_ON_SC,50.832,48.522,49.677,49.168,31.896,51.478,28.919,0.885,0.005,0.0,0.654,0.0,0.0
0,TabDDPM ON,tabddpm_ON_SC_mult_0,adult_ON_SC,67.627,45.534,56.58,32.373,21.135,54.466,34.251,0.777,0.01,0.077,0.656,0.017,0.003
0,TabDDPM ON,tabddpm_ON_SC_mult_1,adult_ON_SC,47.5,44.998,46.249,52.5,36.557,55.002,30.118,0.846,0.042,0.0,0.659,0.0,0.0
0,TabDDPM ON,tabddpm_ON_SC_mult_2,adult_ON_SC,66.35,47.484,56.917,33.65,21.144,52.516,30.016,0.726,0.014,0.11,0.664,0.022,0.003
0,TabDDPM ON,tabddpm_ON_SC_mult_3,adult_ON_SC,57.031,41.458,49.245,42.969,25.194,58.542,32.075,0.676,0.048,0.009,0.664,0.0,0.0


In [27]:
import os
directory = f'./eval/total/{dataname}'
if not os.path.exists(directory):
    os.makedirs(directory)

In [28]:
final_metrics_table.sort_values(('', 'Type')).to_csv(f'./eval/total/{dataname}/{model_short}_final_metrcs_table.csv', index=False)