# Create a PDF of the Data Quality Report for Data Desginer

# 🎛️ Import things

In [None]:
import os
from pprint import pprint

from datasets import load_dataset

from navigator_helpers.llms.llm_suite import GretelLLMSuite
from navigator_helpers.tasks.evaluation.evaluation import (
    BaseEvaluationTaskSuite,
    VisualizationTaskSuite
)

# set environment variable 'GRETEL_PROD_API_KEY' from https://console.gretel.ai/users/me/key
gretel_prod_api_key = input("Enter your Gretel API key from https://console.gretel.ai/users/me/key: ")
os.environ['GRETEL_PROD_API_KEY'] = gretel_prod_api_key

# 🔢 Choose Dataset for Evaluation

In [None]:
# Set the number of samples to load from the dataset for testing
# Set to None to use the full dataset
NUM_SAMPLES = 1000

datasets_dict = {
    "synthetic_text_to_sql": {
        "dataset_kwargs": {
            "path": "gretelai/synthetic_text_to_sql",
            "split": "train"
        },
        "code_lang": "sql",
        "eval_kwargs":{
            "instruction_col_name": "sql_prompt",
            "code_col_name": "sql",
            "context_col_name": "sql_context"
        }
    },
    "gsm8k": {
        "dataset_kwargs": {
            "path": "openai/gsm8k",
            "name": "main",
            "split": "train"
        },
        "eval_kwargs": {
            "instruction_col_name": "question",
            "code_col_name": "answer",
        }
    },
    "synthetic_gsm8k": {
        "dataset_kwargs": {
            "path": "gretelai/synthetic-gsm8k-reflection-405b",
            "split": "train"
        },
        "eval_kwargs": {
            "instruction_col_name": "question",
            "code_col_name": "answer",
        }
    },
    "xlcost_text_to_code": {
        "dataset_kwargs": {
            "path": "codeparrot/xlcost-text-to-code",
            "split": "train"
        },
        "code_lang": "python",
        "eval_kwargs": {
            "instruction_col_name": "text",
            "code_col_name": "code",
        }
    },
}

# Prompt user to select a dataset
print("Available datasets:")
for key in datasets_dict.keys():
    print(f" - {key}")

selected_dataset = input("\nEnter the name of the dataset to load: ").strip()

# Load the selected dataset
if selected_dataset in datasets_dict:
    dataset_dict = datasets_dict[selected_dataset]
    eval_kwargs = dataset_dict["eval_kwargs"]
    code_lang = dataset_dict["code_lang"] if "code_lang" in dataset_dict.keys() else None
    dataset = load_dataset(**dataset_dict["dataset_kwargs"])

    # Optionally, select a subset if NUM_SAMPLES is specified
    if NUM_SAMPLES is not None and NUM_SAMPLES < len(dataset):
        dataset = dataset.select(range(NUM_SAMPLES))
    
    dataset_df = dataset.to_pandas()
    
    print(f"Loaded dataset '{selected_dataset}' successfully!")
else:
    print("Error: Dataset not found. Please enter a valid dataset name.")

In [None]:
# tmp
import pandas as pd
dataset_df = pd.read_json('/mnt/foundation-shared/nina_xu_gretel_ai/datasets/text_to_python_v1.json')
dataset_df.head(1)

In [4]:
# tmp
results = {'results': {'row_uniqueness': {'percent_unique': 100.0,
   'percent_semantically_unique': 100.0,
   'non_unique_ids': [],
   'non_semantically_unique_ids': []},
  'feature_cardinality': {'id': 100,
   'domain': 10,
   'topic': 69,
   'complexity': 4,
   'prompt': 100,
   'dependency_list': 12,
   'code': 100},
  'feature_distribution': {'distribution': {'id': None,
    'domain': {'Educational Technology': 13,
     'Healthcare Technology': 12,
     'E-commerce': 12,
     'Financial Services': 11,
     'Cybersecurity': 10,
     'Aerospace Software': 10,
     'Telecommunications': 10,
     'Video Game Development': 8,
     'Artificial Intelligence': 8,
     'Automotive Software': 6},
    'topic': {'avg_length': 19.69,
     'std_length': 4.500718236733404,
     'avg_word_count': 2.26,
     'word_count_histogram': ([0, 0, 1, 0, 0, 74, 0, 23, 0, 2],
      [0.0,
       0.4,
       0.8,
       1.2000000000000002,
       1.6,
       2.0,
       2.4000000000000004,
       2.8000000000000003,
       3.2,
       3.6,
       4.0])},
    'complexity': {'Expert: Concurrency, parallel processing, and metaprogramming': 29,
     'Beginner: Basic syntax, data types, and control structures': 27,
     'Advanced: Object-oriented programming and exception handling': 26,
     'Intermediate: Functions, modules, and file handling': 18},
    'prompt': {'avg_length': 1103.08,
     'std_length': 259.7072042111884,
     'avg_word_count': 159.96,
     'word_count_histogram': ([0, 0, 0, 35, 49, 7, 5, 3, 0, 1],
      [0.0,
       35.9,
       71.8,
       107.69999999999999,
       143.6,
       179.5,
       215.39999999999998,
       251.29999999999998,
       287.2,
       323.09999999999997,
       359.0])},
    'dependency_list': {'avg_length': 60.38,
     'std_length': 1.9683351935946871,
     'avg_word_count': 5.0,
     'word_count_histogram': ([0, 0, 0, 0, 0, 0, 0, 0, 0, 100],
      [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0])},
    'code': {'avg_length': 1059.34,
     'std_length': 595.0563375416523,
     'avg_word_count': 102.05,
     'word_count_histogram': ([7, 16, 16, 18, 16, 15, 4, 3, 4, 1],
      [0.0,
       27.2,
       54.4,
       81.6,
       108.8,
       136.0,
       163.2,
       190.4,
       217.6,
       244.79999999999998,
       272.0])}},
   'score': {'id': None,
    'domain': {'gini-simpson_index': 0.8958},
    'topic': {'text_diversity_index': 0.8279853407018426},
    'complexity': {'gini-simpson_index': 0.743},
    'prompt': {'text_diversity_index': 0.5630924313331656},
    'dependency_list': {'text_diversity_index': 0.18927071370650894},
    'code': {'text_diversity_index': 0.7658162507189068}}},
  'num_words_per_record': {'average_words_per_record': 46.26166666666666,
   'word_counts_per_column': {'domain': 1.76,
    'topic': 2.26,
    'complexity': 6.54,
    'prompt': 159.96,
    'dependency_list': 5.0,
    'code': 102.05},
   'average_tokens_per_record': 496.66,
   'tokens_per_column': {'id': 1.0,
    'domain': 2.77,
    'topic': 2.78,
    'complexity': 10.71,
    'prompt': 222.83,
    'dependency_list': 20.3,
    'code': 236.27},
   'total_tokens': 49666},
  'column_notes': {'id': 'Unique ID',
   'domain': 'Seed Column',
   'topic': 'Seed Column',
   'complexity': 'Seed Column',
   'prompt': '',
   'dependency_list': 'Seed Column',
   'code': ''},
  'column_data_types': {'id': 'Other',
   'domain': 'Categorical',
   'topic': 'Text',
   'complexity': 'Categorical',
   'prompt': 'Text',
   'dependency_list': 'Text',
   'code': 'Text'}},
 'dataset_overview_statistics': {'number_of_rows': 100,
  'number_of_columns': 7,
  'number_of_categorical_columns': 2,
  'number_of_text_columns': 4,
  'number_of_numerical_columns': 0,
  'number_of_other_columns': 1,
  'number_of_seed_columns': 4,
  'data_completeness': 100.0,
  'single_row': {'id': 0,
   'domain': 'Cybersecurity',
   'topic': 'Application Security',
   'complexity': 'Beginner: Basic syntax, data types, and control structures',
   'prompt': 'Write a Python function that checks if a given password meets the following security criteria:\n\n1. It should be at least 8 characters long.\n2. It must contain at least one uppercase letter.\n3. It must contain at least one lowercase letter.\n4. It must contain at least one digit.\n5. It must contain at least one special character from the following set: !@#$%^&*().\n\nThe function should return True if the password meets all the criteria, and False otherwise. Name the function "is_secure_password".\n\n### Instructions\n    * The code should have a complexity of "Beginner: Basic syntax, data types, and control structures".\n    * Write code that might be used in the "Cybersecurity" industry within a "Application Security" context.\n    * Try to include at least 1 of the following Python packages:  `scikit-learn`, `requests`, `numpy`, `matplotlib`, `pandas`.\n    * Include only the code, without any comments or additional text.\n',
   'dependency_list': ['scikit-learn',
    'requests',
    'numpy',
    'matplotlib',
    'pandas'],
   'code': "import re\n\ndef is_secure_password(password):\n    if len(password) < 8:\n        return False\n    if not re.search(r'[A-Z]', password):\n        return False\n    if not re.search(r'[a-z]', password):\n        return False\n    if not re.search(r'\\d', password):\n        return False\n    if not re.search(r'[!@#$%^&*()]', password):\n        return False\n    return True"}}}


## Run Evaluation on dataset

In [None]:
llm_suite = GretelLLMSuite()

In [None]:
# # Define a dictionary to store evaluation results
# results = {}

# # Uncomment the following lines to run individual evaluation tasks
# results.update({"row_uniqueness": BaseEvaluationTaskSuite(llm_suite, dataset_df).row_uniqueness()})
# results.update({"feature_cardinality": BaseEvaluationTaskSuite(llm_suite, dataset_df).feature_cardinality()})
# results.update({"feature_distribution": BaseEvaluationTaskSuite(llm_suite, dataset_df).feature_distribution()})
# results.update({"num_words_per_record": BaseEvaluationTaskSuite(llm_suite, dataset_df).num_words_per_record()})

# # Uncomment this line to run everything, including LLM-as-a-judge
# # results = BaseEvaluationTaskSuite(llm_suite, dataset_df, code_lang, eval_kwargs).evaluate_all()

# pprint(results)


In [43]:
# Standard library imports
from typing import List, Optional, Tuple, Union, Dict, Any
import math
import io

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.io import to_image
from PIL import Image as PILImage
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, PageBreak, Paragraph, Spacer, Image, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.lib.units import inch
from reportlab.lib.enums import TA_CENTER

# Constants
FIG_WIDTH = 7  # inches
FIG_HEIGHT = 3.6  # inches
SCORE_VALUES = [
    {"label": "Very poor", "color": "rgb(229, 60, 26)"},
    {"label": "Poor", "color": "rgb(229, 128, 26)"},
    {"label": "Average", "color": "rgb(229, 161, 26)"},
    {"label": "Good", "color": "rgb(183, 210, 45)"},
    {"label": "Excellent", "color": "rgb(72, 210, 45)"},
]
PRIMARY_PALETTE = ['#2E1065', '#D3A66E', '#110420', '#4F00A9', '#F9EFDE', '#1D0B32', '#8D32FA', '#C399FF', '#EFE5FF', '#EFD7AD', '#F4E3C6', '#FBF7ED', '#A59DAD', '#D2CED6', '#E8E7EB']
SECONDARY_PALETTE = ['#052095', '#FF6BA9', '#3056F2', '#FFA8CC', '#8BB9FF', '#FFEDF5', '#E5F0FF', '#1E9C98', '#92F6F4', '#C5FEFF', '#E8FEFF', '#FF9248', '#FFB38A', '#FFD7B5', '#FFECDC', '#FF6700', '#FFCA1A', '#FFE16D', '#FFF099', '#FFFDE3', '#ECA10A']

# Set up custom color palette for seaborn
sns.set_theme(style="white")
sns.set_palette(sns.color_palette(SECONDARY_PALETTE))

def create_chart(data: pd.Series, title: str, xlabel: str, ylabel: str) -> Image:
    fig, ax = plt.subplots(figsize=(FIG_WIDTH, FIG_HEIGHT))
    bars = ax.bar(range(len(data)), data.values, color='#4F00A9')
    ax.set_facecolor('white')
    fig.patch.set_facecolor('white')
    
    ax.set_title(title, fontsize=10, color='#1D0B32')
    ax.set_xlabel(xlabel, fontsize=10, color='#1D0B32')
    ax.set_ylabel(ylabel, fontsize=10, color='#1D0B32')
    ax.set_xticks(range(len(data)))
    
    truncated_labels = [str(label)[:17] + '...' if len(str(label)) > 20 else str(label) for label in data.index]
    ax.set_xticklabels(truncated_labels, rotation=45, ha='right', fontsize=8, color='#1D0B32')
    
    ax.tick_params(axis='both', colors='#1D0B32')
    ax.tick_params(axis='y', labelsize=6)
    
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}' if isinstance(height, float) else f'{height}',
                ha='center', va='bottom', fontsize=8, color='#1D0B32')
    
    plt.tight_layout()
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
    img_buffer.seek(0)
    return Image(img_buffer, width=7*inch, height=4*inch)

def create_pareto_chart(data: pd.DataFrame, title: str) -> Image:
    fig, ax1 = plt.subplots(figsize=(FIG_WIDTH, FIG_HEIGHT))
    ax2 = ax1.twinx()
    
    bars = ax1.bar(range(len(data)), data['count'], color='#4F00A9')
    ax1.set_facecolor('white')
    fig.patch.set_facecolor('white')
    
    ax1.set_xlabel('Categories', fontsize=10, color='#1D0B32')
    ax1.set_ylabel('Count', fontsize=10, color='#1D0B32')
    ax1.set_title(title, fontsize=10, color='#1D0B32')
    
    cumulative_percentage = 100 * data['count'].cumsum() / data['count'].sum()
    ax2.plot(range(len(data)), cumulative_percentage, color='#FF6700', marker='D', ms=4)
    ax2.set_ylabel('Cumulative Percentage', fontsize=10, color='#1D0B32')
    ax2.set_ylim([0, 110])
    
    ax1.tick_params(axis='both', colors='#1D0B32')
    ax2.tick_params(axis='both', colors='#1D0B32')
    ax1.tick_params(axis='y', labelsize=6)
    ax2.tick_params(axis='y', labelsize=6)
    
    ax1.set_xticks(range(len(data)))
    truncated_labels = [str(label)[:17] + '...' if len(str(label)) > 20 else str(label) for label in data.index]
    ax1.set_xticklabels(truncated_labels, rotation=45, ha='right', fontsize=6, color='#1D0B32')
    
    for i, v in enumerate(data['count']):
        ax1.text(i, v, f'{v:.2f}' if isinstance(v, float) else f'{v}', ha='center', va='bottom', fontsize=8, color='#1D0B32')
    
    ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0f}%'))
    
    plt.tight_layout()
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
    img_buffer.seek(0)
    return Image(img_buffer, width=FIG_WIDTH*inch, height=FIG_HEIGHT*inch)

def create_text_diversity_chart(text_diversity_df: pd.DataFrame) -> Image:
    plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT))
    ax = sns.barplot(x=text_diversity_df.index, y='diversity_index', data=text_diversity_df, color='#4F00A9')
    ax.set_facecolor('white')
    plt.gcf().patch.set_facecolor('white')
    
    plt.title("Text Diversity Indices", fontsize=10, color='#1D0B32')
    plt.ylabel("Diversity Index", fontsize=10, color='#1D0B32')
    plt.xlabel("", fontsize=10, color='#1D0B32')
    plt.xticks(rotation=45, ha='right', fontsize=8, color='#1D0B32')
    
    plt.ylim(0, 1)
    
    ax.tick_params(axis='both', colors='#1D0B32')
    ax.tick_params(axis='y', labelsize=6)
    
    for i, v in enumerate(text_diversity_df['diversity_index']):
        ax.text(i, v, f'{v:.2f}', ha='center', va='bottom', fontsize=8, color='#1D0B32')
    plt.tight_layout()
    
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
    img_buffer.seek(0)
    return Image(img_buffer, width=FIG_WIDTH*inch, height=FIG_HEIGHT*inch)

def create_histogram(counts: List[int], bins: List[float], col_name: str, data_type: str = "Text") -> Image:
    assert data_type in ["Text", "Numeric"], f"Invalid data type: {data_type}"
    if data_type == "Text":
        x_label = "Word Count"
    else:
        x_label = "Value"
    
    plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT))
    plt.hist(bins[:-1], bins, weights=counts, color='#4F00A9')

    # Add labels and title
    plt.xlabel(x_label, fontsize=10, color='#1D0B32')
    plt.ylabel("Count", fontsize=10, color='#1D0B32')
    plt.title(f"{col_name.replace('_', ' ').title()}: {x_label} Distribution (Histogram)", fontsize=10, color='#1D0B32')
    plt.xticks(fontsize=6, color='#1D0B32')
    plt.yticks(fontsize=6, color='#1D0B32')

    # Add counts above bars
    for i in range(len(counts)):
        plt.text((bins[i]+bins[i+1])/2, counts[i], str(counts[i]), ha='center', va='bottom', fontsize=8, color='#1D0B32')

    plt.tight_layout()
    
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
    img_buffer.seek(0)
    return Image(img_buffer, width=FIG_WIDTH*inch, height=FIG_HEIGHT*inch)

def create_schema_table(dataset_df: pd.DataFrame, data_results: Dict[str, Any]) -> Tuple[Table, Dict[str, float]]:
    schema_data = [['Column Name', 'Type', 'Total Count', '% Null', 'Average Length', 'Avg Tokens']]
    for col in dataset_df.columns:
        dtype = str(dataset_df[col].dtype)
        total_count = len(dataset_df)
        null_count = dataset_df[col].isnull().sum()
        pcnt_null = (null_count / total_count) * 100
        avg_length = 'N/A'
        avg_tokens = 'N/A'
        if 'num_words_per_record' in data_results:
            num_words = data_results['num_words_per_record']
            if col in num_words['word_counts_per_column']:
                avg_length = num_words['word_counts_per_column'][col]
            if col in num_words['tokens_per_column']:
                avg_tokens = num_words['tokens_per_column'][col]
        schema_data.append([col, dtype, total_count, f"{pcnt_null:.2f}%", avg_length, avg_tokens])
    
    table = Table(schema_data)
    style = TableStyle([
        ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#4F00A9')),
        ('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
        ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
        ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
        ('FONTSIZE', (0, 0), (-1, 0), 8),
        ('BOTTOMPADDING', (0, 0), (-1, 0), 6),
        ('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#EFE5FF')),
        ('TEXTCOLOR', (0, 1), (-1, -1), colors.HexColor('#110420')),
        ('ALIGN', (0, 1), (-1, -1), 'LEFT'),
        ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
        ('FONTSIZE', (0, 1), (-1, -1), 7),
        ('TOPPADDING', (0, 1), (-1, -1), 3),
        ('BOTTOMPADDING', (0, 1), (-1, -1), 3),
        ('GRID', (0, 0), (-1, -1), 1, colors.HexColor('#4F00A9'))
    ])
    table.setStyle(style)
    return table

def create_overview_table(overview_data: List[List[str]]) -> Table:
    table = Table(overview_data, colWidths=[1.5*inch, 1.5*inch])
    style = TableStyle([
        ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#4F00A9')),
        ('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
        ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
        ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
        ('FONTSIZE', (0, 0), (-1, 0), 8),
        ('BOTTOMPADDING', (0, 0), (-1, 0), 3),  # Reduced padding
        ('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#EFE5FF')),
        ('TEXTCOLOR', (0, 1), (-1, -1), colors.HexColor('#110420')),
        ('ALIGN', (0, 1), (-1, -1), 'LEFT'),
        ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
        ('FONTSIZE', (0, 1), (-1, -1), 7),
        ('TOPPADDING', (0, 1), (-1, -1), 3),  # Minimal top padding
        ('BOTTOMPADDING', (0, 1), (-1, -1), 3),  # Minimal bottom padding
        ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#4F00A9'))  # Thinner grid lines
    ])
    table.setStyle(style)
    return table

def create_single_record_preview(row: Dict[str, Any]) -> str:
    preview_text = ""
    for column, value in row.items():
        truncated_value = str(value)[:100] + ('...' if len(str(value)) > 100 else '')
        preview_text += f"<b>{column}:</b>\t{truncated_value}"
        preview_text += "<br/>"
    return preview_text

def _generate_pointer_path(score: int) -> str:
    theta = score * (282 - 34) / 100 - 34
    rads = math.radians(theta)
    radius = 0.45
    size = 0.025
    x1 = -1 * radius * math.cos(rads) + 0.5
    y1 = radius * math.sin(rads) + 0.5
    return f"""
    M {x1} {y1}
    L {-1 * size * math.cos(math.radians(theta - 90)) + 0.5}
        {size * math.sin(math.radians(theta - 90)) + 0.5}
    L {-1 * size * math.cos(math.radians(theta + 90)) + 0.5}
        {size * math.sin(math.radians(theta + 90)) + 0.5}
    Z"""

def gauge_and_needle_chart(score: Optional[int], display_score: bool = True, marker_colors: Optional[List[str]] = None) -> go.Figure:
    if score is None:
        fig = go.Figure(
            layout=go.Layout(
                annotations=[
                    go.layout.Annotation(
                        text="N/A",
                        font=dict(color="rgba(174, 95, 5, 1)", size=18),
                        showarrow=False,
                        xref="paper",
                        yref="paper",
                        x=0.5,
                        y=0.5,
                    )
                ]
            )
        )
        marker_colors = ["rgb(220, 220, 220)", "rgba(255, 255, 255, 0)"]
        pie_values = [70, 30]
    else:
        if not marker_colors:
            marker_colors = [s["color"] for s in SCORE_VALUES]
        if marker_colors[-1] != "rgba(255, 255, 255, 0)":
            marker_colors.append("rgba(255, 255, 255, 0)")
        pie_values = [70 // (len(marker_colors) - 1)] * (len(marker_colors) - 1)
        pie_values.append(30)
        fig = go.Figure()

    fig.update_layout(
        autosize=False,
        showlegend=False,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        height=180,
        width=180,
        margin=dict(l=0, r=0, t=0, b=0),
        paper_bgcolor="rgba(0,0,0,0)",
        hovermode=False,
        modebar=None,
    )
    fig.add_trace(
        go.Pie(
            name="gauge",
            values=pie_values,
            marker=dict(
                colors=marker_colors,
                line=dict(width=4, color="#fafafa"),
            ),
            hole=0.75,
            direction="clockwise",
            sort=False,
            rotation=234,
            showlegend=False,
            hoverinfo="none",
            textinfo="none",
            textposition="outside",
        )
    )

    if score is not None:
        if display_score:
            fig.add_trace(
                go.Indicator(
                    mode="number", value=score, domain=dict(x=[0, 1], y=[0.28, 0.45])
                )
            )
        fig.add_shape(
            type="circle", fillcolor="black", x0=0.475, x1=0.525, y0=0.475, y1=0.525
        )
        fig.add_shape(
            type="path",
            fillcolor="black",
            line=dict(width=0),
            path=_generate_pointer_path(score),
        )

    return fig

def create_gauge_chart(score: int) -> Image:
    fig = gauge_and_needle_chart(score)
    fig.update_layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(t=20, b=20, l=20, r=20)
    )
    img_bytes = to_image(fig, format="png", scale=2)
    img = PILImage.open(io.BytesIO(img_bytes))
    img_buffer = io.BytesIO()
    img.save(img_buffer, format="PNG")
    img_buffer.seek(0)
    return Image(img_buffer, width=1.75*inch, height=1.75*inch)

def calculate_average_diversity_indexes(data: Dict[str, Any]) -> Tuple[float, float, List[str]]:
    # Average diversity indexes, only for generated columns
    text_diversity_scores = []
    gini_scores = []
    included_columns = []
    column_notes = data['results']['column_notes']
    try:
        for key, value in data['results']['feature_distribution']['score'].items():
            if isinstance(value, dict) and column_notes[key] == '':
                included_columns.append(key)
                if 'text_diversity_index' in value:
                    text_diversity_scores.append(value['text_diversity_index'])
                if 'gini_simpson_index' in value:
                    gini_scores.append(value['gini_simpson_index'])
    except Exception as e:
        print('Error calculating average diversity indexes:', e)

    avg_text_diversity = sum(text_diversity_scores) / len(text_diversity_scores) if text_diversity_scores else None
    avg_gini_index = sum(gini_scores) / len(gini_scores) if gini_scores else None

    return avg_text_diversity, avg_gini_index, included_columns

def plot_distributions(data_results: Dict[str, Any], column_subset: str, story: List[Any], styles):
    """
    Plot the distribution of each generated column or seed column.
    data_results: expects data['results]
    """
    assert column_subset in ['generated', 'seed'], f"column_subset must be 'generated' or 'seed', not '{column_subset}'"
    column_note = "" if column_subset == 'generated' else "Seed Column"
    plot_count = 0

    try:
        for key, distribution in data_results['feature_distribution']['distribution'].items():

            # Only plot distributions for the subset of columns
            if data_results['column_notes'][key] != column_note:
                continue
            
            data_type = data_results['column_data_types'][key] if 'column_data_types' in data_results else None
        
            if distribution and isinstance(distribution, dict):
                try:
                    if 'score' in data_results['feature_distribution'] and key in data_results['feature_distribution']['score']:
                        for score_key, score_value in data_results['feature_distribution']['score'][key].items():
                            section_title = key.replace('_', ' ') + ' Distribution (' +score_key.replace('_', ' ')+ ': ' + str(round(score_value, 2)) + ')'
                            section_title = section_title.replace('_', ' ').title()
                            story.append(Paragraph(section_title, styles['Heading2']))
                    
                    if data_type == 'Categorical':
                        dist_df = pd.DataFrame.from_dict(distribution, orient='index', columns=['count'])
                        dist_df['count'] = pd.to_numeric(dist_df['count'], errors='coerce')
                        dist_df = dist_df.dropna().sort_values('count', ascending=False)
                        
                        if not dist_df.empty:
                            # Handle large distributions
                            if len(dist_df) > 75:
                                other_count = dist_df.iloc[75:]['count'].sum()
                                dist_df = dist_df.iloc[:75]
                                dist_df.loc['Other'] = other_count
                            img = create_pareto_chart(dist_df, f"{key.replace('_', ' ').title()} Distribution (Pareto Chart)")
                        else:
                            continue
                    
                    elif data_type == 'Text':
                        counts, bins = distribution['word_count_histogram']
                        img = create_histogram(counts, bins, key, "Text")
                        
                    elif data_type == 'Numeric':
                        counts = distribution['histogram']
                        bins = distribution['bin_edges']
                        img = create_histogram(counts, bins, key, "Numeric")
                    else:
                        # Skip unsupported column types, e.g., 'Other', None
                        continue

                    story.append(img)
                    plot_count += 1

                    # Fit 2 plots per page
                    if plot_count % 2 == 0:
                        story.append(PageBreak())
                    else:
                        story.append(Spacer(1, 0.2*inch))
                except Exception as e:
                    story.append(Paragraph(f"Error processing {key} distribution: {str(e)}", styles['BodyText']))
    except KeyError as e:
        print(e)


def create_report_pdf(data: Dict[str, Any], dataset_df: pd.DataFrame, output_filename: str = 'enhanced_data_quality_report.pdf'):
    doc = SimpleDocTemplate(output_filename, pagesize=letter)
    styles = getSampleStyleSheet()
    
    chart_title_style = ParagraphStyle(
        name='ChartTitle', 
        parent=styles['BodyText'], 
        alignment=TA_CENTER,
        fontSize=8,
        leading=10
    )
    styles.add(chart_title_style)

    styles['Title'].fontSize = 24
    styles['Title'].alignment = 1
    styles['Title'].spaceAfter = 12
    styles['Title'].textColor = colors.HexColor('#110420')

    styles['Heading1'].fontSize = 18
    styles['Heading1'].spaceAfter = 6
    styles['Heading1'].textColor = colors.HexColor('#110420')

    styles['Heading2'].fontSize = 14
    styles['Heading2'].spaceBefore = 12
    styles['Heading2'].spaceAfter = 6
    styles['Heading2'].textColor = colors.HexColor('#110420')

    styles['BodyText'].fontSize = 10
    styles['BodyText'].spaceBefore = 6
    styles['BodyText'].spaceAfter = 6
    styles['BodyText'].textColor = colors.HexColor('#110420')

    styles.add(ParagraphStyle(name='RowPreview',
                              parent=styles['BodyText'],
                              fontName='Courier',
                              fontSize=8,
                              leading=10,
                              spaceAfter=12,
                              firstLineIndent=0,
                              leftIndent=20))
    
    story = []
    
    # Average diversity indexes, only for generated columns
    avg_text_diversity, avg_gini_index, included_columns = calculate_average_diversity_indexes(data)

    story.append(Paragraph("Data Quality Report", styles['Title']))
    story.append(Spacer(1, 0.2*inch))

    story.append(Paragraph("Key Metrics", styles['Heading1']))
    
    unique_rows_chart = create_gauge_chart(int(data['results']['row_uniqueness']['percent_unique']))
    semantically_unique_rows_chart = create_gauge_chart(int(data['results']['row_uniqueness']['percent_semantically_unique']))
    text_diversity_chart = create_gauge_chart(int(avg_text_diversity * 100) if avg_text_diversity else None)
    gini_simpson_chart = create_gauge_chart(int(avg_gini_index * 100) if avg_gini_index else None)

    unique_rows_title = Paragraph("Unique Rows", styles['ChartTitle'])
    semantically_unique_rows_title = Paragraph("Semantically Unique Rows", styles['ChartTitle'])
    text_diversity_title = Paragraph("Text Diversity", styles['ChartTitle'])
    gini_simpson_title = Paragraph("Gini-Simpson Diversity", styles['ChartTitle'])

    chart_table = Table([
        [unique_rows_title, semantically_unique_rows_title, text_diversity_title, gini_simpson_title],
        [unique_rows_chart, semantically_unique_rows_chart, text_diversity_chart, gini_simpson_chart]
    ])
    chart_table_style = TableStyle([
        ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
        ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
        ('BOTTOMPADDING', (0, 0), (-1, 0), 2), 
        ('TOPPADDING', (0, 1), (-1, -1), 3),
    ])
    chart_table.setStyle(chart_table_style)

    story.append(chart_table)

    # Dataset Overview
    story.append(Paragraph("Dataset Overview", styles['Heading1']))
    story.append(Paragraph("This section provides key metrics on the structure, uniqueness, complexity, and quality of the data.", styles['BodyText']))

    overview = data['dataset_overview_statistics']
    data_results = data['results']
    # Split the overview data into two tables to save space
    overview_data_1 = [
        ["Metric", "Value"],
        ["Data Completeness", f"{overview['data_completeness']}%"],
        ["Number of Rows", f"{overview['number_of_rows']}"],
        ["Number of Columns", f"{overview['number_of_columns']}"],
        ["Categorical Columns", f"{overview['number_of_categorical_columns']}"],
        ["Text Columns", f"{overview['number_of_text_columns']}"],
        ["Numerical Columns", f"{overview['number_of_numerical_columns']}"],
        ["Seed Columns", f"{overview['number_of_seed_columns']}"], 
    ]

    overview_data_2 = [
        ["Metric", "Value"],
        ["Unique Rows", f"{data['results']['row_uniqueness']['percent_unique']}%"],
        ["Semantically Unique Rows", f"{data_results['row_uniqueness']['percent_semantically_unique']}%"],
        ["Avg Words per Row", f"{data_results['num_words_per_record']['average_words_per_record']:.2f}"],
        ["Avg Tokens per Row", f"{data_results['num_words_per_record']['average_tokens_per_record']:.2f}"],
        ["Total Tokens", f"{data_results['num_words_per_record']['total_tokens']}"],
        ["Avg Text Diversity", f"{avg_text_diversity:.4f}" if avg_text_diversity else "N/A"],
        ["Avg Gini-Simpson Index", f"{avg_gini_index:.4f}" if avg_gini_index else "N/A"],
    ]

    overview_table_1 = create_overview_table(overview_data_1)
    overview_table_2 = create_overview_table(overview_data_2)
    overview_table_table = Table([[overview_table_1, overview_table_2]])
    story.append(overview_table_table)
    story.append(Spacer(1, 0.2*inch))

    # Single Row Preview
    # TODO: cut off the preview if it would spill into the next page?
    story.append(Paragraph("Single Row Preview", styles['Heading1']))
    preview_text = create_single_record_preview(data['dataset_overview_statistics']['single_row'])
    story.append(Paragraph(preview_text, styles['RowPreview']))
    story.append(PageBreak())

    # Dataset Schema
    story.append(Paragraph("Dataset Schema & Preview", styles['Heading1']))
    story.append(Paragraph("The schema table provides an overview of each column in the dataset.", styles['BodyText']))
    schema_table = create_schema_table(dataset_df, data_results)
    story.append(schema_table)
    story.append(Spacer(1, 0.2*inch))

    # column Cardinality
    if 'feature_cardinality' in data_results:
        story.append(Paragraph("Column Cardinality", styles['Heading1']))
        feature_cardinality = pd.DataFrame.from_dict(data_results['feature_cardinality'], orient='index', columns=['cardinality'])
        
        img = create_chart(feature_cardinality['cardinality'], "Column Cardinality", "Columns", "Cardinality")
        story.append(img)
        story.append(PageBreak())

    # Distribution Visualizations
    story.append(Paragraph("Seed Column Distributions", styles['Heading1']))
    plot_distributions(data_results, 'seed', story, styles)
    story.append(PageBreak())
    story.append(Paragraph("Generated Column Distributions", styles['Heading1']))
    plot_distributions(data_results, 'generated', story, styles)
    story.append(PageBreak())
    
    
    # Word Count per Column
    if 'word_counts_per_column' in data_results['num_words_per_record']:
        story.append(Paragraph("Average Word Count per Column", styles['Heading1']))
        word_count = pd.DataFrame.from_dict(data_results['num_words_per_record']['word_counts_per_column'], orient='index', columns=['avg_words'])
        word_count = word_count.sort_values('avg_words', ascending=False)
        
        img = create_chart(word_count['avg_words'], "Average Word Count per Column", "Columns", "Average Word Count")
        story.append(img)
        story.append(Spacer(1, 0.2*inch))

    # Text Diversity Indices
    if 'feature_distribution' in data_results and 'score' in data_results['feature_distribution']:
        text_diversity = {}
        for key, value in data_results['feature_distribution']['score'].items():
            if isinstance(value, dict) and 'text_diversity_index' in value:
                text_diversity[key] = value['text_diversity_index']
        if text_diversity:
            story.append(Paragraph("Text Diversity Indices", styles['Heading1']))
            text_diversity_df = pd.DataFrame.from_dict(text_diversity, orient='index', columns=['diversity_index'])
            img = create_text_diversity_chart(text_diversity_df)
            story.append(img)
            story.append(Spacer(1, 0.2*inch))
            

    # Conclusion
    story.append(Paragraph("Conclusion", styles['Heading1']))

    conclusion_text = "This report provides a comprehensive view of the dataset's structure, content diversity, and the nature of the data it contains. Key takeaways include:<br/>"

    # Data Uniqueness
    if 'row_uniqueness' in data_results:
        unique = data_results['row_uniqueness'].get('percent_unique', 'N/A')
        sem_unique = data_results['row_uniqueness'].get('percent_semantically_unique', 'N/A')
        conclusion_text += f"1. Data Uniqueness: With {unique}% unique rows and {sem_unique}% semantically unique rows, "
        if unique != 'N/A' and float(unique) > 90:
            conclusion_text += "the dataset shows a high degree of individuality in its rows. This suggests a rich and varied dataset.<br/><br/>"
        else:
            conclusion_text += "the dataset shows some level of repetition in its rows. This may indicate patterns or recurring themes in the data.<br/><br/>"

    # column Cardinality
    if 'feature_cardinality' in data_results:
        conclusion_text += "2. Column Cardinality: The dataset contains columns with varying cardinalities. "
        conclusion_text += "This diversity in column types allows for both granular analysis and higher-level pattern recognition.<br/><br/>"

    # Distribution Patterns
    if 'feature_distribution' in data_results and 'distribution' in data_results['feature_distribution']:
        conclusion_text += "3. Distribution Patterns: The charts reveal the distribution patterns within each column, "
        conclusion_text += "highlighting potential focus areas or biases in the data. Understanding these distributions "
        conclusion_text += "is crucial for balanced analysis and identifying underrepresented categories.<br/><br/>"

    # Text Complexity
    if 'num_words_per_record' in data_results:
        avg_words = data_results['num_words_per_record'].get('average_words_per_record', 'N/A')
        if avg_words != 'N/A':
            conclusion_text += f"4. Text Complexity: With an average of {avg_words:.2f} words per row, "
            if float(avg_words) > 50:
                conclusion_text += "the dataset shows a high level of complexity. "
            elif float(avg_words) > 20:
                conclusion_text += "the dataset shows a moderate level of complexity. "
            else:
                conclusion_text += "the dataset shows a low level of complexity. "
            conclusion_text += "This gives an indication of the depth of information contained in each row.<br/><br/>"

    # Text Diversity
    if 'feature_distribution' in data_results and 'score' in data_results['feature_distribution']:
        conclusion_text += "5. Text Diversity: The text diversity indices provide insight into the variety of content within text columns. "
        conclusion_text += "Higher diversity can be beneficial for tasks requiring a broad range of examples, while lower diversity "
        conclusion_text += "might indicate more standardized content.<br/><br/>"

    conclusion_text += """
    <b>Implications for Machine Learning:</b><br/><br/> 
 
    <b>Pre-training</b><br/> 
    - The dataset's uniqueness and diversity can provide a rich foundation for pre-training language models or other AI systems.<br/>
    - High cardinality columns may help in learning broad representations, while low cardinality columns could aid in learning important categorical distinctions.<br/>
    - If text diversity is high, it could be particularly valuable for building robust language models that can handle a wide range of contexts and styles.<br/><br/>

    <b>Fine-tuning:</b><br/> 
    - The distribution patterns revealed in the charts should guide the fine-tuning process. Imbalanced categories may require techniques like weighted sampling or loss adjustment to ensure equal representation during fine-tuning.<br/>
    - Columns with high semantic uniqueness could be especially useful for fine-tuning models on specific domains or tasks, as they likely contain a wide range of relevant examples.<br/>
    - Consider the average word count per row when deciding on sequence length for transformer-based models during fine-tuning.<br/><br/>

    <b>Designing/Iterating on Data to Fill Data Gaps:</b><br/> 
    - Analyze the distribution charts to identify underrepresented categories. These areas may require additional data collection or augmentation to ensure comprehensive model performance.<br/>
    - If certain text diversity scores are low, consider ways to introduce more variety in those columns, either through data augmentation techniques or targeted data collection.<br/>
    - For columns with very high cardinality, consider if grouping or categorization might be beneficial to prevent overfitting on rare categories.<br/>
    - If semantic uniqueness is low in certain areas, it might indicate a need for more diverse examples in those categories to improve model generalization.<br/><br/>

    <b>General Considerations:</b><br/> 
    - The overall uniqueness of the dataset impacts models that require diverse examples. However, care should be taken to address any imbalances revealed in the distribution charts.<br/>
    - Monitor for potential biases in the data that could be propagated or amplified by machine learning models.<br/>
    - Consider privacy implications, especially for high-cardinality columns that might contain identifiable information.<br/>
    - The text complexity (average words per row) should inform decisions about model architecture and preprocessing steps.<br/><br/>
    """

    story.append(Paragraph(conclusion_text, styles['BodyText']))

    metric_definition_text = f"""
    This section provides definitions for the metrics used in the report.<br/><br/>
    <b>Key Metrics</b><br/>
    Only generated columns requested by the user are included in the calculation of Key Metrics: {included_columns}. Helper columns like ID columns, seed columns or informational columns like code validation columns, data quality evaluation columns are excluded from Key Metrics calculation. <br/>
    • <b>Unique Rows:</b> Percentage of rows that are unique in the dataset.<br/>
    • <b>Semantically Unique Rows:</b> Percentage of rows that are semantically unique, based on TF-IDF.<br/>
    • <b>Text Diversity:</b> Average Text Diversity Index (defined below) across all text columns, with higher values indicating more diverse content.<br/>
    • <b>Gini-Simpson Diversity:</b> Average Gini-Simpson Index (defined below) across all categorical columns. Higher values indicating greater diversity.<br/><br/>

    <b>Dataset Overview</b><br/>
    The enhanced dataset overview provides key metrics about the structure, uniqueness, complexity, and quality of the data:<br/>
    • <b>Number of Rows and Columns:</b> Indicates the size and dimensionality of the dataset.<br/>
    • <b>Categorical and Numerical Columns:</b> Gives insight into the types of data present, helping to guide appropriate analysis techniques.<br/>
    • <b>Data Completeness:</b> Shows the overall percentage of non-null values across all columns, indicating the dataset's overall quality and potential need for imputation.<br/>
    • <b>Unique and Semantically Unique Rows:</b> Demonstrates the level of data diversity and potential redundancy in the dataset.<br/>
    • <b>Average Words per Row:</b> Provides an indication of the typical complexity or detail level of each entry.<br/>
    • <b>Average Tokens per Row and Total Tokens:</b> These metrics correspond to tokens used in Large Language Models (LLMs), giving an estimate of the dataset's complexity from an LLM processing perspective.<br/>
    • <b>Average Text Diversity:</b> Average Text Diversity Index (defined below) across all text columns, with higher values indicating more diverse content.<br/>
    • <b>Average Gini-Simpson Index:</b> Average Gini-Simpson Index (defined below) across all categorical columns. Higher values indicating greater diversity.<br/><br/>

    <b>Dataset Schema & Preview</b><br/>
    The schema table provides an overview of each column in the dataset, including the data type, the count of non-null and null values, and the average length (where applicable). This information is crucial for understanding the structure of the data and identifying potential data quality issues such as missing values or unexpected data types.<br/>
    • <b>Data Type:</b> Categorical, Numeric, Text or Other. Categorical columns are those whose percentage of unique values are low; Text columns are non-Categorical columns with at least 2 spaces per Row, on average.<br/>
    • <b>Total Count:</b> Total number of values in the Column.<br/>
    • <b>% Null:</b> Percentage of null values in the Column.<br/>
    • <b>Average Length:</b> Average character count of the values (for each text column).<br/>
    • <b>Avg Tokens:</b> Average number of tokens in the values (for each text column).<br/><br/>

    """

    if 'feature_cardinality' in data_results:
        metric_definition_text += """
        <b>Column Cardinality</b><br/>
        • <b>Column cardinality:</b> Represents the number of unique values for each column in the dataset. Higher cardinality indicates more diverse values within a Column.<br/><br/>

        """
    
    metric_definition_text += """
    <b>Column Distributions</b><br/>
    Column distributions show the frequency of different values within each column. These visualizations help identify common patterns, imbalances, or biases in the data.<br/>
    • <b>Pareto Chart:</b> A Pareto chart illustrates the distribution of domain in the dataset. The bars represent the count for each category, while the line shows the cumulative percentage. Only the top 75 categories are shown individually. The remaining categories are grouped as 'Other'. This visualization helps identify the most significant categories and their relative importance.<br/>
    • <b>Gini-Simpson Index:</b> A diversity index for categorical columns. It quantifies the probability that two values taken at random from the column (with replacement) are different. Higher values indicate greater diversity.<br/>
    • <b>Text Diversity Index:</b> A diversity index for text columns. It is defined as the average correlation between each row's TF-IDF vector and the dataset's TF-IDF matrix. Higher values indicate greater diversity.<br/><br/>
    """

    story.append(Paragraph("Metric Definitions", styles['Heading1']))
    story.append(Paragraph(metric_definition_text, styles['BodyText']))

    # Build the PDF
    doc.build(story)
    print(f"PDF created: {output_filename}")

In [None]:
create_report_pdf(results, dataset_df,  'data_quality_report.pdf')

In [53]:

# field = 'Example'
# counts = [7, 7, 7, 10, 6, 6, 3, 1, 1, 2]
# bins = [0.0, 27.7, 55.4, 83.1, 110.8, 138.5, 166.2, 193.9, 221.6, 249.29999999999998, 277.0]


In [None]:
test_str = "here's the {prompt} and here's the {response}"
test_str.format(prompt='prompt', response='response', context='context')

In [None]:
results['results']['num_words_per_record']

In [None]:
column_notes = results['results']['column_notes']
column_notes['id'] == ''