In [None]:
%%capture
# Install the gretel-client
!pip install -U gretel-client

In [None]:
import boto3
import os
import json
import time

import pandas as pd

from botocore.exceptions import ClientError
from gretel_client import configure_session
from gretel_client.projects.models import read_model_config
from gretel_client.projects import create_or_get_unique_project
from gretel_client.helpers import poll
from smart_open import open

## 1. Gather data files from the s3 source bucket

In [None]:
def get_unprocessed_files_with_extension(source_bucket, dest_bucket, extension='.csv'):

    s3 = boto3.client('s3')
    response = s3.list_objects_v2(Bucket=source_bucket)

    files = []
    for obj in response['Contents']:
        key = obj['Key']
        if key.endswith(extension):
            try:
                s3.head_object(Bucket=dest_bucket, Key=f'{os.path.splitext(key)[0]}_synth.csv')
            except:
                files.append(key)
    return files

In [None]:
# retrieve data
source_bucket = 'gretel-source-data-bucket'
dest_bucket = 'gretel-destination-data-bucket'
extension = '.csv'

s3_files = get_unprocessed_files_with_extension(source_bucket, dest_bucket, extension=extension)

gretel_dict = {}
for s3_file in s3_files:
    key = os.path.splitext(s3_file)[0]
    df = pd.read_csv(f's3://{source_bucket}/{s3_file}')
    gretel_dict[key] = {}
    gretel_dict[key]['data_source'] = f's3://{source_bucket}/{s3_file}'
    gretel_dict[key]['nb_rows'] = len(df)
    gretel_dict[key]['nb_cols'] = len(df.columns)

display(gretel_dict)

## 2. Run Gretel Transform+Synthetics

In [None]:
# Define some helper functions

def get_secret():

    secret_name = "prod/Gretel/ApiKey"
    region_name = "us-east-1"

    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    # Decrypts secret using the associated KMS key.
    secret = json.loads(get_secret_value_response['SecretString'])

    return secret["gretelApiKey"]


def track_status_greteljob(job_dict,
                           project_name,
                           model_id_key,
                           record_id_key=None):

    project = create_or_get_unique_project(name=project_name)

    for key in job_dict:
        model_id = job_dict[key][model_id_key]
        model = project.get_model(model_id)
        if record_id_key:
            record_id = job_dict[key][record_id_key]
            job = model.get_record_handler(record_id)
            job_id = record_id
        else:
            job = model
            job_id = model_id

        while True:
            job.refresh()
            if job.status != 'completed':
                print(key, job_id, job.status)
                time.sleep(30)
            else:
                print(f"Processing {key} with {job_id} is complete.")
                break


def gretel_transform_train(job_dict,
                           project_name,
                           data_source_key,
                           config=None):

    project = create_or_get_unique_project(name=project_name)

    if not config:
        config = read_model_config(
            "https://raw.githubusercontent.com/gretelai/gdpr-helpers/main/src/config/transforms_config.yaml"
        )

    for key in job_dict:
        model = project.create_model_obj(
            model_config=config, data_source=job_dict[key][data_source_key]
        )
        model.name = f"transform-{key}"
        model.submit_cloud()
        job_dict[key]['transform_model_id'] = model.model_id

    print("Status of transform training jobs:")
    track_status_greteljob(
        job_dict,
        project_name=project_name,
        model_id_key="transform_model_id"
    )


def gretel_transform_run(job_dict,
                         project_name,
                         data_source_key,
                         config=None):

    project = create_or_get_unique_project(name=project_name)

    for key in job_dict:
        model_id = job_dict[key]['transform_model_id']
        model = project.get_model(model_id)
        record_handler = model.create_record_handler_obj(
            data_source=job_dict[key][data_source_key],
        )
        record_handler.submit_cloud()
        job_dict[key]['transform_record_id'] = record_handler.record_id

    print("Status of transform run jobs:")
    track_status_greteljob(
        job_dict,
        project_name=project_name,
        model_id_key="transform_model_id",
        record_id_key="transform_record_id"
    )

    # Store the de-identified data
    for key in job_dict:
        model_id = job_dict[key]["transform_model_id"]
        model = project.get_model(model_id)
        record_id = job_dict[key]["transform_record_id"]
        record_handler = model.get_record_handler(record_id)
        job_dict[key]['deidentified_data_source'] = pd.read_csv(
            record_handler.get_artifact_link("data"), compression="gzip"
        )


def gretel_synthetics_train(job_dict,
                            project_name,
                            data_source_key,
                            config=None):

    project = create_or_get_unique_project(name=project_name)

    if not config:
        config = read_model_config("synthetics/tabular-actgan")

    for key in job_dict:
        model = project.create_model_obj(
            model_config=config, data_source=job_dict[key][data_source_key]
        )
        model.name = f"synthetics-{key}"
        model.submit_cloud()
        job_dict[key]['model_id'] = model.model_id

    print("Status of synthetics training jobs:")
    track_status_greteljob(
        job_dict,
        project_name=project_name,
        model_id_key="model_id"
    )


def gretel_synthetics_run(job_dict,
                          project_name,
                          num_records=None):

    project = create_or_get_unique_project(name=project_name)

    for key in job_dict:
        model_id = job_dict[key]['model_id']
        model = project.get_model(model_id)
        if not num_records:
            num_records = job_dict[key]['nb_rows']
        record_handler = model.create_record_handler_obj(
            params={"num_records": num_records},
        )
        record_handler.submit_cloud()
        job_dict[key]['record_id'] = record_handler.record_id

    print("Status of synthetics run jobs:")
    track_status_greteljob(
        job_dict,
        project_name=project_name,
        model_id_key="model_id",
        record_id_key="record_id"
    )

In [None]:
# Configure a Gretel session

GRETEL_PROJECT_NAME = 'aws-lambda-gretel-project'

GRETEL_API_KEY = get_secret()
configure_session(api_key=GRETEL_API_KEY, cache="yes", validate=True)

In [None]:
# Use Gretel Transform to de-identify the data

gretel_transform_train(
    gretel_dict,
    data_source_key="data_source",
    project_name=GRETEL_PROJECT_NAME
)

gretel_transform_run(
    gretel_dict,
    data_source_key="data_source",
    project_name=GRETEL_PROJECT_NAME
)

In [None]:
# Load and modify Gretel Actan config
config = read_model_config("synthetics/tabular-differential-privacy")

gretel_synthetics_train(
    gretel_dict,
    data_source_key="deidentified_data_source",
    project_name=GRETEL_PROJECT_NAME,
    config=config
)

gretel_synthetics_run(
    gretel_dict,
    project_name=GRETEL_PROJECT_NAME
)

## 3. Write artifacts to the s3 destination bucket

In [None]:
# save the SQS reports to s3 destination bucket

s3 = boto3.client('s3')

for key in gretel_dict:
    model_id = gretel_dict[key]['model_id']
    project = create_or_get_unique_project(name=GRETEL_PROJECT_NAME)
    model = project.get_model(model_id)

    html_data = open(model.get_artifact_link("report")).read()
    s3.put_object(
        Body=html_data,
        Bucket=dest_bucket,
        Key=f'{key}_report.html'
    )

    # save SQS report summary
    s3.put_object(
         Body=json.dumps(model.get_report_summary()),
         Bucket=dest_bucket,
         Key=f'{key}_report_summary.json'
      )

    record_id = gretel_dict[key]['record_id']
    rh = model.get_record_handler(record_id)
    synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")
    synth_df.to_csv(f's3://{dest_bucket}/{key}_synth.csv', index=0)