In [None]:
%%capture
!git clone https://github.com/gretelai/draw-data.git
!pip install -U ./draw-data

## Import the canvas and supporting libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

from ipywidgets import (HBox, VBox)

from sklearn.preprocessing import LabelEncoder

from draw.canvas import get_canvas

from gretel_client import configure_session
from gretel_client.projects.models import  read_model_config
from gretel_client.projects import create_or_get_unique_project
from gretel_client.helpers import poll

In this notebook we will create a conditional data generation model that can generate synthetic data based on a sketch we put on a canvas. So to start we configure our gretel session. Get the and paste your API key when prompted

In [None]:
configure_session(api_key="prompt", cache="yes", validate=True)

### We'll pull an example shopping dataset and subsample it for training time

In [None]:
train_df = pd.read_csv("https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/E-commerce+Dataset.csv")
train_df = train_df.dropna()
train_df = train_df.drop_duplicates()
train_df = train_df.sample(n=5000).reset_index(drop=True) # subsample

In [None]:
train_df.head()

## Setup for data editing

#### We seed on one column at a time, so you choose the column and set the variable to determine if it is a categorical value or not

Here we set up some parameters about our data to help the canvas "look nice" - you can leave these as their default values or change them for your dataset. 

Feel free to give feedback on anything that isn't intuitive! We'd love to fix it.

In [None]:
is_categorical = True
if is_categorical:
    column = 'Product_Category'
else:
    # Note, numerical seeding often works poorly
    column = 'Sales'

data_to_sample = 5000

if is_categorical:
    le = LabelEncoder()
    train_df[f"{column}_codes"] = le.fit_transform(train_df[column].values)
    original_column = column
    column = f"{column}_codes"
else:
    original_column = column
    column = column


# set desired x limits
domain_min = train_df[column].min()
domain_max = train_df[column].max()

if is_categorical:
    train_df[original_column].value_counts().plot.bar()
    plt.xticks(rotation=25)
else:
    sns.displot(train_df[original_column])

plt.xlabel(original_column)

# turn off y axis
plt.gca().yaxis.set_visible(False)

# get y axis limits from plot
y_lim = plt.gca().get_ylim()

x_labels = [x.get_text() for x in plt.gca().get_xticklabels()]
plt.show()

## We then train a gretel seed model - this should take a few minutes

In [None]:
project = create_or_get_unique_project(name="draw-your-own-data")
config = read_model_config("synthetics/default")
fields = [original_column] # just seeding on one column as an example
task = {"type": "seed", "attrs": {"fields": fields}}
config["models"][0]["synthetics"]["task"] = task

# Fit the model on the training set
train_df.to_csv("train.csv", index=False)
model = project.create_model_obj(model_config=config, data_source="train.csv")

model.submit_cloud()

poll(model)

synthetic = pd.read_csv(model.get_artifact_link("data_preview"), compression="gzip")
synthetic.head()

## We look at the synthetic data

He we see that the synthetic data distribution matches the original distribution very well. 

In [None]:
if is_categorical:
    synthetic[original_column].value_counts().plot.bar()
    plt.xticks(rotation=25)
else:
    sns.displot(synthetic[original_column])

plt.xlabel(original_column)

# turn off y axis
plt.gca().yaxis.set_visible(False)

plt.show()

## Draw your own data

The main point of this notebook is below. We can now change the distribution of our seed column to look how we want, just by drawing. There are many applications for this  - imagine editing how data changes over months of the year, or balancing a dataset between various classes

### Click 'Download' when you're happy with your drawing

In [None]:
canvas, clear_button, download_button = get_canvas(train_df, column, original_column, x_labels, is_categorical, domain_min, domain_max, data_to_sample, le)
HBox((canvas,
    VBox((clear_button, download_button))))

## Conditionally generate data using the "downloaded" drawn data above

In [None]:
# Use the model to generate additional synthetic data.
import os
assert os.path.exists("seeds.csv") == True
seeds = pd.read_csv("seeds.csv")

rh = model.create_record_handler_obj(
    data_source="seeds.csv", params={"num_records": len(seeds)}
)
rh.submit_cloud()

poll(rh)

synthetic_next = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")
synthetic_next

Plot the next data distribution, it looks like what you drew!

In [None]:
if is_categorical:
    synthetic_next[original_column].value_counts().loc[x_labels].plot.bar()

    plt.xticks(rotation=25)
else:
    sns.displot(synthetic_next[original_column])

plt.xlabel(original_column)

# turn off y axis
plt.gca().yaxis.set_visible(False)

plt.show()

Now we have data where one column follows the data distribution we drew above! Enjoy and share!

In [None]:
synthetic_next.head()