### Setup methods and data utils

## Steps:
    - This assumes you already have jupyter-notebooks installed in your python environment
    - Cells 1-4 can be run to accomplish the following:
        - import packages
        - setup variables for the exercise
        - establish utility methods
        - setup db and process files
    - The remaining cells will perform the analyses and output results

In [None]:
from pathlib import Path
from collections import OrderedDict
from dateutil.parser import parse
import csv
import json
import sqlite3
import datetime
import typing as t
import copy

In [None]:
T = t.TypeVar("T")
pharmacy_dir = Path.cwd().joinpath('pharmacy_claims')
file_type = ['enrollment', 'pharmacy_claims', 'analysis']

In [None]:
schema_dict = {
    "enrollment": {
        "fields": [
            {"name": "birth_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "card_id", "type": "text", "format": ""},
            {"name": "enrollment_end_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "enrollment_start_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "first_name", "type": "text", "format": ""},
            {"name": "gender", "type": "text", "format": ""},
            {"name": "last_name", "type": "text", "format": ""}
        ]
    },
    "pharmacy_claims": {
        "fields": [
            {"name": "allowed_amount", "type": "text", "format": ""},
            {"name": "card_id", "type": "text", "format": ""},
            {"name": "claim_line_number", "type": "text", "format": ""},
            {"name": "claim_status", "type": "text", "format": ""},
            {"name": "days_supply", "type": "text", "format": ""},
            {"name": "fill_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "maintenance_drug_flag", "type": "text", "format": ""},
            {"name": "medication_name", "type": "text", "format": ""},
            {"name": "ndc_11", "type": "text", "format": ""},
            {"name": "paid_date", "type": "text", "format": ""},
            {"name": "pharmacy_id", "type": "text", "format": ""},
            {"name": "pharmacy_name", "type": "text", "format": ""},
            {"name": "pharmacy_npi", "type": "text", "format": ""},
            {"name": "pharmacy_tax_id", "type": "text", "format": ""},
            {"name": "prescriber_id", "type": "text", "format": ""},
            {"name": "prescriber_npi", "type": "text", "format": ""},
            {"name": "quantity_dispensed", "type": "text", "format": ""},
            {"name": "refill_number", "type": "text", "format": ""},
            {"name": "retail_mail_flag", "type": "text", "format": ""},
            {"name": "run_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "rxtype", "type": "text", "format": ""},
            {"name": "specialty_drug_flag", "type": "text", "format": ""},
            {"name": "strength_units", "type": "text", "format": ""},
            {"name": "strength_value", "type": "text", "format": ""},
            {"name": "uploadDate", "type": "date", "format": "%Y-%m-%d"},
            {"name": "claim_number", "type": "text", "format": ""}
        ]
    },
    "analysis": {
        "fields": [
            {"name": "card_id", "type": "text", "format": ""},
            {"name": "claim_number", "type": "text", "format": ""},
            {"name": "fill_date", "type": "date", "format": "%Y-%m-%d"},
            {"name": "ndc_11", "type": "text", "format": ""},
            {"name": "claim_status", "type": "text", "format": ""},
            {"name": "days_supply", "type": "text", "format": ""},
            {"name": "allowed_amount", "type": "text", "format": ""}
        ]
    }
}

In [None]:
def get_json_from_file(json_file_path: Path):
    """load json file and parse as dict"""
    with open(json_file_path) as fhandle:
        filestr = fhandle.read()
    return json.loads(filestr)

class CSVDialect(csv.Dialect):
    """Dialect for file ingestion"""
    
    delimiter = ","
    doublequote = True
    escapechar = "\\"
    quotechar = '"'
    strict = True
    lineterminator = "\r\n"
    quoting = csv.QUOTE_MINIMAL
    
def convert_ndc(ndc11: str):
    """Convert between ndc formats 11 > 9"""
    ndc11_fmt = f'{ndc11[:5]}-{ndc11[5:]}'
    return ndc11_fmt[:-2]
    
def generate_ddl_scripts(schemas: dict):
    """creates a ddl file from the schema json"""
    
    ddl_list = []
    for schema_type in file_type: 
        schema = schema_dict.get(schema_type)
    
        expected_headers = [(field["name"].lower(), field["type"]) for field in schema["fields"]]
        file_ddl = f"""create table if not exists "{schema_type}" (
            {', '.join(f'"{h[0]}" {h[1]}' for h in expected_headers)}
        )""";
        ddl_list.append(file_ddl)
    return ddl_list
              
def establish_sqlite_db(
    path: Path = None, 
    schemas: dict = schema_dict) -> sqlite3.Connection:
    """creates a sqlite db with tables and returns a connection object"""
    
    if path != "" and Path(path).exists():
        print(f"Database {path!s} already exists. Deleting..")
        Path(path).unlink()
    
    conn = sqlite3.connect(path, check_same_thread=False)
    print(f"DB setup at {path}")
    
    # generate schema ddl and add tables
    ddl_list = generate_ddl_scripts(schemas)
    cur = conn.cursor()
    for ddl in ddl_list:
        cur.execute(ddl)
    cur.close()
    return conn

def get_line_count(path: Path) -> int:
    """create buffered read generator to determin line count of file"""
    def _make_gen(reader):
        while True:
            b = reader(2**16)
            if not b:
                break
            yield b
    with open(path, "rb") as f:
        count = sum(buf.count(b"\n") for buf in _make_gen(f.raw.read))
    return int(count)

def csv_to_ordered_dict(csv_path: Path, fields: list):
    retval, quarantine = [], []
    with open(csv_path, newline="") as csvfile:
        reader = csv.DictReader(csvfile, dialect=CSVDialect)
        rownum = 1
        for row in reader:
            if len(row.keys()) != len(fields):
                print(f'Row {rownum}: Wrong number of fields found')
                quarantine.append(row)
            else:
                retval.append(row)
            rownum += 1
        return retval, quarantine

def process_data_file(data_path: Path, schema: dict):
    """process and clean data files; isolate problematic records"""
    
    fields = schema["fields"]
    fieldnames = [f["name"] for f in fields]
    data, quarantine = csv_to_ordered_dict(data_path, fieldnames)
    
    rows = []
    for record in data:
        base_record = OrderedDict([(field["name"], "") for field in fields])
        for field in fields:
            value = record[field["name"]].strip()
            newvalue = None
            if field["type"] == 'date':
                if value != "":
                    try:
                        datetime.datetime.strptime(value, field["format"])
                    except ValueError:
                        newvalue = parse(value)
            if newvalue:
                base_record[field["name"]] = newvalue
            else:
                base_record[field["name"]] = value
        rows.append(base_record)
    return rows, quarantine


def load_data_to_db(
    data: t.Iterable[T], 
    schema_type:str, 
    schema: dict, 
    db_path: Path,
    conn: sqlite3.Connection):
    """load processed files into db"""
    
    fields = schema["fields"]
    ordered_fields = ", ".join(f'"{f["name"]}"' for f in fields)
    qmarks = ",".join(["?"] * len(fields))
    insert_stmt = f'insert into "{schema_type}" ({ordered_fields}) values ({qmarks});'
    
    data_gen = (rec for rec in data)
    cur = conn.cursor()
    while True:
        try:
            record_dict = next(data_gen)
            record = [v for k,v in record_dict.items()]
        except StopIteration:
            print(f"StopIteration Encountered. Ending db load for {schema_type}.")
            break
        
        try:
            cur.execute(insert_stmt, [*record])
        except (sqlite3.Error, sqlite3.OperationalError, sqlite3.IntegrityError,) as sql_er:
            error_args = " ".join(sql_er.args)
            print(f"SQLiiteError: {sql_er.__class__} - {error_args}")
    conn.commit()
    return
    
    
def ingest_data_files(schemas: dict, file_type: list):
    # setup db
    db_path = Path.cwd().joinpath('testDB.db')
    conn = establish_sqlite_db(db_path, schema_dict)
    
    # collect data files
    for tfile in file_type:
        data_file_paths = list(Path.cwd().glob(f'**/{tfile}*.csv'))
        if not data_file_paths:
            continue
            
        schema = schema_dict.get(tfile)
        for data_path in data_file_paths:
            line_count = get_line_count(data_path)
            rowdata, quarantine = process_data_file(data_path, schema)
            print(f"Starting db load of {data_path.name}: {line_count- 1} records expected")
            load_data_to_db(rowdata, tfile, schema, db_path, conn)     
    return conn

### Process and load files into db

In [None]:
conn = ingest_data_files(schema_dict, file_type)

### Perform analysis

#### Question 1. How many patients were enrolled in the program as of July 1st, 2020?

In [None]:
cur = conn.cursor()
cur.execute("select count(distinct card_id) from enrollment where enrollment_start_date >= date('2020-07-01');")
cur.fetchone()[0]

#### Question 2: how many rows are there in the initial pharmacy claims data set?

##### This response assumes initial pharmacy claims data set includes all files supplied

In [None]:
cur = conn.cursor()
cur.execute("select count(*) from pharmacy_claims;")
cur.fetchone()[0]

#### Question 3: How many prepared claims do you have at the end of step 3?

In [None]:
cur = conn.cursor()
insert_query = """
insert into analysis
WITH paid AS (
    SELECT
        card_id,
        claim_number,
        fill_date,
        ndc_11,
        claim_status,
        allowed_amount,
        days_supply
    FROM
        pharmacy_claims
    WHERE
        claim_status in('PAID')
),
reversal AS (
    SELECT
        card_id,
        claim_number,
        fill_date,
        ndc_11,
        claim_status,
        allowed_amount,
        days_supply
    FROM
        pharmacy_claims
    WHERE
        claim_status in('REVERSAL')
),
paid_reversed AS (
    SELECT
        p.card_id,
        p.claim_number,
        p.fill_date,
        p.ndc_11,
        p.claim_status,
        r.claim_status AS r_claim_status,
        p.days_supply,
        r.days_supply AS r_days_supply,
        p.days_supply + r.days_supply AS net_days_supply,
        p.allowed_amount,
        r.allowed_amount AS r_allowed_amount,
        p.allowed_amount + r.allowed_amount AS net_allowed_amount
    FROM
        paid AS p
        INNER JOIN reversal r ON p.card_id = r.card_id
            AND p.claim_number = r.claim_number
            AND p.fill_date = r.fill_date
            AND p.ndc_11 = r.ndc_11
)
SELECT
    *
FROM (
    SELECT
        ph.card_id,
        ph.claim_number,
        ph.fill_date,
        ph.ndc_11,
        ph.claim_status,
        ph.days_supply,
        ph.allowed_amount
    FROM
        pharmacy_claims ph
WHERE
    claim_status <> 'DENIED'
EXCEPT
SELECT
    pr.card_id,
    pr.claim_number,
    pr.fill_date,
    pr.ndc_11,
    pr.claim_status,
    pr.days_supply,
    pr.allowed_amount
FROM
    paid_reversed AS pr
WHERE
    claim_status = 'PAID'
    AND pr.net_days_supply <= 0
EXCEPT
SELECT
    pr2.card_id,
    pr2.claim_number,
    pr2.fill_date,
    pr2.ndc_11,
    pr2.r_claim_status AS claim_status,
    pr2.r_days_supply AS days_supply,
    pr2.r_allowed_amount AS allowed_amount
FROM
    paid_reversed AS pr2
WHERE
    pr2.r_claim_status = 'REVERSAL') AS foo
WHERE
    foo.days_supply > 0;
"""
cur.execute(insert_query)
conn.commit()
cur.execute('select count(*) from analysis;')
cur.fetchone()[0]

#### Question 4: What is the highest amount_allowed? Which patient and generic drug does it correspond to?

In [None]:
ndc_dict = get_json_from_file(Path.cwd().joinpath('ndc9_lookup.json'))

In [None]:
patient_med_query = """
WITH patient AS (
    SELECT
        replace(e.card_id,
            "ID",
            "") AS card_id,
        e.first_name,
        e.last_name,
        e.gender
    FROM
        enrollment e
    GROUP BY
        e.card_id,
        e.first_name,
        e.last_name,
        e.gender
),
meds AS (
    SELECT
        card_id,
        ndc_11,
        sum(days_supply) AS days_supply,
        sum(allowed_amount) AS allowed_amount
    FROM
        analysis
    WHERE
        fill_date BETWEEN date('2020-01-01')
        AND date('2020-06-30')
    GROUP BY
        card_id,
        ndc_11
)
SELECT
    pt.card_id, pt.first_name, pt.last_name, md.ndc_11, md.days_supply, md.allowed_amount
FROM
    patient pt
    JOIN meds md ON pt.card_id = md.card_id;
"""
cur = conn.cursor()
cur.execute(patient_med_query)
pt_med_results = cur.fetchall()

highest_allowed = pt_med_results[0]
for i in pt_med_results:
    if i[5] > highest_allowed[5]:
        highest_allowed = i
ndc9 = convert_ndc(highest_allowed[3])
generic_med = ndc_dict[ndc9]['genericName']
print(f'highest allowed_amount: {highest_allowed[5]}')
print(f'highest allowed_amount patient: {highest_allowed[1]} {highest_allowed[2]} {generic_med}')

#### Question 5: How many unique generic names for the patient Abe Lincoln?

In [None]:
pt_summary_query = """
WITH patient AS (
    SELECT DISTINCT
        replace(e.card_id,
            "ID",
            "") AS card_id,
        e.first_name,
        e.last_name,
        e.gender,
        e.enrollment_start_date,
        e.enrollment_end_date
    FROM
        enrollment e
    WHERE
        enrollment_start_date <= date("2020-07-01")
        and(enrollment_end_date = NULL
            OR enrollment_end_date > date("2020-07-01"))
),
meds AS (
    SELECT
        card_id,
        ndc_11,
        sum(days_supply) AS days_supply,
        sum(allowed_amount) AS allowed_amount
    FROM
        analysis
    WHERE
        fill_date BETWEEN date("2020-01-01")
        AND date("2020-06-30")
    GROUP BY
        card_id,
        ndc_11
)
SELECT
    pt.card_id,
    pt.first_name,
    pt.last_name,
    md.ndc_11,
    md.days_supply,
    md.allowed_amount
FROM
    patient pt
    JOIN meds md ON pt.card_id = md.card_id;
"""

cur = conn.cursor()
cur.execute(pt_summary_query)
pt_summary_results = cur.fetchall()

In [None]:
# build a hash table of patients
pt_hash = {}
template = {
    "first_name": "",
    "last_name": "",
    "med_summary": {}
}

med_template = {
    "allowed_amount": None,
    "days_supply": None
}
while pt_summary_results:
    card_id = pt_summary_results[0][0]
    try:
        pt_hash[card_id]
    except KeyError:
        pt_template = copy.deepcopy(template)
        pt_template["first_name"] = pt_summary_results[0][1]
        pt_template["last_name"] = pt_summary_results[0][2]
        pt_hash[card_id] = pt_template
    
    ndc9 = convert_ndc(pt_summary_results[0][3])
    try:
        drug_info = ndc_dict[ndc9]
        generic = drug_info['genericName']
        
        pt_med_template = copy.deepcopy(med_template)
        pt_med_template["allowed_amount"] = pt_summary_results[0][5]
        pt_med_template["days_supply"] = pt_summary_results[0][4]
        pt_hash[card_id]["med_summary"][generic] = pt_med_template
    except KeyError:
        print(f'generic drug not found for ndc: {ndc9}')
    pt_summary_results.pop(0)

In [None]:
for key in pt_hash.keys():
    with open(Path.cwd().joinpath(f"results/patient-ID{key}.json"), "w") as writer:
        writer.write(json.dumps(pt_hash[key], indent=4))
    
    if pt_hash[key]["first_name"] == "Abe" and pt_hash[key]["last_name"] == "Lincoln":
        print(f"Abe Lincoln has [ {len(pt_hash[key]['med_summary'].keys())} ] generic medications")