In [None]:
%additional_python_modules backoff,aws-requests-auth


In [None]:
from datetime import datetime
from typing import Union, List

from pyspark.sql import DataFrame
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import col, from_unixtime


class LabelsLogLoader(object):

    def __init__(self, spark: SparkSession, log_path: str):
        self.__spark = spark
        self.__log_path = log_path
        self.__filter_source_types = None
        self.__filter_source_ids = None
        self.__filter_labels = None
        self.__filter_applicants = None
        self.__filter_duplicates = None
        self.__filter_before_timestamp = None
        self.__filter_after_timestamp = None
        self.__drop_duplicates = None

    def filter_by_source_type(self, source_types: Union[Union[str, None], List[Union[str, None]]]):
        r"""
        Filters by source types

        :param source_types: Source type or list of source types to filter by. None is also supported.
                             To search only when no source type was specified, pass a list with None inside.
                             Passing just None (outside of list) will unset the filter
        """
        if not isinstance(source_types, list) and source_types is not None:
            source_types = [source_types]
        self.__filter_source_types = source_types

    def filter_by_source_id(self, source_ids: Union[Union[str, None], List[Union[str, None]]]):
        r"""
        Filters by source ids

        :param source_ids: Source id or list of source ids to filter by. None is also supported.
                             To search only when no source id was specified, pass a list with None inside.
                             Passing just None (outside of list) will unset the filter
        """
        if not isinstance(source_ids, list) and source_ids is not None:
            source_ids = [source_ids]
        self.__filter_source_ids = source_ids

    def filter_by_label(self, labels: Union[Union[str, None], List[Union[str, None]]]):
        r"""
        Filters by labels

        :param labels: Label or list of labels to filter by. None is also supported.
                       To search only when no label was specified, pass a list with None inside.
                       Passing just None (outside of list) will unset the filter
        """
        if not isinstance(labels, list) and labels is not None:
            labels = [labels]
        self.__filter_labels = labels

    def filter_by_applicant(self, ids: Union[str, List[str], None]):
        r"""
        Filters by applicant being evaluated (labeled)

        :param ids: Id or list of ids to filter by.
                    Passing just None (outside of list) will unset the filter
        """
        if not isinstance(ids, list) and ids is not None:
            ids = [ids]
        self.__filter_applicants = ids

    def filter_by_duplicate(self, ids: Union[str, List[str], None]):
        r"""
        Filters by duplicate being matched (labeled)

        :param ids: Id or list of ids to filter by.
                    Passing just None (outside of list) will unset the filter
        """
        if not isinstance(ids, list) and ids is not None:
            ids = [ids]
        self.__filter_duplicates = ids

    def __to_timestamp(self, timestamp: Union[None, int, datetime]) -> Union[None, int]:
        r"""
        Converts input to valid timestamp value for filters

        :param timestamp: Timestamp input
        :return: Timestamp output
        """
        if isinstance(timestamp, datetime):
            return round(timestamp.timestamp() * 1000)
        return timestamp

    def filter_after_timestamp(self, timestamp: Union[None, int, datetime]):
        r"""
        Filters logs that occurred after specific timestamp

        :param timestamp: Timestamp or None to unset
        """
        self.__filter_after_timestamp = self.__to_timestamp(timestamp)

    def filter_before_timestamp(self, timestamp: Union[None, int, datetime]):
        r"""
        Filters logs that occurred before specific timestamp

        :param timestamp: Timestamp or None to unset
        """
        self.__filter_before_timestamp = self.__to_timestamp(timestamp)

    def drop_duplicates(self, drop_duplicates: bool = True):
        r"""
        Should duplicates be dropped. Duplicate entries are those with same applicant and match id.
        Only latest entry is kept.

        :param drop_duplicates: If True (or non specified) will drop duplicates. If False will ignore this behaviour.
        """
        self.__drop_duplicates = drop_duplicates

    def load(self, date_field: bool = False, time_field: bool = False) -> DataFrame:
        r"""
        Loads logs and applies filters

        Returned DataFrame has following schema
        root
         |-- timestamp: string (nullable = true)
         |-- label: string (nullable = true)
         |-- recordId: string (nullable = true)
         |-- applicantId: string (nullable = true)
         |-- employeeId: string (nullable = true)
         |-- duplicateId: string (nullable = true)
         |-- sourceType: string (nullable = true)
         |-- sourceId: string (nullable = true)

        If `date_field` is set to true then field representing date from timestamp in format `yyyy-MM-dd` is added
          |-- date: string (nullable = true)

        If `date_time` is set to true then field representing time from timestamp in format `HH:mm:ss` is added
          |-- time: string (nullable = true)

        :param date_field: Should field `date` be added
        :param time_field: Should field `time` be added

        :return: Dataframe of logs
        """

        condition = None

        # Update source type condition
        if self.__filter_source_types is not None:
            filter_source_type_condition = col('sourceType').isin([m for m in self.__filter_source_types if m is not None])
            if None in self.__filter_source_types:
                filter_source_type_condition = (filter_source_type_condition | col('sourceType').isNull())
            condition = filter_source_type_condition if condition is None else condition & filter_source_type_condition

        # Update source id condition
        if self.__filter_source_ids is not None:
            filter_source_id_condition = col('sourceId').isin([m for m in self.__filter_source_ids if m is not None])
            if None in self.__filter_source_ids:
                filter_source_id_condition = (filter_source_id_condition | col('sourceId').isNull())
            condition = filter_source_id_condition if condition is None else condition & filter_source_id_condition

        # Update label condition
        if self.__filter_labels is not None:
            filter_label_condition = col('label').isin([m for m in self.__filter_labels if m is not None])
            if None in self.__filter_labels:
                filter_label_condition = (filter_label_condition | col('label').isNull())
            condition = filter_label_condition if condition is None else condition & filter_label_condition

        # Update applicant condition
        if self.__filter_applicants is not None:
            filter_applicant_condition = col('applicantId1').isin([m for m in self.__filter_applicants if m is not None])
            condition = filter_applicant_condition if condition is None else condition & filter_applicant_condition

        # Update match condition
        if self.__filter_duplicates is not None:
            filter_match_condition = col('applicantId2').isin([m for m in self.__filter_duplicates if m is not None])
            condition = filter_match_condition if condition is None else condition & filter_match_condition

        # Filter before timestamp
        if self.__filter_before_timestamp is not None:
            filter_before_condition = col('timestamp') <= self.__filter_before_timestamp
            condition = filter_before_condition if condition is None else condition & filter_before_condition

        # Filter after timestamp
        if self.__filter_after_timestamp is not None:
            filter_after_condition = col('timestamp') >= self.__filter_after_timestamp
            condition = filter_after_condition if condition is None else condition & filter_after_condition

        labels = self.__spark.read.csv(self.__log_path, header=True).sort('timestamp', ascending=False)
        results = labels if condition is None else labels.filter(condition)
        
        
        temp = results.withColumnRenamed('applicantId1', 'applicantId').withColumnRenamed('applicantId2', 'duplicateId')
        temp = temp.withColumnRenamed('applicantId', 'applicantId2').withColumnRenamed('duplicateId', 'applicantId1')
        
        results = results.union(temp)
        
        results = results.groupBy('applicantId1').agg(sf.collect_set('applicantId2').alias('duplicateIds'))
        
        results = results if not self.__drop_duplicates else results.dropDuplicates(['applicantId1'])
        results = results.withColumnRenamed('applicantId1', 'applicantId').withColumnRenamed('applicantId2', 'duplicateId')

        if date_field:
            results = results.withColumn('date', from_unixtime(col('timestamp')/1000, 'yyyy-MM-dd'))

        if time_field:
            results = results.withColumn('time', from_unixtime(col('timestamp')/1000, 'HH:mm:ss'))

        return results



In [None]:
from typing import Union, List
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import col, from_json, size, explode_outer as pyspark_explode, from_unixtime
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, ArrayType, StringType, DoubleType
from datetime import datetime


class QueryLogLoader:
    r"""
    Tool to help parse and filer detection call logs
    """

    def __init__(self, spark: SparkSession, logs_path: str):
        r"""

        :param spark: Spark session
        :param logs_path:  Path to logs
        """
        self.__spark = spark
        self.__logs_path = logs_path
        self.__filter_applicants = None
        self.__filter_max_results = None
        self.__filter_metric_target = None
        self.__filter_metric = None
        self.__filter_models = None
        self.__filter_has_matches = None
        self.__filter_after_timestamp = None
        self.__filter_before_timestamp = None

    def filter_by_applicants(self, applicants: Union[Union[str, None], List[Union[str, None]]]):
        r"""
        Filters by requested applicants

        :param applicants: Applicant or list of applicants to filter by. None is also supported.
                           To search only when no model was specified, pass a list with None inside.
                           Passing just None (outside of list) will unset the filter
        """
        if not isinstance(applicants, list) and applicants is not None:
            applicants = [applicants]
        self.__filter_applicants = applicants

    def filter_by_max_results(self, max_results: Union[int, List[int], None] = None):
        r"""
        Filters by requested max results

        :param max_results: Max results or list of max results to filter by.
                            Passing just None will unset the filter.
        """
        if not isinstance(max_results, list) and max_results is not None:
            max_results = [max_results]
        self.__filter_max_results = max_results

    def filter_by_metric_target(self, metric_target: Union[float, List[float], None] = None):
        r"""
        Filters by requested metric targets

        :param metric_target: Metric target or list of metric targets to filter by.
                              Passing just None will unset the filter.
        """
        if not isinstance(metric_target, list) and metric_target is not None:
            metric_target = [metric_target]
        self.__filter_metric_target = metric_target

    def filter_by_metric(self, metric: Union[str, List[str], None] = None):
        r"""
        Filters by requested metrics

        :param metric: Metric or list of metrics to filter by.
                       Passing just None will unset the filter.
        """
        if not isinstance(metric, list) and metric is not None:
            metric = [metric]
        self.__filter_metric = metric

    def filter_by_models(self, models: Union[Union[str, None], List[Union[str, None]]]):
        r"""
        Filters by requested models

        :param models: Model or list of models to filter by. None is also supported.
                       To search only when no model was specified, pass a list with None inside.
                       Passing just None (outside of list) will unset the filter
        """
        if not isinstance(models, list) and models is not None:
            models = [models]
        self.__filter_models = models

    def filter_by_has_matches(self, has_matches: Union[None, bool]):
        r"""
        Filter based on did response have matches

        :param has_matches: True, False to set filter, None to unset filter
        :return:
        """
        self.__filter_has_matches = has_matches

    def __to_timestamp(self, timestamp: Union[None, int, datetime]) -> Union[None, int]:
        r"""
        Converts input to valid timestamp value for filters

        :param timestamp: Timestamp input
        :return: Timestamp output
        """
        if isinstance(timestamp, datetime):
            return round(timestamp.timestamp() * 1000)
        return timestamp

    def filter_after_timestamp(self, timestamp: Union[None, int, datetime]):
        r"""
        Filters logs that occurred after specific timestamp

        :param timestamp: Timestamp or None to unset
        """
        self.__filter_after_timestamp = self.__to_timestamp(timestamp)

    def filter_before_timestamp(self, timestamp: Union[None, int, datetime]):
        r"""
        Filters logs that occurred before specific timestamp

        :param timestamp: Timestamp or None to unset
        """
        self.__filter_before_timestamp = self.__to_timestamp(timestamp)

    def load(self, explode: bool = False, date_field: bool = False, time_field: bool = False) -> DataFrame:
        r"""
        Loads logs and applies filters

        Returned DataFrame has following schema
        root
         |-- model: string (nullable = true)
         |-- smRequest: struct (nullable = true)
         |    |-- applicantId: string (nullable = true)
         |    |-- maxResults: long (nullable = true)
         |    |-- metric: string (nullable = true)
         |    |-- metricTarget: double (nullable = true)
         |-- smResponse: struct (nullable = true)
         |    |-- matches: array (nullable = true)
         |    |    |-- element: struct (containsNull = true)
         |    |    |    |-- duplicateId: string (nullable = true)
         |    |    |    |-- probability: double (nullable = true)
         |-- timestamp: long (nullable = true)

         If `explode` is set to True, then smRequest is separated into columns and each of the matches
         produces new row also separated into columns. DataFrame schema is then as follows:
         root
          |-- model: string (nullable = true)
          |-- applicantId: string (nullable = true)
          |-- maxResults: long (nullable = true)
          |-- metric: string (nullable = true)
          |-- metricTarget: double (nullable = true)
          |-- duplicateId: string (nullable = true)
          |-- probability: double (nullable = true)
          |-- timestamp: long (nullable = true)
        For cases where no matches were found, explode will create a row where duplicateId and probability are null

        If `date_field` is set to true then field representing date from timestamp in format `yyyy-MM-dd` is added
          |-- date: string (nullable = true)

        If `date_time` is set to true then field representing time from timestamp in format `HH:mm:ss` is added
          |-- time: string (nullable = true)

        :param explode:    Should rows be exploded (description above)
        :param date_field: Should field `date` be added
        :param time_field: Should field `time` be added

        :return: Dataframe of logs
        """

        condition = None

        # Update applicant condition
        if self.__filter_applicants is not None:
            filter_applicant_condition = col('smRequest.applicantId')\
                .isin([m for m in self.__filter_applicants if m is not None])
            if None in self.__filter_applicants:
                filter_applicant_condition = (filter_applicant_condition | col('smRequest.applicantId').isNull())
            condition = filter_applicant_condition if condition is None else condition & filter_applicant_condition

        # Update max results condition
        if self.__filter_max_results is not None:
            filter_max_results_condition = col('smRequest.maxResults') \
                .isin([m for m in self.__filter_max_results if m is not None])
            condition = filter_max_results_condition if condition is None else condition & filter_max_results_condition

        # Update precision condition
        if self.__filter_metric_target is not None:
            filter_metric_target_condition = col('smRequest.metricTarget') \
                .isin([m for m in self.__filter_metric_target if m is not None])
            condition = filter_metric_target_condition if condition is None \
                else condition & filter_metric_target_condition

        # Update max results condition
        if self.__filter_metric is not None:
            filter_metric_condition = col('smRequest.metric') \
                .isin([m for m in self.__filter_metric if m is not None])
            condition = filter_metric_condition if condition is None else condition & filter_metric_condition

        # Update model condition
        if self.__filter_models is not None:
            filter_model_condition = col('model').isin([m for m in self.__filter_models if m is not None])
            if None in self.__filter_models:
                filter_model_condition = (filter_model_condition | col('model').isNull())
            condition = filter_model_condition if condition is None else condition & filter_model_condition

        # Update has matches condition
        if self.__filter_has_matches is not None:
            filter_matches_condition = (size('smResponse.matches') > 0) if self.__filter_has_matches else (
                    size('smResponse.matches') == 0)
            condition = filter_matches_condition if condition is None else condition & filter_matches_condition

        # Filter before timestamp
        if self.__filter_before_timestamp is not None:
            filter_before_condition = col('timestamp') <= self.__filter_before_timestamp
            condition = filter_before_condition if condition is None else condition & filter_before_condition

        # Filter after timestamp
        if self.__filter_after_timestamp is not None:
            filter_after_condition = col('timestamp') >= self.__filter_after_timestamp
            condition = filter_after_condition if condition is None else condition & filter_after_condition

        # Response is JSON stored as string. This needs to be parsed.
        response_schema = StructType([
            StructField(
                name='matches',
                nullable=False,
                dataType=ArrayType(
                    elementType=StructType(
                        [
                            StructField(name='duplicateId', dataType=StringType(), nullable=False),
                            StructField(name='probability', dataType=DoubleType(), nullable=False)
                        ]
                    ),
                    containsNull=False
                )
            )
        ])

        logs = self.__spark.read.json(path=self.__logs_path) \
            .withColumn('smResponse', from_json(col('smResponse'), response_schema))
        results = (logs if condition is None else logs.filter(condition))

        if explode:
            results = results.select(
                col('model'),
                col('smRequest'),
                pyspark_explode(col('smResponse.matches')).alias('match'),
                col('timestamp')
            )
        results = results.select(
                col('model'),
                col('smRequest.applicantId').alias('applicantId'),
                col('smRequest.maxResults').alias('maxResults'),
                col('smRequest.metric').alias('metric'),
                col('smRequest.metricTarget').alias('metricTarget'),
                col('smResponse.matches').alias('match'),
                col('timestamp')
            )

        if date_field:
            results = results.withColumn('date', from_unixtime(col('timestamp')/1000, 'yyyy-MM-dd'))

        if time_field:
            results = results.withColumn('time', from_unixtime(col('timestamp')/1000, 'HH:mm:ss'))

        return results

In [None]:
from datetime import datetime
from typing import Union

class LabeledRequestsSearcher(object):

    def __init__(self, query_log_loader: QueryLogLoader, labels_log_loader: LabelsLogLoader):
        self.__query_log_loader = query_log_loader
        self.__labels_log_loader = labels_log_loader

    def find_requests(self,
                      since: datetime,
                      until: datetime,
                      max_results: int = 5,
                      metric: str = "PRECISION",
                      metric_target: float = 0.85,
                      models: Union[str, None] = None,
                      source_id: str = "ALTEREGO_ON_HIRE",
                      source_type: str = "VERIFIER",
                      labels: Union[str] = None,
                      minutes_delta: int = 5):
        # Setup Query log loader options
        self.__query_log_loader.filter_by_metric(metric)
        self.__query_log_loader.filter_by_metric_target(metric_target)
        self.__query_log_loader.filter_by_applicants(None)
        self.__query_log_loader.filter_by_max_results(max_results)
        self.__query_log_loader.filter_by_has_matches(True)
        self.__query_log_loader.filter_by_models([None] if models is None else models)
        self.__query_log_loader.filter_before_timestamp(until)
        self.__query_log_loader.filter_after_timestamp(since)

        # Setup Label log loader options
        self.__labels_log_loader.filter_by_applicant(None)
        self.__labels_log_loader.filter_by_duplicate(None)
        self.__labels_log_loader.filter_by_source_type(source_type)
        self.__labels_log_loader.filter_by_source_id(source_id)
        self.__labels_log_loader.filter_by_label(['DUPLICATE', 'NON_DUPLICATE'] if labels is None else labels)
        self.__labels_log_loader.filter_before_timestamp(until)
        self.__labels_log_loader.filter_after_timestamp(since)
        self.__labels_log_loader.drop_duplicates()

        query_logs = self.__query_log_loader.load(explode=False).cache()
        label_logs = self.__labels_log_loader.load().cache()

        joined = label_logs.join(
            query_logs,
            (query_logs.applicantId == label_logs.applicantId)
#             & (query_logs.duplicateId == label_logs.duplicateId)
            & (query_logs.timestamp < label_logs.timestamp)
            & (label_logs.timestamp - query_logs.timestamp < 1000 * 60 * minutes_delta),
            'inner')

        return joined \
            .drop(label_logs.timestamp) \
            .drop(label_logs.recordId) \
#             .drop(label_logs.applicantId) \
            .drop(label_logs.employeeId) \
#             .drop(label_logs.duplicateId) \
            .drop(label_logs.sourceType) \
            .drop(label_logs.sourceId) \
            .drop(label_logs.label) \
            .drop(query_logs.model) \
            .drop(query_logs.probability) \
            .drop(query_logs.timestamp) \
            .drop(query_logs.duplicateId) \
            .dropDuplicates(['applicantId', 'maxResults', 'metric', 'metricTarget'])

In [None]:
from datetime import datetime, timedelta

import backoff
import requests
from aws_requests_auth.aws_auth import AWSRequestsAuth
from botocore.client import BaseClient
from requests.auth import AuthBase


class AuthenticatedLambdaClient(object):

    def __init__(
            self,
            sts_client: BaseClient,
            sts_role: str,
            sts_session: str,
            hostname: str,
            region: str,
            stage: str,
            token_life: int = 900,
            token_buffer_time: int = 60):
        self.__sts_client = sts_client
        self.__sts_role = sts_role
        self.__sts_session = sts_session
        self.__hostname = hostname
        self.__region = region
        self.__stage = stage
        self.__token_life = token_life
        self.__token_buffer_time = token_buffer_time
        self.__token = None

    def __is_token_valid(self) -> bool:
        r"""
        Checks if token is still valid (not close to expiry or already expired)

        :return: True if token is valid, False otherwise
        """
        if self.__token is None:
            return False
        token_expires = self.__token['Credentials']['Expiration']
        current_time = datetime.now(self.__token['Credentials']['Expiration'].tzinfo)
        buffer_window = timedelta(seconds=self.__token_buffer_time)
        return (token_expires - current_time) > buffer_window

    def __get_token(self) -> dict:
        r"""
        Gets valid (non expired / close to expiry) token

        :return: Valid token
        """
        if not self.__is_token_valid():
            self.__token = self.__sts_client.assume_role(
                RoleArn=self.__sts_role,
                RoleSessionName=self.__sts_session,
                DurationSeconds=self.__token_life
            )
        return self.__token

    def __get_auth(self) -> AWSRequestsAuth:
        r"""
        Creates authenticator for the call
        """
        credentials = self.__get_token()
        access_key = credentials['Credentials']['AccessKeyId']
        secret_key = credentials['Credentials']['SecretAccessKey']
        session_token = credentials['Credentials']['SessionToken']
        return AWSRequestsAuth(aws_access_key=access_key,
                               aws_secret_access_key=secret_key,
                               aws_token=session_token,
                               aws_host=self.__hostname,
                               aws_region=self.__region,
                               aws_service='execute-api')

    def __create_uri(self, path: str) -> str:
        path = path if not path.startswith('/') else path[1:]
        return f'https://{self.__hostname}/{self.__stage}/{path}'

    @backoff.on_exception(backoff.expo, requests.exceptions.ConnectionError, jitter=backoff.full_jitter, max_tries=5)
    def request(self, method: str, path: str, auth: AuthBase = None, **kwargs) -> requests.Response:
        r"""
        Submits an API request to lambda
        :param method: Request method
        :param path:   API path (not including hostname and stage)
        :param auth:   Optional authentication, STS token if not specified
        :param kwargs: Other requests.request parameters
        :return:
        """
        return requests.request(
            method=method,
            url=self.__create_uri(path),
            auth=auth if auth is not None else self.__get_auth(),
            **kwargs
        )

    def post(self, path: str, auth: AuthBase = None, **kwargs) -> requests.Response:
        r"""
        Submits an POST request to lambda
        :param path:   API path (not including hostname and stage)
        :param auth:   Optional authentication, STS token if not specified
        :param kwargs: Other requests.request parameters
        :return:
        """
        return self.request(method='POST', path=path, auth=auth, **kwargs)

    def get(self, path: str, auth: AuthBase = None, **kwargs) -> requests.Response:
        r"""
        Submits an GET request to lambda
        :param path:   API path (not including hostname and stage)
        :param auth:   Optional authentication, STS token if not specified
        :param kwargs: Other requests.request parameters
        :return:
        """
        return self.request(method='GET', path=path, auth=auth, **kwargs)

    

class AlterEgoClient(object):

    def __init__(self, lambda_client: AuthenticatedLambdaClient):
        self.__lambda_client = lambda_client

    def query_duplicates(self, applicant_id: str, metric: str, metric_target: float, max_results: int, model: str = None, applicant_namespace: str = 'ICIMS', results_namespace: str = 'ICIMS') -> dict:
        request = {
            'applicantId': {
                'namespace': applicant_namespace,
                'id': applicant_id
            },
            'metric': metric,
            'metricTarget': metric_target,
            'resultsNamespace': results_namespace,
            'maxResults': max_results
        }
        if model is not None:
            request['model'] = model
        print(request)

        response = self.__lambda_client.post(path='/duplicate/query', json=request)
        if response.status_code != 200:
            return {"error":f"{response.status_code} {response.text} error"}
        return response.json()


In [None]:
import json
import sys
from datetime import datetime
from typing import Union, Optional

import dateutil
from boto3 import client
from pyspark import SparkContext
from pyspark.sql.types import Row

class LabeledModelReplayJob(object):

    def __init__(self, alter_ego_client: AlterEgoClient, labeled_requests_searcher: LabeledRequestsSearcher):
        self.__alter_ego_client = alter_ego_client
        self.__labeled_requests_searcher = labeled_requests_searcher

    def __handle_row(self, row: Row, model: str, current: int, total: int):
        result = self.__alter_ego_client.query_duplicates(
            applicant_id=row.applicantId,
            metric=row.metric,
            metric_target=row.metricTarget,
            max_results=row.maxResults,
            model=model
        )
        data = {
            'total': total,
            'current': current,
            'input': {
                'applicantId': row.applicantId,
                'metric': row.metric,
                'metricTarget': row.metricTarget,
                'maxResults': row.maxResults,
                'model': model
            },
            'output': result
        }
        print(row.asDict())

#         ground_truth = row.duplicateIds
#         in_model_output = row.duplicateIds
        print(f'Labeled Replay: {json.dumps(data)}')
        
    def get_replay_data(self, input_models: Union[str, None], since: datetime, until: datetime):
        self.requests = self.__labeled_requests_searcher.find_requests(since=since, until=until, models=input_models).cache()
        
    def replay(self, input_models: Union[str, None], replay_model: str):
        total = self.requests.count()
        current = 0
        for row in self.requests.collect():
            current = current + 1
            self.__handle_row(row=row, model=replay_model, current=current, total=total)
            break


def parse_time(in_time: str) -> Union[datetime, None]:
    r"""
    Parse time input
    :param in_time:  Time input
    :return: Parsed time input
    """
    i = in_time.strip()
    return None if len(i) == 0 else dateutil.parser.parse(i)


def parse_model(model: str) -> Optional[str]:
    r"""
    Parse model input
    :param model:      Name of model to filter for
    :return:
    """
    in_model = model.strip()
    return None if len(in_model) == 0 or in_model.lower() == 'none' else in_model


In [None]:
# Read input parameters
args = {
    '--hostname': 'oxnjq1wfm8.execute-api.us-west-2.amazonaws.com',
    '--stage': 'beta',
    '--query_logs': 's3a://alterego-data-exports-prod/detections/monthly/20230129.gz',
    '--label_logs': 's3a://alterego-data-exports-prod/labels/20230130T0509/',
    '--in_model': 'COMPOUND_B',
    '--out_model': 'COMPOUND_B_OS',
    '--since': '2023-01-01 00:00:00',
    '--until': '2023-01-30 00:00:00',
    '--sts_role': 'arn:aws:iam::171837983301:role/CustomerRoleAlterEgoExperiments',
    '--sts_session': 'LabeledModelReplay',
}
ARG_KEYS = ['sts_role', 'sts_session', 'hostname', 'stage', 'query_logs', 'label_logs', 'in_model', 'out_model',
            'since', 'until']

sts_role, sts_session, hostname, stage, query_logs, label_logs, in_model, out_model, since, until \
    = (args[f"--{key}"] for key in ARG_KEYS)

In [None]:
# Setup job
lambda_client = AuthenticatedLambdaClient(client('sts'), sts_role, sts_session, hostname, 'us-west-2', stage)
alter_ego_client = AlterEgoClient(lambda_client)

In [None]:
searcher = LabeledRequestsSearcher(QueryLogLoader(spark, query_logs), LabelsLogLoader(spark, label_logs))
replay_job = LabeledModelReplayJob(alter_ego_client, searcher)


In [None]:
replay_job.replay(parse_model(in_model), parse_model(out_model), parse_time(since), parse_time(until))