In [None]:
#| default_exp pipeline

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
try:
    from .data_reader import read_csv_batches, process_csv_batch
    from .api_clients import call_polaris_api, call_polaris_api_bulk, call_parser_api, extract_target_direction
    from .derivations import derive_all_fields
    from .core import get_db_connection, insert_prescription_items
except ImportError:
    from dosagelimitsindex.data_reader import read_csv_batches, process_csv_batch
    from dosagelimitsindex.api_clients import call_polaris_api, call_polaris_api_bulk, call_parser_api, extract_target_direction
    from dosagelimitsindex.derivations import derive_all_fields
    from dosagelimitsindex.core import get_db_connection, insert_prescription_items

from typing import List, Dict, Any
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

In [None]:
#| export
def group_records_by_prescription(records: List[Dict]) -> Dict[str, List[Dict]]:
    """Group flattened records by prescription_id for API calls."""
    grouped = defaultdict(list)
    for record in records:
        grouped[record['prescription_id']].append(record)
    return grouped

In [None]:
#| export
def process_single_prescription(prescription_records: List[Dict]) -> List[tuple]:
    """
    Process a single prescription through both APIs and derivations.
    Returns flattened rows ready for database insertion.
    (Original non-bulk version - kept for reference)
    """
    first = prescription_records[0]
    
    prescription_items = []
    for rec in prescription_records:
        prescription_items.append({
            'seq': rec['seq'],
            'code': rec['code'],
            'drug': rec['drug'],
            'form': rec['form'],
            'route': rec['route'],
            'original_direction': rec['original_direction'],
            'additional_instructions': rec['additional_instructions']
        })
    
    polaris_response = call_polaris_api(
        prescription_id=first['prescription_id'],
        pharmacy_name=first['pharmacy_name'],
        pharmacy_code=first['pharmacy_code'],
        prescription_date=first['prescription_date'],
        patient_age=first['patient_age'],
        patient_gender=first['patient_gender'],
        prescription_items=prescription_items
    )
    
    for i, rec in enumerate(prescription_records):
        result = polaris_response['results'][i]
        rec['target_direction'] = extract_target_direction(result['target_direction'])
    
    parser_items = []
    for rec in prescription_records:
        parser_items.append({
            'seq': rec['seq'],
            'code': rec['code'],
            'drug': rec['drug'],
            'form': rec['form'],
            'route': rec['route'],
            'original_direction': rec['original_direction'],
            'additional_instructions': rec['additional_instructions'],
            'target_direction': rec['target_direction']
        })
    
    parser_payload = {
        'prescription_id': first['prescription_id'],
        'pharmacy_name': first['pharmacy_name'],
        'pharmacy_code': first['pharmacy_code'],
        'prescription_date': first['prescription_date'],
        'patient_age': first['patient_age'],
        'patient_gender': first['patient_gender'],
        'prescription_items': parser_items
    }
    
    parser_response = call_parser_api(parser_payload)
    
    db_rows = []
    for item_response in parser_response['prescription_items']:
        base_rec = next(r for r in prescription_records if r['seq'] == item_response['seq'])
        
        dosage_fields_list = item_response.get('dosage_fields', [])
        if not dosage_fields_list:
            dosage_fields_list = [{}]
        
        for dosage_seq, dosage_field in enumerate(dosage_fields_list, start=1):
            derived = derive_all_fields(dosage_field)
            
            row = (
                base_rec['prescription_id'],
                base_rec['pharmacy_name'],
                base_rec['pharmacy_code'],
                base_rec['prescription_date'],
                base_rec['patient_age'],
                base_rec['patient_gender'],
                base_rec['seq'],
                base_rec['code'],
                base_rec['drug'],
                base_rec['form'],
                base_rec['route'],
                base_rec['original_direction'],
                base_rec['additional_instructions'],
                base_rec['target_direction_manual'],
                base_rec['target_direction'],
                dosage_seq,
                dosage_field.get('text'),
                dosage_field.get('strength'),
                dosage_field.get('strength_max'),
                dosage_field.get('strength_unit'),
                dosage_field.get('strength_numerator'),
                dosage_field.get('strength_numerator_unit'),
                dosage_field.get('strength_denominator'),
                dosage_field.get('strength_denominator_unit'),
                dosage_field.get('dosage'),
                dosage_field.get('dosage_max'),
                dosage_field.get('dosage_unit'),
                dosage_field.get('frequency'),
                dosage_field.get('frequency_max'),
                dosage_field.get('period'),
                dosage_field.get('period_max'),
                dosage_field.get('period_unit'),
                dosage_field.get('duration'),
                dosage_field.get('duration_max'),
                dosage_field.get('duration_unit'),
                dosage_field.get('as_needed'),
                dosage_field.get('indication'),
                derived.get('dosage_per_administration'),
                derived.get('dosage_per_administration_unit'),
                derived.get('dosage_numerator_per_administration'),
                derived.get('dosage_numerator_per_administration_unit'),
                derived.get('dosage_denominator_per_administration'),
                derived.get('dosage_denominator_per_administration_unit'),
                derived.get('dosage_per_period'),
                derived.get('dosage_per_period_unit'),
                derived.get('dosage_numerator_per_period'),
                derived.get('dosage_numerator_per_period_unit'),
                derived.get('dosage_denominator_per_period'),
                derived.get('dosage_denominator_per_period_unit'),
                derived.get('dosage_per_duration'),
                derived.get('dosage_per_duration_unit'),
                derived.get('dosage_numerator_per_duration'),
                derived.get('dosage_numerator_per_duration_unit'),
                derived.get('dosage_denominator_per_duration'),
                derived.get('dosage_denominator_per_duration_unit'),
            )
            db_rows.append(row)
    
    return db_rows


In [None]:
#| export
def process_batch_end_to_end(chunk, conn, max_workers: int = 20):
    """
    Process a CSV batch with parallel prescription processing.
    (Original non-bulk version - kept for reference)
    """
    batch_start = time.time()
    
    records = process_csv_batch(chunk)
    grouped = group_records_by_prescription(records)
    
    all_db_rows = []
    prescriptions_processed = 0
    errors = 0
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(process_single_prescription, prescription_records): prescription_id
            for prescription_id, prescription_records in grouped.items()
        }
        
        for future in as_completed(futures):
            prescription_id = futures[future]
            try:
                db_rows = future.result()
                all_db_rows.extend(db_rows)
                prescriptions_processed += 1
            except Exception as e:
                print(f"Error processing {prescription_id}: {e}")
                errors += 1
    
    if all_db_rows:
        insert_prescription_items(conn, all_db_rows)
    
    batch_time = time.time() - batch_start
    
    return {
        'prescriptions_processed': prescriptions_processed,
        'rows_inserted': len(all_db_rows),
        'errors': errors,
        'batch_time': batch_time
    }

In [None]:
#| export
def process_prescriptions_bulk(prescription_groups: List[tuple], bulk_size: int = 10) -> List[tuple]:
    """Process multiple prescriptions using bulk Polaris API."""
    all_db_rows = []
    
    for i in range(0, len(prescription_groups), bulk_size):
        batch = prescription_groups[i:i + bulk_size]
        
        prescriptions_data = []
        for prescription_id, prescription_records in batch:
            first = prescription_records[0]
            prescription_items = []
            for rec in prescription_records:
                prescription_items.append({
                    'seq': rec['seq'],
                    'code': rec['code'],
                    'drug': rec['drug'],
                    'form': rec['form'],
                    'route': rec['route'],
                    'original_direction': rec['original_direction'],
                    'additional_instructions': rec['additional_instructions']
                })
            
            prescriptions_data.append({
                'prescription_id': first['prescription_id'],
                'pharmacy_name': first['pharmacy_name'],
                'pharmacy_code': first['pharmacy_code'],
                'prescription_date': first['prescription_date'],
                'patient_age': first['patient_age'],
                'patient_gender': first['patient_gender'],
                'prescription_items': prescription_items
            })
        
        polaris_results = call_polaris_api_bulk(prescriptions_data)
        
        for idx, (prescription_id, prescription_records) in enumerate(batch):
            polaris_response = polaris_results[idx]
            
            for i, rec in enumerate(prescription_records):
                result = polaris_response['results'][i]
                rec['target_direction'] = extract_target_direction(result['target_direction'])
            
            first = prescription_records[0]
            parser_items = []
            for rec in prescription_records:
                parser_items.append({
                    'seq': rec['seq'],
                    'code': rec['code'],
                    'drug': rec['drug'],
                    'form': rec['form'],
                    'route': rec['route'],
                    'original_direction': rec['original_direction'],
                    'additional_instructions': rec['additional_instructions'],
                    'target_direction': rec['target_direction']
                })
            
            parser_payload = {
                'prescription_id': first['prescription_id'],
                'pharmacy_name': first['pharmacy_name'],
                'pharmacy_code': first['pharmacy_code'],
                'prescription_date': first['prescription_date'],
                'patient_age': first['patient_age'],
                'patient_gender': first['patient_gender'],
                'prescription_items': parser_items
            }
            
            parser_response = call_parser_api(parser_payload)
            
            for item_response in parser_response['prescription_items']:
                base_rec = next(r for r in prescription_records if r['seq'] == item_response['seq'])
                
                dosage_fields_list = item_response.get('dosage_fields', [])
                if not dosage_fields_list:
                    dosage_fields_list = [{}]
                
                for dosage_seq, dosage_field in enumerate(dosage_fields_list, start=1):
                    derived = derive_all_fields(dosage_field)
                    
                    row = (
                        base_rec['prescription_id'],
                        base_rec['pharmacy_name'],
                        base_rec['pharmacy_code'],
                        base_rec['prescription_date'],
                        base_rec['patient_age'],
                        base_rec['patient_gender'],
                        base_rec['seq'],
                        base_rec['code'],
                        base_rec['drug'],
                        base_rec['form'],
                        base_rec['route'],
                        base_rec['original_direction'],
                        base_rec['additional_instructions'],
                        base_rec['target_direction_manual'],
                        base_rec['target_direction'],
                        dosage_seq,
                        dosage_field.get('text'),
                        dosage_field.get('strength'),
                        dosage_field.get('strength_max'),
                        dosage_field.get('strength_unit'),
                        dosage_field.get('strength_numerator'),
                        dosage_field.get('strength_numerator_unit'),
                        dosage_field.get('strength_denominator'),
                        dosage_field.get('strength_denominator_unit'),
                        dosage_field.get('dosage'),
                        dosage_field.get('dosage_max'),
                        dosage_field.get('dosage_unit'),
                        dosage_field.get('frequency'),
                        dosage_field.get('frequency_max'),
                        dosage_field.get('period'),
                        dosage_field.get('period_max'),
                        dosage_field.get('period_unit'),
                        dosage_field.get('duration'),
                        dosage_field.get('duration_max'),
                        dosage_field.get('duration_unit'),
                        dosage_field.get('as_needed'),
                        dosage_field.get('indication'),
                        derived.get('dosage_per_administration'),
                        derived.get('dosage_per_administration_unit'),
                        derived.get('dosage_numerator_per_administration'),
                        derived.get('dosage_numerator_per_administration_unit'),
                        derived.get('dosage_denominator_per_administration'),
                        derived.get('dosage_denominator_per_administration_unit'),
                        derived.get('dosage_per_period'),
                        derived.get('dosage_per_period_unit'),
                        derived.get('dosage_numerator_per_period'),
                        derived.get('dosage_numerator_per_period_unit'),
                        derived.get('dosage_denominator_per_period'),
                        derived.get('dosage_denominator_per_period_unit'),
                        derived.get('dosage_per_duration'),
                        derived.get('dosage_per_duration_unit'),
                        derived.get('dosage_numerator_per_duration'),
                        derived.get('dosage_numerator_per_duration_unit'),
                        derived.get('dosage_denominator_per_duration'),
                        derived.get('dosage_denominator_per_duration_unit'),
                    )
                    all_db_rows.append(row)
    
    return all_db_rows


In [None]:
#| export
def process_batch_parallel_bulk(chunk, conn, max_workers: int = 20, bulk_size: int = 10):
    """Process a CSV batch with parallel prescription processing using bulk API."""
    batch_start = time.time()
    
    records = process_csv_batch(chunk)
    grouped = group_records_by_prescription(records)
    
    prescription_groups = list(grouped.items())
    all_db_rows = []
    errors = 0
    
    # Process in parallel batches
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        chunk_size = max(1, len(prescription_groups) // max_workers)
        for i in range(0, len(prescription_groups), chunk_size):
            worker_groups = prescription_groups[i:i + chunk_size]
            if worker_groups:
                futures.append(executor.submit(process_prescriptions_bulk, worker_groups, bulk_size))
        
        for future in as_completed(futures):
            try:
                db_rows = future.result()
                all_db_rows.extend(db_rows)
            except Exception as e:
                print(f"Error processing batch: {e}")
                errors += 1
    
    if all_db_rows:
        insert_prescription_items(conn, all_db_rows)
    
    batch_time = time.time() - batch_start
    prescriptions_processed = len(prescription_groups)
    
    return {
        'prescriptions_processed': prescriptions_processed,
        'rows_inserted': len(all_db_rows),
        'errors': errors,
        'batch_time': batch_time
    }


In [None]:
#| export
if __name__ == "__main__":
    import sys
    
    filepath = sys.argv[1] if len(sys.argv) > 1 else 'dataset/prescriptions.csv'
    chunksize = int(sys.argv[2]) if len(sys.argv) > 2 else 100
    max_workers = int(sys.argv[3]) if len(sys.argv) > 3 else 20
    bulk_size = int(sys.argv[4]) if len(sys.argv) > 4 else 10
    
    conn = get_db_connection()
    
    total_prescriptions = 0
    total_rows = 0
    total_errors = 0
    batch_num = 0
    overall_start = time.time()
    
    print(f"Starting processing: {filepath}")
    print(f"Batch size: {chunksize} CSV rows")
    print(f"Parallel workers: {max_workers}")
    print(f"Bulk API size: {bulk_size} prescriptions\n")
    
    for chunk in read_csv_batches(filepath, chunksize):
        batch_num += 1
        
        result = process_batch_parallel_bulk(chunk, conn, max_workers, bulk_size)
        
        total_prescriptions += result['prescriptions_processed']
        total_rows += result['rows_inserted']
        total_errors += result['errors']
        
        elapsed = time.time() - overall_start
        rate = total_prescriptions / elapsed if elapsed > 0 else 0
        
        print(f"Batch {batch_num}: {result['prescriptions_processed']} prescriptions, "
              f"{result['rows_inserted']} rows in {result['batch_time']:.2f}s | "
              f"Total: {total_prescriptions} ({rate:.1f}/s)")
    
    conn.close()
    
    total_time = time.time() - overall_start
    print(f"\nCompleted! Total: {total_prescriptions} prescriptions, "
          f"{total_rows} rows, {total_errors} errors in {total_time:.2f}s")