# Synthesize a Database with Gretel Relational

This notebook uses Gretel Relational Synthetics to synthesize a sample telecommunications database. Try running the example below and compare the synthetic vs real world data for the example database. 

<img src="https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/telecom_db.png"  width="70%" height="70%">

## Getting Started

In [None]:
%%capture
!pip install -U gretel-trainer

In [None]:
from gretel_trainer.relational import *

In [None]:
# Download sample database
!wget https://gretel-blueprints-pub.s3.amazonaws.com/rdb/telecom.db

## Define Source Data

### Input data via database connector

In [None]:
# Input data from database
from gretel_trainer.relational import sqlite_conn

db_path = "telecom.db"
sqlite = sqlite_conn(path=db_path)
relational_data = sqlite.extract()


### Alternatively, manually define data from CSVs 


In [None]:
# @title
# #Alternatively, manually define relational data
# #Uncomment code to run


# from gretel_trainer.relational import RelationalData
# import pandas as pd

# csv_dir = "/content"

# tables = [
#     #("table_name", "primary_key")
#     ("account", "account_id"),
#     ("client", "client_id"),
#     ("invoice", "invoice_id"),
#     ("location", "location_id"),
#     ("payment", "payment_id"),
# ]

# foreign_keys = [
#     #("fkey_table.fkey", "pkey_table.pkey")
#     ("account.client_id", "client.client_id"),
#     ("location.client_id", "client.client_id"),
#     ("invoice.account_id", "account.account_id"),
#     ("payment.invoice_id", "invoice.invoice_id"),
# ]

# relational_data = RelationalData()

# for table, pk in tables:
#     relational_data.add_table(name=table, primary_key=pk, data=pd.read_csv(f"{csv_dir}/{table}.csv"))

# for fk, ref in foreign_keys:
#     relational_data.add_foreign_key(foreign_key=fk, referencing=ref)

In [None]:
#@title Preview source data
#@markdown #### Confirm referential integrety by joining two tables
#@markdown Every record in the child table matches a distinct record in the parent table. Therefore, the number of records in the joined data will match the number of records in the child table.


from IPython.display import display, HTML

def join_tables(parent: str, child: str, relational_data=relational_data):
  p_key = relational_data.get_primary_key(parent)
  f_key = ""
  for fk in relational_data.get_foreign_keys(child):
    if fk.parent_table_name==parent:
      f_key=fk.column_name
    else:
      logging.warning("The input parent and child table must be related.")
  
  parent_df = relational_data.get_table_data(parent)
  child_df = relational_data.get_table_data(child)

  joined_data = child_df.merge(parent_df, how="left", left_on=p_key, right_on=f_key)

  print(f"Number of records in {child} table:\t {len(child_df)}")
  print(f"Number of records in {parent} table:\t {len(parent_df)}")
  print(f"Number of records in joined data:\t {len(joined_data)}")

  return joined_data.head()


parent_table = "client" #@param {type:"string"}
child_table = "account" #@param {type:"string"}

print("\033[1m Source Data: \033[0m")
source_data = join_tables(parent_table, child_table)
display(source_data)

## Set up Relational Model and Create Project

During this step, you will be prompted to input your API key, which can be found in the [Gretel Console](https://console.gretel.ai/users/me/key).

Relational Synthetics will use Amplify by default. Alternatively, you can set `gretel_model="actgan"` or `gretel_model="lstm"`. 

In [None]:
from gretel_trainer.relational import MultiTable

multitable = MultiTable(
    relational_data,
    project_display_name="Synthesize Telecom Database",
    #gretel_model="amplify",
    #refresh_interval=60
)

## Synthesize Database

In [None]:
multitable.train()

In [None]:
multitable.generate(record_size_ratio=1)       # To adjust the amount of data generated, change record_size_ratio parameter
    

## View Results

In [None]:
#@title Compare an individual table
table = "payment" #@param {type:"string"}
from IPython.display import display, HTML


source_table = multitable.relational_data.get_table_data(table).head(5)
synth_table = multitable.synthetic_output_tables[table][source_table.columns].head(5)
print("\033[1m Source Table:")
display(source_table)
print("\n\n\033[1m Synthesized Table:")
display(synth_table)

In [None]:
#@title Examine joined tables to confirm referential integrity
#@markdown As with the original data, every record in the synthesized child table matches a distinct record in its synthesized parent table. The number of records in the joined data matches the number of records in the child table, confirming referential integrity has been maintained in the synthetic database.
import logging 
from IPython.display import display, HTML

def join_synth_tables(parent: str, child: str, multitable=multitable): 
  p_key = multitable.relational_data.get_primary_key(parent)
  f_key = ""
  for fk in multitable.relational_data.get_foreign_keys(child):
    if fk.parent_table_name==parent:
      f_key=fk.column_name
    else:
      logging.warning("The input parent and child table must be related.")
  
  parent_df = multitable.synthetic_output_tables[parent]
  child_df = multitable.synthetic_output_tables[child]

  joined_data = child_df.merge(parent_df, how="left", left_on=p_key, right_on=f_key)

  print(f"Number of records in {child} table:\t {len(child_df)}")
  print(f"Number of records in {parent} table:\t {len(parent_df)}")
  print(f"Number of records in joined data:\t {len(joined_data)}")
  return joined_data.head()


parent_table = "client" #@param {type:"string"}
child_table = "account" #@param {type:"string"}

print("\n\n\033[1m Synthesized Data:\033[0m")
display(join_synth_tables(parent_table, child_table)[source_data.columns])

### View Relational Synthetic Report

In [None]:
# View relational report
import IPython
from smart_open import open

report_path = str(multitable._working_dir / multitable._synthetics_run.identifier / "relational_report.html")

IPython.display.HTML(data=open(report_path).read())

## [Optional] Save Synthesized Data to a Database

In [None]:
output_db_path = "synthetic_telecom.db"
output_conn = sqlite_conn(output_db_path)
output_conn.save(
    multitable.synthetic_output_tables,
    prefix="synth_"
    )