# Using the Grouping FlatMap in Safetab

### Import Libraries

In [None]:
from pyspark.sql import SparkSession
import pandas as pd

from tmlt.analytics.privacy_budget import PureDPBudget
from tmlt.analytics.session import Session
from tmlt.analytics.query_builder import QueryBuilder

### Load Data and Set Up the Session

This example uses a dummy dataset with QRACE and REGION to help illustrate a Safetab-like example. Here a session is set up with this dummy dataset. Then, a dataframe that maps iteration codes to detailed_only=True and detailed_only=False (like in safetab) is added as a public data source.

In [None]:
spark = SparkSession.builder.getOrCreate()

# Set up private df.
sdf = spark.createDataFrame(
    pd.DataFrame(
        [["1", "01"], ["2", "02"], ["3", "03"], ["4", "04"]], columns=["QRACE", "REGION"]
    )
)

# Set up Session
budget = PureDPBudget(10)
session = Session.from_dataframe(
    privacy_budget=budget,
    source_id="private",
    dataframe=sdf
)

# For Safetab -- Set up detailed-only public df
detailed_only_df = spark.createDataFrame(
    pd.DataFrame(
       [["1", "True"], ["2", "False"], ["3", "True"], ["4", "False"]], columns=["ITERATION_CODE", "DETAILED_ONLY"]
    )
)
session.add_public_dataframe("detail", detailed_only_df)

### Grouping FlatMap

The iteration code flatmap is built by mapping QRACE codes to iteration codes. Here the argument "grouping" is set to True. This requires that any groupby aggregation on this query includes ITERATION_CODE (the new column) as a groupby column. 

Since a view is then created, the required grouping column is passed to the view and continues to be required on any query off of this view.

In [None]:
race_mapping = {"1": ["1", "2"], "2": ["2", "3"], "3": ["3", "4"], "4": ["4"]}
iteration_code_flatmap = QueryBuilder("private").flat_map(
    lambda x: [
        {"ITERATION_CODE": iter_code} for iter_code in race_mapping[x["QRACE"]]
    ],
    max_num_rows=2,
    new_column_types={"ITERATION_CODE": "VARCHAR"},
    augment=True,
    grouping=True
).join_public("detail")

# Create a view with the flatmap applied.
session.create_view(iteration_code_flatmap, "private_with_iteration", cache=True)

print("View Grouping Column:", session.get_grouping_column("private_with_iteration"))

### Partition the Session into DetailedOnly and NonDetailedOnly

When the view is partitioned into detailed_only and non_detailed_only sessions, the private sources of these sessions also retain the required grouping column.

In [None]:
new_sessions = session.partition_and_create(
    "private_with_iteration", 
    privacy_budget=budget, 
    attr_name="DETAILED_ONLY", 
    splits = {"detailed_only": "True", "non_detailed_only": "False"}
)

detailed_session = new_sessions["detailed_only"]
non_detailed_session = new_sessions["non_detailed_only"]

print("Detailed Session Grouping Column:", detailed_session.get_grouping_column("detailed_only"))
print("Non-detailed Session Grouping Column:", non_detailed_session.get_grouping_column("non_detailed_only"))

### Do the Detailed Only Query

One of the queries in safetab is on the detailed_only table. Due to the grouping flatmap, this query requires that ITERATION_CODE is included as a groupby column. If ITERATION_CODE is not included, it will result in an error.

In [None]:
detailed_only_query = (
    QueryBuilder("detailed_only")
    .groupby_domains(
        {
            "REGION": ["01", "02", "03", "04"],
            "ITERATION_CODE": ["1", "2", "3", "4"],
        }
    )
    .count(name="COUNT")
)

answer = detailed_session.evaluate(
    query_expr=detailed_only_query, 
    privacy_budget=PureDPBudget(budget.epsilon/2)
)

answer.show()