В этом ноубуке используются обченные тестовые модели, с их помощью происходит сэмплирование и далее закодированные категориальные переменные переводятся в изначальные. После чего замеряется качество.

In [1]:
import json
import pandas as pd
import torch
import matplotlib.pyplot as plt

from utils import *
from models.tabddpm_ohe_noise.tabddpm_ohe_noise import *

from tqdm.notebook import tqdm

from eval.base_metrics import calculate_base_metrics
from eval.similarity import calculate_similarity
from eval.mle import calculate_mle
from eval.alpha_beta import calculate_alpha_beta
from eval.detection import calculate_detection
from eval.dcr import calculate_DCR

### Подготовка

In [2]:
dataname = 'adult_ON'
model_short = 'tabddpm_ON'

In [3]:
if torch.cuda.is_available():
    device = f'cuda'
else:
    device = 'cpu'
CONFIG.add_arg('device', device)
print(f"Device: {device}")

CONFIG.add_arg('dataname', dataname)
CONFIG.add_arg('real_path',
                   f'./synthetic/{dataname}/initial_real.csv')
CONFIG.add_arg('test_path',
                   f'./synthetic/{dataname}/initial_test.csv')

CONFIG.add_arg('info_path',
                   f'./data/{dataname}/initial_info.json')

Device: cuda


In [4]:
model_save_path_hist = []
for sigma in [0, 0.001, 0.01, 0.1, 0.25, 0.5]:
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_const_{str(sigma).replace('.', '_')}"
    model_save_path_hist.append(model_save_path)

for i, sigmas in enumerate(range(8)):
    model_save_path=f"./models/{model_short}/ckpt/{model_short}_mult_{i}"
    model_save_path_hist.append(model_save_path)

### Sample + Eval

In [12]:
CONFIG.get_all_args()

{'dataname': 'adult_ON',
 'method': 'tabddpm_ON_mult_6',
 'device': 'cuda',
 'save_path': './synthetic/adult_ON/initial_tabddpm_ON_mult_6.csv',
 'mode': 'train',
 'train': 1,
 'sample_save_path': './synthetic/adult_ON/initial_tabddpm_ON_mult_6.csv',
 'sigma_scheduller_name': 'constant',
 'sigma_value': 0.001,
 'num_noise': 103,
 'real_path': './synthetic/adult_ON/initial_real.csv',
 'test_path': './synthetic/adult_ON/initial_test.csv',
 'info_path': './data/adult_ON/initial_info.json'}

In [1]:
overall_metrics = {}

for model_save_path in model_save_path_hist:
    model_name = model_save_path.split('/')[-1]
    sample_save_path = f'./synthetic/{dataname}/{model_name}.csv'

    CONFIG.add_arg('method', model_name)
    CONFIG.add_arg('dataname', dataname)
    CONFIG.add_arg('save_path',
                       f"synthetic/{CONFIG.get_arg('dataname')}/{CONFIG.get_arg('method')}.csv")
    CONFIG.add_arg('sample_save_path',
                       f"synthetic/{CONFIG.get_arg('dataname')}/{CONFIG.get_arg('method')}.csv")

    overall_metrics[model_name] = {}

    print(model_save_path, model_name, sample_save_path)

    tabddpm_noise_ohe = TabDDPM_OHE_Noise(CONFIG, model_save_path=model_save_path, sigmas=None)
    tabddpm_noise_ohe.sample(sample_save_path=sample_save_path)
    postsample_OHE(dataname, f'./synthetic/{dataname}/initial_{model_name}.csv')

    # подсчет метрик
    # overall_metrics[model_name]['base_metrics'] = calculate_base_metrics(make_binary=True, value=' >50K') - ошибка
    overall_metrics[model_name]['similarity'] = calculate_similarity()
    overall_metrics[model_name]['mle'] = calculate_mle()
    overall_metrics[model_name]['detection'] = calculate_detection()
    overall_metrics[model_name]['DCR'] = calculate_DCR()
    overall_metrics[model_name]['quality'] = calculate_alpha_beta()

In [9]:
final_metrics_table = []
for m in overall_metrics.keys():
    # Сбор таблички результатов
    tmp = pd.DataFrame([{'Model':'TabDDPM ON', 'Type':m, 'Data':dataname}])
    tmp.columns = pd.MultiIndex.from_tuples([('', i) for i in tmp.columns])
    result = [tmp]
    
    for metric_group in overall_metrics[m].keys():
        tmp = pd.DataFrame([overall_metrics[m][metric_group]])
        tmp.columns = pd.MultiIndex.from_tuples([(metric_group, i) for i in tmp.columns])
        result.append(tmp)
    result = pd.concat(result, axis = 1)
    final_metrics_table.append(result)
    
final_metrics_table = pd.concat(final_metrics_table)

In [10]:
final_metrics_table.round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,similarity,similarity,similarity,similarity,similarity,similarity,similarity,mle,mle,detection,DCR,quality,quality
Unnamed: 0_level_1,Model,Type,Data,"Column Shapes Score, %","Column Pair Trends Score, %","Overall Score (Average), %","Error rate (%) of column-wise density estimation, %","Error rate (%) of column-wise density estimation std, %","Error rate (%) of pair-wise column correlation score, %","Error rate (%) of pair-wise column correlation score std, %",ROC - AUC обучения на синтетических данных,"ROC - AUC обучения на синтетических данных, std",Score,Score,alpha precision,beta recall
0,TabDDPM ON,tabddpm_ON_const_0,adult_ON,37.138,24.722,30.93,62.862,24.124,75.278,32.75,0.503,0.005,0.0,0.686,0.01,0.0
0,TabDDPM ON,tabddpm_ON_const_0_001,adult_ON,39.3,25.886,32.593,60.7,22.229,74.114,32.418,0.54,0.007,0.0,0.677,0.013,0.0
0,TabDDPM ON,tabddpm_ON_const_0_01,adult_ON,41.973,26.27,34.121,58.027,21.869,73.73,32.071,0.504,0.006,0.0,0.701,0.013,0.0
0,TabDDPM ON,tabddpm_ON_const_0_1,adult_ON,26.846,21.053,23.95,73.154,21.806,78.947,32.09,0.839,0.011,0.0,0.715,0.001,0.0
0,TabDDPM ON,tabddpm_ON_const_0_25,adult_ON,41.642,26.223,33.933,58.358,21.725,73.777,32.071,0.504,0.005,0.0,0.703,0.013,0.0
0,TabDDPM ON,tabddpm_ON_const_0_5,adult_ON,30.504,23.332,26.918,69.496,25.764,76.668,32.686,0.851,0.015,0.0,0.713,0.006,0.0
0,TabDDPM ON,tabddpm_ON_mult_0,adult_ON,31.363,22.219,26.791,68.637,26.868,77.781,29.172,0.86,0.002,0.0,0.869,0.007,0.0
0,TabDDPM ON,tabddpm_ON_mult_1,adult_ON,42.049,26.293,34.171,57.951,21.911,73.707,32.113,0.496,0.007,0.0,0.702,0.013,0.0
0,TabDDPM ON,tabddpm_ON_mult_2,adult_ON,34.944,23.343,29.143,65.056,23.308,76.657,32.008,0.729,0.019,0.0,0.722,0.003,0.0
0,TabDDPM ON,tabddpm_ON_mult_3,adult_ON,60.711,39.068,49.89,39.289,34.564,60.932,33.733,0.706,0.017,0.0,0.69,0.065,0.002


In [17]:
final_metrics_table.to_csv(f'./eval/total/{dataname}/{model_short}_final_metrcs_table.csv', index=False)

---
#### Пример ошибки

In [12]:
calculate_base_metrics(make_binary=True, value=' >50K')

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


ValueError: Found unknown categories [' Doctorate', ' Prof-school', ' Assoc-acdm', ' Some-college', ' Preschool', ' Assoc-voc', ' Bachelors', ' HS-grad', ' Masters'] in column 0 during transform