In [None]:
#!pip install --upgrade data_repo_client

In [9]:
# Imports
import import_ipynb
import ingest_pipeline_utilities as utils
import data_repo_client
from google.cloud import bigquery
import pandas as pd
pd.set_option('display.max_rows', None)

# TDR Reader Management

## Remove Undesired Readers from TDR Datasets

In [None]:
# Function to remove erroneous readers from snapshot
def clean_up_ad_readers(snapshot_id, readers):
    print("Cleaning up readers for {}...".format(snapshot_id))
    reader_list = readers
    api_client = utils.refresh_tdr_api_client()
    snapshots_api = data_repo_client.SnapshotsApi(api_client=api_client)
    # Retrieve snapshot, grab auth_domain
    if '$AUTH_DOMAIN' in reader_list:
        snapshot_response = snapshots_api.retrieve_snapshot(id=snapshot_id)
        snapshot_name = snapshot_response.name
        print("Snapshot name: {}".format(snapshot_name))
        try:
            auth_domain_list = snapshot_response.source[0].dataset_properties["auth_domains"]
        except:
            auth_domain_list = []
        for ad in auth_domain_list:
            reader_list.append(ad + "@firecloud.org")

    # Retrieve snapshot policies and delete readers that aren't in reader list
    snapshot_policy_response = snapshots_api.retrieve_snapshot_policies(id=snapshot_id)
    delete_count = 0
    for policy in snapshot_policy_response.policies:
        if policy.name == "reader":
            for policymember in policy.members:
                if policymember not in reader_list:
                    api_client = utils.refresh_tdr_api_client()
                    snapshots_api = data_repo_client.SnapshotsApi(api_client=api_client)
                    retry_count = 0
                    while retry_count < 1:
                        try:
                            delete_response = snapshots_api.delete_snapshot_policy_member(id=snapshot_id, policy_name="reader", member_email=policymember)
                            delete_count += 1
                            break
                        except:
                            retry_count += 1
                        
    # Print results
    snapshot_policy_response = snapshots_api.retrieve_snapshot_policies(id=snapshot_id)
    print(f"\t{delete_count} erroneous readers deleted.")
    
    for role in snapshot_policy_response.policies:
        if role.name == "reader":
            rem_readers = ", ".join(role.members)
            print(f"\tRemaining readers: {rem_readers}")
    return 

# Clean-up snapshots
reader_list = ["azul-anvil-prod@firecloud.org"]#, '$AUTH_DOMAIN']
snapshot_id_list = [
'b0fc6253-d274-4e53-9977-85d943116f7c',
]
for snapshot_id in snapshot_id_list:
    clean_up_ad_readers(snapshot_id, reader_list)


## Add Auth Domain Users to TDR Datasets

In [None]:
# Function to remove add readers to snapshot


In [None]:
snapshot_id = '5567b767-9fae-4d12-9242-732c2d436eab'
api_client = utils.refresh_tdr_api_client()
snapshots_api = data_repo_client.SnapshotsApi(api_client=api_client)
snapshot_policy_response = snapshots_api.retrieve_snapshot_policies(id=snapshot_id)
snapshot_policy_response.policies

# Snapshot Row Count Collection

In [None]:
def return_row_counts(snapshot_id, results_list):
    # Grab access information from schema
    api_client = utils.refresh_tdr_api_client()
    snapshots_api = data_repo_client.SnapshotsApi(api_client=api_client)
    try:
        response = snapshots_api.retrieve_snapshot(id=snapshot_id, include=["TABLES", "ACCESS_INFORMATION"]).to_dict()
        tdr_schema_dict = {}
        tdr_schema_dict["tables"] = response["tables"]
        bq_project = response["access_information"]["big_query"]["project_id"]
        bq_dataset = response["access_information"]["big_query"]["dataset_name"]
    except:
        results_list.append([snapshot_id, 0])
        return results_list
    
    # Build row count query
    table_set = set()
    table_count = 0
    row_count_subquery = ""
    for table_entry in tdr_schema_dict["tables"]:
        table_set.add(table_entry["name"])
    for table_entry in table_set:
        table_count += 1
        if table_count == 1:
            row_count_subquery += "SELECT datarepo_row_id FROM `{project}.{dataset}.{table}` ".format(project=bq_project, dataset=bq_dataset, table=table_entry)
        else:
            row_count_subquery += "UNION ALL SELECT datarepo_row_id FROM `{project}.{dataset}.{table}` ".format(project=bq_project, dataset=bq_dataset, table=table_entry)
    row_count_query = "SELECT COUNT(*) AS row_count FROM ({subquery})".format(subquery=row_count_subquery)
    
    # Execute query and write results to results dict
    try:
        client = bigquery.Client()
        df_results = client.query(row_count_query).result().to_dataframe()
        row_count = df_results["row_count"].values[0]
        results_list.append([snapshot_id, row_count])
    except:
        results_list.append([snapshot_id, 0])
    return results_list
    
# Loop through snapshots and collect row counts
results_list = []
snapshot_id_list = [
'bb7eaad8-b02c-455c-964d-c9242019d9e5',
]
for snapshot_id in snapshot_id_list:
    results_list = return_row_counts(snapshot_id, results_list)
    
# Convert results to dataframe and display
results_df = pd.DataFrame(results_list, columns = ["snapshot_id", "row_count"])
display(results_df)


In [None]:
display(results_df)