## Define Plan and Index

In [None]:
from palimpzest.datamanager import DataDirectory
from palimpzest.corelib import Schema
from ragatouille import RAGPretrainedModel
import palimpzest as pz
import datasets

class BiodexEntry(Schema):
    """A single entry in the Biodex ICSR Dataset."""

    pmid = pz.StringField(desc="The PubMed ID of the medical paper", required=True)
    title = pz.StringField(desc="The title of the medical paper", required=True)
    abstract = pz.StringField(desc="The abstract of the medical paper", required=True)
    fulltext = pz.StringField(desc="The full text of the medical paper, which contains information relevant for creating a drug safety report.", required=True)

class BiodexReactions(BiodexEntry):
    """
    You will be presented with the text of a medical article which is partially or entirely about
    an adverse event experienced by a patient in response to taking one or more drugs. In this task,
    you will be asked to extract a list of the primary adverse reactions which are experienced by the patient.
    """
    # reactions = pz.ListField(desc="The **list** of the adverse reaction(s) which may have resulted from taking the drug(s) discussed in the report.\n - For example: [\"Epstein-Barr virus\", \"infection reactivation\", \"Idiopathic interstitial pneumonia\"]", element_type=pz.StringField, required=True)
    reactions = pz.ListField(desc="The list of the adverse reaction term(s) which may have resulted from taking the drug(s) discussed in the report.", element_type=pz.StringField, required=True)

class BiodexReactionLabels(BiodexReactions):
    """
    Retrieve the labels which are most relevant for the given set of inferred reactions.
    """
    reaction_labels = pz.ListField(desc="Most relevant official terms for adverse reactions for the provided `reactions`",
                                   element_type=pz.StringField, required=True)

class BiodexRankedReactions(BiodexReactionLabels):
    """
    You will be presented with the text of a medical article which is partially or entirely about
    an adverse event experienced by a patient in response to taking one or more drugs. You will also
    be presented with a list of inferred reactions, and a set of retrieved labels which were matched
    to these inferred reactions. In this task, you are asked to output a ranked list of the labels
    which are most applicable based on the context of the article. Your output list must:
    - contain only elements from `reaction_labels`
    - place the most likely label first and the least likely label last
    - you may omit labels if you think they do not describe a reaction experienced by the patient
    """
    ranked_reaction_labels = pz.ListField(desc="The ranked list of labels for adverse reactions experienced by the patient. The most likely label occurs first in the list.",
                                          element_type=pz.StringField, required=True)


class BiodexValidationSource(pz.ValidationDataSource):
    def __init__(self, datasetId, reactions_only: bool=True, rp_at_k: int=5, num_samples: int=5, shuffle: bool=False, seed: int=42):
        super().__init__(BiodexEntry, datasetId)
        self.dataset = datasets.load_dataset("BioDEX/BioDEX-ICSR")
        self.train_dataset = [self.dataset['train'][idx] for idx in range(100)]
        self.test_dataset = [self.dataset['train'][idx] for idx in range(100, 150)]

        # sample from full test dataset
        # self.test_dataset = [self.dataset['test'][idx] for idx in range(len(self.dataset['test']))]
        # self.test_dataset = self.test_dataset[:250] # use first 250 to compare directly with biodex

        self.reactions_only = reactions_only
        self.rp_at_k = rp_at_k
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.seed = seed

        # construct mapping from listing --> label (field, value) pairs
        def compute_target_record(entry, reactions_only: bool=False):
            target_lst = entry['target'].split('\n')
            label_dict = {
                "serious": int(target_lst[0].split(':')[-1]),
                "patientsex": int(target_lst[1].split(':')[-1]),
                "drugs": [drug.strip().lower() for drug in target_lst[2].split(':')[-1].split(",")],
                "reactions": [reaction.strip().lower() for reaction in target_lst[3].split(':')[-1].split(",")],
                "reaction_labels": [reaction.strip().lower() for reaction in target_lst[3].split(':')[-1].split(",")],
                "ranked_reaction_labels": [reaction.strip().lower() for reaction in target_lst[3].split(':')[-1].split(",")],
            }
            if reactions_only:
                label_dict = {
                    k: v
                    for k, v in label_dict.items()
                    if k in ["reactions", "reaction_labels", "ranked_reaction_labels"]
                }
            return label_dict

        self.label_fields_to_values = {
            entry['pmid']: compute_target_record(entry, reactions_only=reactions_only)
            for entry in self.test_dataset
        }

        # shuffle records if shuffle = True
        if shuffle:
            random.Random(seed).shuffle(self.train_dataset)

        # trim to number of samples
        self.train_dataset = self.train_dataset[:num_samples]

    def copy(self):
        return BiodexValidationSource(self.dataset_id, self.num_samples, self.shuffle, self.seed)

    def __len__(self):
        return len(self.test_dataset)

    def getValLength(self):
        return len(self.train_dataset)

    def getSize(self):
        return 0

    def getFieldToMetricFn(self):
        # define f1 function
        def f1_eval(preds: list, targets: list):
            if preds is None:
                return 0.0

            try:
                # compute precision and recall
                s_preds = set([pred.lower() for pred in preds])
                s_targets = set([target.lower() for target in targets])

                intersect = s_preds.intersection(s_targets)

                precision = len(intersect) / len(s_preds) if len(s_preds) > 0 else 0.0
                recall = len(intersect) / len(s_targets)

                # compute f1 score and return
                f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

                return f1

            except:
                os.makedirs("f1-errors", exist_ok=True)
                ts = time.time()
                with open(f"f1-errors/error-{ts}.txt", "w") as f:
                    f.write(str(preds))
                return 0.0

        # define rank precision at k
        def rank_precision_at_k(preds: list, targets: list):
            if preds is None:
                return 0.0

            try:
                # lower-case each list
                preds = [pred.lower() for pred in preds]
                targets = set([target.lower() for target in targets])

                # compute rank-precision at k
                Rn = len(targets)
                denom = min(self.rp_at_k, Rn)
                total = 0.0
                for i in range(self.rp_at_k):
                    total += preds[i] in targets if i < len(preds) else 0.0

                return total / denom

            except:
                os.makedirs("rp@k-errors", exist_ok=True)
                ts = time.time()
                with open(f"rp@k-errors/error-{ts}.txt", "w") as f:
                    f.write(str(preds))
                return 0.0

        # define quality eval function for drugs and reactions fields
        fields_to_metric_fn = {}
        if self.reactions_only:
            fields_to_metric_fn = {
                "reactions": f1_eval,
                "reaction_labels": f1_eval,
                "ranked_reaction_labels": rank_precision_at_k,
            }

        else:
            fields_to_metric_fn = {
                "serious": "exact",
                "patientsex": "exact",
                "drugs": f1_eval,
                "reactions": f1_eval,
            }

        return fields_to_metric_fn

    def getItem(self, idx: int, val: bool=False, include_label: bool=False):
        # fetch entry
        entry = self.test_dataset[idx] if not val else self.train_dataset[idx]

        # create data record
        dr = pz.DataRecord(self.schema, source_id=entry['pmid'])
        dr.pmid = entry['pmid']
        dr.title = entry['title']
        dr.abstract = entry['abstract']
        dr.fulltext = entry['fulltext']

        # if requested, also return the label information
        if include_label:
            # augment data record with label info
            labels_dict = self.label_fields_to_values[entry['pmid']]

            for field, value in labels_dict.items():
                setattr(dr, field, value)

        return dr

# create and register validation data source
datasource = BiodexValidationSource(
    datasetId=f"biodex-reactions",
    num_samples=50,
    shuffle=False,
    seed=42,
)
DataDirectory().registerUserSource(
    src=datasource,
    dataset_id=f"biodex-reactions",
)

# load index
index_path = ".ragatouille/colbert/indexes/reaction-terms"
index = RAGPretrainedModel.from_index(index_path)

## Define Plan

In [None]:
yaml_str = """
operators:
- MixtureOfAgentsConvert:
    proposer_models: [gpt-4o-mini, llama3, gpt-4o]
    aggregator_model: llama3
    temperatures: [0.4, 0.4, 0.4]
- Retrieve:
    k: 5
- TokenReducedConvertBonded:
    model: gpt-4o-mini
    token_budget: 0.5
"""

In [None]:
from palimpzest.constants import Model
from palimpzest.operators import *
from palimpzest.optimizer import PhysicalPlan
import yaml

def yaml_to_biodex_reactions_physical_plan(yaml_str):
    plan_dict = yaml.safe_load(yaml_str)
    operators = plan_dict['operators']

    # maps from op_idx to the input and output schema(s)
    op_idx_to_input_schema = {
        0: BiodexEntry,
        1: BiodexReactions,
        2: BiodexReactionLabels,
    }
    op_idx_to_output_schema = {
        0: BiodexReactions,
        1: BiodexReactionLabels,
        2: BiodexRankedReactions,
    }
    model_name_to_model = {
        "gpt-4o": Model.GPT_4o,
        "gpt-4o-mini": Model.GPT_4o_MINI,
        "llama3": Model.LLAMA3,
        "mixtral": Model.MIXTRAL,
    }

    # create scan operator
    scanOp = MarshalAndScanDataOp(outputSchema=BiodexEntry, dataset_id="biodex-reactions")

    ops = [scanOp]
    for op_idx, operator in enumerate(operators):
        input_schema = op_idx_to_input_schema[op_idx]
        output_schema = op_idx_to_output_schema[op_idx]

        op = None
        op_name = list(operator.keys())[0]
        if "Retrieve" == op_name:
            op_kwargs = operator["Retrieve"]
            op = RetrieveOp(
                inputSchema=input_schema,
                outputSchema=output_schema,
                search_attr="reactions",
                output_attr="reaction_labels",
                index=index,
                verbose=True,
                **op_kwargs,
            )
        elif "MixtureOfAgentsConvert" == op_name:
            op_kwargs = operator["MixtureOfAgentsConvert"]
            temperatures = op_kwargs['temperatures']            
            proposer_models = [model_name_to_model[model_name] for model_name in op_kwargs["proposer_models"]]
            aggregator_model = model_name_to_model[op_kwargs['aggregator_model']]
            depends_on = op_kwargs["depends_on"] if "depends_on" in op_kwargs else None
            op = MixtureOfAgentsConvert(
                inputSchema=input_schema,
                outputSchema=output_schema,
                proposer_models=proposer_models,
                aggregator_model=aggregator_model,
                temperatures=temperatures,
                verbose=True,
                depends_on=depends_on,
            )
        elif "LLMConvertBonded" == op_name:
            op_kwargs = operator["LLMConvertBonded"]
            model = model_name_to_model[op_kwargs["model"]]
            depends_on = op_kwargs["depends_on"] if "depends_on" in op_kwargs else None
            op = LLMConvertBonded(
                inputSchema=input_schema,
                outputSchema=output_schema,
                verbose=True,
                model=model,
                depends_on=depends_on,
            )
        elif "TokenReducedConvertBonded" == op_name:
            op_kwargs = operator["TokenReducedConvertBonded"]
            model = model_name_to_model[op_kwargs["model"]]
            token_budget = op_kwargs["token_budget"]
            depends_on = op_kwargs["depends_on"] if "depends_on" in op_kwargs else None
            op = TokenReducedConvertBonded(
                inputSchema=input_schema,
                outputSchema=output_schema,
                verbose=True,
                model=model,
                token_budget=token_budget,
                depends_on=depends_on,
            )

        ops.append(op)

    # construct final plan
    plan = PhysicalPlan(operators=ops)

    return plan


In [None]:
plan = yaml_to_biodex_reactions_physical_plan(yaml_str)

In [None]:
print(plan)

## Single-Threaded Execution (works for all policies, but slow)

In [None]:
from palimpzest.corelib import SourceRecord
from palimpzest.dataclasses import OperatorStats, PlanStats
from palimpzest.datamanager import DataDirectory
from palimpzest.operators import DataSourcePhysicalOp

plan_start_time = time.time()

# initialize plan stats and operator stats
plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
for op in plan.operators:
    op_id = op.get_op_id()
    op_name = op.op_name()
    op_details = {k: str(v) for k, v in op.get_op_params().items()}
    plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)

# get handle to DataSource and pre-compute its size
source_operator = plan.operators[0]
datasource = DataDirectory().getRegisteredDataset(source_operator.dataset_id)
datasource_len = len(datasource)

# execute the plan one operator at a time
candidates, output_records = [], []
for current_scan_idx in range(datasource_len):
    for op_idx, operator in enumerate(plan.operators):
        op_id = operator.get_op_id()
        prev_op_id = (
            plan.operators[op_idx - 1].get_op_id() if op_idx > 1 else None
        )
        next_op_id = (
            plan.operators[op_idx + 1].get_op_id()
            if op_idx + 1 < len(plan.operators)
            else None
        )
    
        # initialize output records and record_op_stats for this operator
        records, record_op_stats = [], []

        # invoke datasource operator(s) until we run out of source records or hit the num_samples limit
        if isinstance(operator, DataSourcePhysicalOp):
            # construct input DataRecord for DataSourcePhysicalOp
            # NOTE: this DataRecord will be discarded and replaced by the scan_operator;
            #       it is simply a vessel to inform the scan_operator which record to fetch
            candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
            candidate.idx = current_scan_idx
            candidate.get_item_fn = datasource.getItem

            # run DataSourcePhysicalOp on record
            record_set = operator(candidate)
            records.extend(record_set.data_records)
            record_op_stats.extend(record_set.record_op_stats)

        # otherwise, process the records in the processing queue for this operator one at a time
        else:
            for input_record in candidates:
                record_set = operator(input_record)
                records.extend(record_set.data_records)
                record_op_stats.extend(record_set.record_op_stats)

        # update plan stats
        plan_stats.operator_stats[op_id].add_record_op_stats(
            record_op_stats,
            source_op_id=prev_op_id,
            plan_id=plan.plan_id,
        )

        # update candidates
        candidates = []
        for record in records:
            if isinstance(operator, FilterOp):
                if not record._passed_operator:
                    continue
            if next_op_id is not None:
                candidates.append(record)
            else:
                output_records.append(record)

        # if we've filtered out all records, terminate early
        if next_op_id is not None:
            if candidates == []:
                break

# finalize plan stats
total_plan_time = time.time() - plan_start_time
plan_stats.finalize(total_plan_time)

## Parallel Execution (fast; but does not work for time policies)

In [None]:
from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
from palimpzest.corelib import SourceRecord
from palimpzest.dataclasses import OperatorStats, PlanStats
from palimpzest.datamanager import DataDirectory
from palimpzest.operators import DataSourcePhysicalOp
from concurrent.futures import ThreadPoolExecutor, wait

NUM_WORKERS = 8

def execute_op_wrapper(operator, op_input):
    """
    Wrapper function around operator execution which also and returns the operator.
    This is useful in the parallel setting(s) where operators are executed by a worker pool,
    and it is convenient to return the op_id along with the computation result.
    """
    record_set = operator(op_input)

    return record_set, operator

plan_start_time = time.time()

# initialize plan stats and operator stats
plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
for op in plan.operators:
    op_id = op.get_op_id()
    op_name = op.op_name()
    op_details = {k: str(v) for k, v in op.get_op_params().items()}
    plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)

# initialize list of output records and intermediate variables
output_records = []
source_records_scanned = 0

# initialize data structures to help w/processing DAG
processing_queue = []
op_id_to_futures_in_flight = {op.get_op_id(): 0 for op in plan.operators}
op_id_to_operator = {op.get_op_id(): op for op in plan.operators}
op_id_to_prev_operator = {
    op.get_op_id(): plan.operators[idx - 1] if idx > 0 else None
    for idx, op in enumerate(plan.operators)
}
op_id_to_next_operator = {
    op.get_op_id(): plan.operators[idx + 1] if idx + 1 < len(plan.operators) else None
    for idx, op in enumerate(plan.operators)
}
op_id_to_op_idx = {op.get_op_id(): idx for idx, op in enumerate(plan.operators)}

# get handle to DataSource and pre-compute its size
source_operator = plan.operators[0]
source_op_id = source_operator.get_op_id()
datasource = DataDirectory().getRegisteredDataset(source_operator.dataset_id)
datasource_len = len(datasource)

# create thread pool w/max workers
futures = []
current_scan_idx = 0
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
    # create initial (set of) future(s) to read first source record;
    # construct input DataRecord for DataSourcePhysicalOp
    # NOTE: this DataRecord will be discarded and replaced by the scan_operator;
    #       it is simply a vessel to inform the scan_operator which record to fetch
    candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
    candidate.idx = current_scan_idx
    candidate.get_item_fn = datasource.getItem
    futures.append(executor.submit(execute_op_wrapper, source_operator, candidate))
    op_id_to_futures_in_flight[source_op_id] += 1
    current_scan_idx += 1

    # iterate until we have processed all operators on all records or come to an early stopping condition
    while len(futures) > 0:
        # get the set of futures that have (and have not) finished in the last PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
        done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)

        # cast not_done_futures from a set to a list so we can append to it
        not_done_futures = list(not_done_futures)

        # process finished futures, creating new ones as needed
        new_futures = []
        for future in done_futures:
            # get the result
            record_set, operator = future.result()
            op_id = operator.get_op_id()

            # decrement future from mapping of futures in-flight
            op_id_to_futures_in_flight[op_id] -= 1

            # update plan stats
            prev_operator = op_id_to_prev_operator[op_id]
            plan_stats.operator_stats[op_id].add_record_op_stats(
                record_set.record_op_stats,
                source_op_id=prev_operator.get_op_id() if prev_operator is not None else None,
                plan_id=plan.plan_id,
            )

            # process each record output by the future's operator
            for record in record_set:
                # skip records which are filtered out
                if not getattr(record, "_passed_operator", True):
                    continue

                # add records to processing queue if there is a next_operator; otherwise add to output_records
                next_operator = op_id_to_next_operator[op_id]
                if next_operator is not None:
                    processing_queue.append((next_operator, record))
                else:
                    output_records.append(record)

            # if this operator was a source scan, update the number of source records scanned
            if op_id == source_op_id:
                source_records_scanned += len(record_set)

                # scan next record if we can still draw records from source
                if current_scan_idx < datasource_len:
                    # construct input DataRecord for DataSourcePhysicalOp
                    # NOTE: this DataRecord will be discarded and replaced by the scan_operator;
                    #       it is simply a vessel to inform the scan_operator which record to fetch
                    candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
                    candidate.idx = current_scan_idx
                    candidate.get_item_fn = datasource.getItem
                    new_futures.append(executor.submit(execute_op_wrapper, source_operator, candidate))
                    op_id_to_futures_in_flight[source_op_id] += 1
                    current_scan_idx += 1

        # process all records in the processing queue which are ready to be executed
        for operator, candidate in processing_queue:
            # if the candidate is not an input to an aggregate, execute it right away
            future = executor.submit(execute_op_wrapper, operator, candidate)
            new_futures.append(future)
            op_id_to_futures_in_flight[operator.get_op_id()] += 1

        # clear the processing queue
        processing_queue = []

        # update list of futures
        not_done_futures.extend(new_futures)
        futures = not_done_futures

# finalize plan stats
total_plan_time = time.time() - plan_start_time
plan_stats.finalize(total_plan_time)

## Compute Plan Cost, Time, and Quality

In [None]:
def compute_quality(record_set, expected_record_set, field_to_metric_fn):
    """
    Compute the quality for the given `record_set` by comparing it to the `expected_record_set`.
    
    Update the record_set by assigning the quality to each entry in its record_op_stats and
    returning the updated record_set.
    """
    # if this operation is a failed convert
    if len(record_set) == 0:
        record_set.record_op_stats[0].quality = 0.0

    # if this is a successful convert operation
    # NOTE: the following computation assumes we do not project out computed values
    #       (and that the validation examples provide all computed fields); even if
    #       a user program does add projection, we can ignore the projection on the
    #       validation dataset and use the champion model (as opposed to the validation
    #       output) for scoring fields which have their values projected out

    # GREEDY ALGORITHM
    # for each record in the expected output, we look for the computed record which maximizes the quality metric;
    # once we've identified that computed record we remove it from consideration for the next expected output
    for expected_record in expected_record_set:
        best_quality, best_record_op_stats = 0.0, None
        for record_op_stats in record_set.record_op_stats:
            # if we already assigned this record a quality, skip it
            if record_op_stats.quality is not None:
                continue

            # compute number of matches between this record's computed fields and this expected record's outputs
            total_quality = 0
            for field in record_op_stats.generated_fields:
                computed_value = record_op_stats.record_state.get(field, None)
                expected_value = getattr(expected_record, field)

                # get the metric function for this field
                metric_fn = (
                    field_to_metric_fn[field]
                    if field_to_metric_fn is not None and field in field_to_metric_fn
                    else "exact"
                )

                # compute exact match
                if metric_fn == "exact":
                    total_quality += int(computed_value == expected_value)

                # compute UDF metric
                elif callable(metric_fn):
                    total_quality += metric_fn(computed_value, expected_value)

                # otherwise, throw an exception
                else:
                    raise Exception(f"Unrecognized metric_fn: {metric_fn}")

            # compute recall and update best seen so far
            quality = total_quality / len(record_op_stats.generated_fields)
            if quality > best_quality:
                best_quality = quality
                best_record_op_stats = record_op_stats

        # set best_quality as quality for the best_record_op_stats
        if best_record_op_stats is not None:
            best_record_op_stats.quality = best_quality

    # for any records which did not receive a quality, set it to 0.0 as these are unexpected extras
    for record_op_stats in record_set.record_op_stats:
        if record_op_stats.quality is None:
            record_op_stats.quality = 0.0
    
    return record_set


In [None]:
# compute plan cost, runtime, and quality
plan_cost = plan_stats.total_plan_cost
plan_time = plan_stats.total_plan_time

expected_outputs = {}
for idx in range(datasource_len):
    data_record = datasource.getItem(idx, include_label=True)
    expected_outputs[data_record._source_id] = data_record

qualities = []
for record in output_records:
    expected_record = expected_outputs[record._source_id]
    fields_to_metric_fn = datasource.getFieldToMetricFn()
    metric_fn = fields_to_metric_fn["ranked_reaction_labels"]
    computed_value = getattr(record, "ranked_reaction_labels")
    expected_value = getattr(expected_record, "ranked_reaction_labels")
    quality = metric_fn(computed_value, expected_value)
    qualities.append(quality)

plan_quality = sum(qualities)/len(qualities)

In [None]:
print(f"Plan(cost={plan_cost:.5f}, time={plan_time:.2f}, quality={plan_quality:.5f})")

## If you want to examine train / test dataset examples

In [None]:
import datasets
dataset = datasets.load_dataset("BioDEX/BioDEX-ICSR")
train_dataset = [dataset['train'][idx] for idx in range(100)]
test_dataset = [dataset['train'][idx] for idx in range(100, 150)]