In [None]:
import pandas as pd
import json
from sdv.metadata import Metadata
from sdv.evaluation.single_table import evaluate_quality
from sdv.evaluation.single_table import get_column_plot

import os

In [None]:
metadata = Metadata.load_from_json(filepath='../2_modeling/gc_metadata.json')
for i in range(0, 10):
    print(f'Running diagnostic for set {i}')
    real_data = pd.read_csv(
        f'../data/train/set_{i}.csv'
    )
    synthetic_data = pd.read_csv(
        f'../data/synthetic/gc/set_{i}.csv'
    )
    report = evaluate_quality(
        real_data=real_data,
        synthetic_data=synthetic_data,
        metadata=metadata
    )
    if report.get_score() < 0.85:
        print(f'Warning: Queality score for set {i} is below 0.85')
        print(f'Score: {report.get_score()}')
        break
    print(report.get_properties())
    fig = report.get_visualization(property_name='Column Shapes')
    fig.update_layout(
        title=dict(
            text=f'Quality Report for Set {i}',
            x=0.5,
            y=0.95,
            font=dict(
                family='Helvetica, Arial, sans-serif',
                size=24,
                color='black'
            )
        ),
        margin=dict(l=40, r=40, t=80, b=40),
        paper_bgcolor='white',
        plot_bgcolor='white'
    )
    fig.write_image(
        f'quality_reports/gc/set_{i}.png',
        width=1200,
        height=800,
        scale=2
    )
    os.makedirs(f'quality_reports/gc/set_{i}', exist_ok=True)

    for column_name in metadata.get_column_names():
        fig = get_column_plot(
            real_data=real_data,
            synthetic_data=synthetic_data,
            metadata=metadata,
            column_name=column_name
        )
        fig.update_layout(
            title=f'Quality Report for column {column_name} in set {i}'
        )
        fig.write_image(
            f'quality_reports/gc/set_{i}/column_{column_name}.png',
            width=1200,
            height=800,
            scale=2
        )