<a target="_blank" href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/docs/notebooks/transform_relational_database.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Transform a Database with Gretel Relational

This notebook uses [Gretel Relational Transform](https://docs.gretel.ai/reference/relational) to redact Personal Identifiable Information (PII) in a sample telecommunications database. Try running the example below and compare the transformed vs real world data.

<img src="https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/telecom_db.png"  width="70%" height="70%">

## Getting Started

In [None]:
%%capture
!pip install -U gretel-trainer

In [None]:
from gretel_trainer.relational import *

In [None]:
from gretel_client import configure_session

configure_session(api_key="prompt", cache="yes", validate=True)

In [None]:
# Download sample database
!wget https://gretel-blueprints-pub.s3.amazonaws.com/rdb/telecom.db

## Define Source Data

### Input data via database connector
For information on connecting to your own database using one of our 30+ connectors, [check out our docs](https://docs.gretel.ai/reference/relational/database-connectors).

In [None]:
# Input data from database
from gretel_trainer.relational import sqlite_conn

db_path = "telecom.db"
conn = sqlite_conn(db_path)
relational_data = conn.extract()

### Alternatively, manually define data from CSVs

In [None]:
#@title
# Alternatively, manually define relational data
# Uncomment code to run cell

# from gretel_trainer.relational import RelationalData
# import pandas as pd

# csv_dir = "/path/to/extracted_csvs"

# tables = [
#     ("events", "id"),
#     ("users", "id"),
#     ("distribution_center", "id"),
#     ("products", "id"),
#     ("inventory_items", "id"),
#     ("order_items", "id"),
# ]

# foreign_keys = [
#     ("events.user_id", "users.id"),
#     ("order_items.user_id", "users.id"),
#     ("order_items.inventory_item_id", "inventory_items.id"),
#     ("inventory_items.product_id", "products.id"),
#     ("inventory_items.product_distribution_center_id", "distribution_center.id"),
#     ("products.distribution_center_id", "distribution_center.id"),
# ]

# relational_data = RelationalData()

# for table, pk in tables:
#     relational_data.add_table(name=table, primary_key=pk, data=pd.read_csv(f"{csv_dir}/{table}.csv"))

# for fk, ref in foreign_keys:
#     relational_data.add_foreign_key(foreign_key=fk, referencing=ref)

In [None]:
#@title Preview source data
#@markdown #### Confirm referential integrity by joining two tables
#@markdown Every record in the child table matches a distinct record in the parent table. Therefore, the number of records in the joined data will match the number of records in the child table.


from IPython.display import display, HTML

def join_tables(parent: str, child: str, relational_data: RelationalData, tableset=None):
  p_key = relational_data.get_primary_key(parent)
  f_key = ""

  # For simplicity, if a child has multiple foreign keys to a parent, just pick one of them
  child_foreign_keys = {
    fk.parent_table_name: fk.column_name
    for fk in relational_data.get_foreign_keys(child)
  }
  if parent in child_foreign_keys:
    f_key = child_foreign_keys[parent]
  else:
    logging.warning("The input parent and child table must be related.")
  
  if tableset is None:
    parent_df = relational_data.get_table_data(parent)
    child_df = relational_data.get_table_data(child)
  else:
    parent_df = tableset[parent]
    child_df = tableset[child]

  joined_data = child_df.merge(parent_df, how="left", left_on=f_key, right_on=p_key)

  print(f"Number of records in {child} table:\t {len(child_df)}")
  print(f"Number of records in {parent} table:\t {len(parent_df)}")
  print(f"Number of records in joined data:\t {len(joined_data)}")

  return joined_data.head()


parent_table = "client" #@param {type:"string"}
child_table = "account" #@param {type:"string"}

print("\033[1m Source Data: \033[0m")
source_data = join_tables(parent_table, child_table, relational_data)
display(source_data)  


## Create Project

During this step, you will be prompted to input your API key, which can be found in the [Gretel Console](https://console.gretel.ai/users/me/key).

In [None]:
from gretel_trainer.relational import MultiTable

multitable = MultiTable(
    relational_data,
    project_display_name="Transform Telecom Database",
    #refresh_interval=60
)

## Transform Database

### Set Transform configuration

In [None]:
from gretel_client.projects.models import read_model_config

configs = {}
for table in relational_data.list_all_tables():
    configs[table] = read_model_config("https://raw.githubusercontent.com/gretelai/gdpr-helpers/main/src/config/transforms_config.yaml")


## Transform Database

In [None]:
multitable.train_transform_models(configs=configs)
multitable.run_transforms()

## View Results

In [None]:
#@title Compare an Individual Table
table = "location" #@param {type:"string"}
from IPython.display import display, HTML

source_table = multitable.relational_data.get_table_data(table).head(5)
trans_table = multitable.transform_output_tables[table][source_table.columns].head(5)
print("\033[1m Source Table:")
display(source_table)
print("\n\n\033[1m Transformed Table:")
display(trans_table)

In [None]:
#@title Examine joined tables to confirm referential integrity
#@markdown As with the original data, every record in the transformed child table matches a distinct record in its transformed parent table. The number of records in the joined data matches the number of records in the child table, confirming referential integrity has been maintained in the transformed database.
import logging 
from IPython.display import display, HTML


parent_table = "client" #@param {type:"string"}
child_table = "account" #@param {type:"string"}

print("\n\n\033[1m Transformed Data:\033[0m")
display(join_tables(parent_table, child_table, multitable.relational_data, multitable.transform_output_tables)[source_data.columns])

In [None]:
#@title Accessing Output Files
#@markdown All of the Relational Transform output files can be found in your local working directory. Additionally, you can download the outputs as a single archive file from the Gretel Console using this URL:
console_url = f"https://console.gretel.ai/{multitable._project.name}/data_sources"
print(console_url)

## [Optional] Save Transformed Data to a Database


In [None]:
output_db_path = "transformed_telecom.db"
output_conn = sqlite_conn(output_db_path)
output_conn.save(
    multitable.transform_output_tables,
    prefix="trans_"
    )