In [1]:
from sklearn.preprocessing import OneHotEncoder
from pymongo import MongoClient
from zipfile import ZipFile, ZIP_DEFLATED
from scipy.sparse import *
from sets import Set
import numpy as np
import json, sys, os, time, re, datetime

def logTime():
    return str(datetime.datetime.now())

In [92]:
%reload_ext autoreload
%autoreload 2
from aca_drug_feature import *
from aca_plan_feature import *
from aca_provider_feature import *

In [3]:
local = False
if local:
    client = MongoClient('fc8iasm01', 27017)
    plan_col = client.aca.plan
    drug_col = client.aca.drug
else:
    client = MongoClient('ec2-54-153-83-172.us-west-1.compute.amazonaws.com', 27017)
    plan_col = client.plans.plans
    drug_col = client.formularies.drugs
    prov_col = client.providers.providers
    faci_col = client.providers.facilities

all_plan = drug_col.distinct('plans.plan_id')
all_drug = drug_col.distinct('rxnorm_id')

print '%s: using %s Mongo, total drug: %d, total plan: %d, total provider: %d' %(
    logTime(), 'local' if local else 'aws', len(all_drug), len(all_plan), prov_col.count())
# client.formularies.scollection_names()
# client.providers.collection_names()

Using aws Mongo, total drug: 46206, total plan: 6035


In [5]:
print '%s: plan document: %d' %(logTime(), plan_col.count())
print '%s: drug document: %d' %(logTime(), drug_col.count())
print '%s: provider document: %d' %(logTime(), prov_col.count())
print '%s: facility document: %d' %(logTime(), faci_col.count())
print '%s: unique plan_id: %d' %(logTime(), len(all_plan))
print '%s: unique rxnorm_id: %d' %(logTime(), len(all_drug))

# multi_plan = [1 for p in plan_col.aggregate([{"$group": {"_id":"$plan_id", "count":{"$sum":1}}}]) if p['count']>1]
# print '%s: plans with multiple documents: %d' %(logTime(), sum(multi_plan))

# multi_drug = [1 for p in drug_col.aggregate([{"$group": {"_id":"$rxnorm_id", "count":{"$sum":1}}}]) if p['count']>1]
# print '%s: drugs with multiple documents: %d' %(logTime(), sum(multi_drug))

state_id = np.unique([i[5:7] for i in all_plan])
print '%s: states in the plan: %s' %(logTime(), ', '.join(state_id))

2016-07-09 09:32:28.660514: plan document: 12136
2016-07-09 09:32:28.744694: drug document: 1540473
2016-07-09 09:32:28.832037: provider document: 8799098
2016-07-09 09:32:28.931804: facility document: 4815321
2016-07-09 09:32:29.014277: unique plan_id: 6035
2016-07-09 09:32:29.014387: unique rxnorm_id: 46206
2016-07-09 09:32:29.016927: states in the plan: AK, AL, AR, AZ, CO, DE, FL, GA, HI, IA, IL, IN, KS, KY, LA, MA, ME, MI, MN, MO, MS, MT, NC, ND, NE, NH, NJ, NM, NV, OH, OK, OR, PA, SC, SD, TN, TX, UT, VA, WA, WI, WV, WY


###Main program

In [94]:
state = 'OR' # set to None to include all (very slow process for all)
ex_id = all_plan if not state else [i for i in all_plan if state in i]
n_plan = len(ex_id)
print '%s: processing %d plans for %s' %(logTime(), len(ex_id), 'all' if not state else state)

print '%s: 1/11 get formulary state space for all plans' %logTime()
all_plan_states = getFormularyAllStates1(plan_col, ex_id) + \
                  getFormularyAllStates2(plan_col, ex_id) + \
                  getFormularyAllStates3(plan_col, ex_id) 
print '%s: total formulary states: %d' %(logTime(), len(all_plan_states))

print '%s: 2/11 extract formulary states for each plan' %logTime()
plan_feature = lil_matrix((n_plan, len(all_plan_states)))
valid_plan1 = []
for f in [getFormularyStatesForPlan1,getFormularyStatesForPlan2,getFormularyStatesForPlan3]:
    for p in f(plan_col, ex_id):
        r_id = ex_id.index(p['_id'])
        valid_plan1.append(p['_id'])
        for s in p['plan_states']:
            plan_feature[r_id, all_plan_states.index(s)] = 1        
print '%s: complete for %d plans' %(logTime(), len(valid_plan1))

# print '%s: 3/11 get formulary summary feature for each plan' %logTime()
# plan_sumstat = [[0]*3]*n_plan
# for p in getFormularyAggregate(plan_col, ex_id):
#     r_id = ex_id.index(p['plan'])
#     plan_sumstat[r_id] = [p['avg_copay'],p['avg_ci_rate'],p['count']]
# print '%s: complete for %d plans' %(logTime(), i)
    
print '%s: 4/11 get all drugs covered by all plans' %logTime()
all_rxnorm = drug_col.find({'plans.plan_id':{'$in':valid_plan1}}).distinct('rxnorm_id')
print '%s: total rx: %d' %(logTime(), len(all_rxnorm))

print '%s: 5/11 check drug coverage for each plan' %logTime()
drug_coverage = lil_matrix((n_plan, len(all_rxnorm)))
valid_plan2 = []
for p in getDrugListForPlans(drug_col, valid_plan1):
    valid_plan2.append(p['plan'])
    r_id = ex_id.index(p['plan'])
    for r in p['drug']:
        drug_coverage[r_id, all_rxnorm.index(r)] = 1
print '%s: complete for %d plans' %(logTime(), len(valid_plan2))

print '%s: 6/11 get summary feature for drug' %logTime()
all_drug_states = getDrugAggregateAllStates(drug_col, valid_plan2)
print '%s: total drug states: %d' %(logTime(), len(all_drug_states))

print '%s: 7/11 extract drug sumstat for each plan' %logTime()
drug_sumstat = lil_matrix((n_plan, len(all_drug_states)))
valid_plan3 = []
for p in getDrugAggregateCountForPlans(drug_col, valid_plan2):
    valid_plan3.append(p['plan'])
    r_id = ex_id.index(p['plan'])
    for d in p['drug_state']:
        drug_sumstat[r_id, all_drug_states.index(d['key'])] = d['cnt']
print '%s: complete for %d plans' %(logTime(), len(valid_plan3))

print '%s: 8/11 get provider under the plans' %logTime()
all_npi = prov_col.find({'plans.plan_id':{'$in':valid_plan3}}).distinct('npi')
print '%s: total providers: %d' %(logTime(), len(all_npi))

print '%s: 9/11 check provider coverage for each plan' %logTime() ##### slow #####
provider_coverage = lil_matrix((n_plan, len(all_npi)))
valid_plan4 = []
for p in getProviderListForPlans(prov_col, valid_plan3):
    valid_plan4.append(p['plan'])
    r_id = ex_id.index(p['plan'])
    for npi in p['npi']:
        provider_coverage[r_id, all_npi.index(npi)] = 1
print '%s: complete for %d plans' %(logTime(), len(valid_plan4))

print '%s: 10/11 get summary feature for provider' %logTime()
all_provider_states = getProviderAllStates(prov_col, valid_plan4)
print '%s: total provider summary: %d' %(logTime(), len(all_provider_states))

print '%s: 11/11 extract provider sumstat for each plan' %logTime()
provider_sumstat = lil_matrix((n_plan, len(all_provider_states)))
valid_plan5 = []
for p in getProviderStateForPlans(prov_col, valid_plan4):
    r_id = ex_id.index(p['_id'])
    valid_plan5.append(p['_id'])
    for d in p['plan_states']:
        provider_sumstat[r_id, all_provider_states.index(d['key'])] = d['count'] #[d['count'], d['location']]
print '%s: complete for %d plans' %(logTime(), len(valid_plan5))

2016-07-09 13:09:18.225952: processing 190 plans for OR
2016-07-09 13:09:18.226057: 1/11 get formulary state space for all plans
2016-07-09 13:09:18.555002: total formulary states: 107
2016-07-09 13:09:18.555753: 2/11 extract formulary states for each plan
2016-07-09 13:09:19.195953: complete for 172 plans
2016-07-09 13:09:19.196544: 4/11 get all drugs covered by all plans
2016-07-09 13:09:22.564035: total rx: 10632
2016-07-09 13:09:22.564778: 5/11 check drug coverage for each plan
2016-07-09 13:10:39.341875: complete for 172 plans
2016-07-09 13:10:39.342471: 6/11 get summary feature for drug
2016-07-09 13:10:50.751716: total drug states: 82
2016-07-09 13:10:50.752472: 7/11 extract drug sumstat for each plan
2016-07-09 13:11:03.301169: complete for 172 plans
2016-07-09 13:11:03.301870: 8/11 get provider under the plans
2016-07-09 13:11:06.175087: total providers: 41031
2016-07-09 13:11:06.175636: 9/11 check provider coverage for each plan
2016-07-09 13:23:10.280560: complete for 164 pl

In [125]:
# combine features
feature_mat = [plan_feature, drug_coverage, drug_sumstat, provider_coverage, provider_sumstat]
n_fea = sum(m.shape[1] for m in feature_mat)
total_feature = lil_matrix((len(valid_plan5), n_fea))
for i in range(len(valid_plan5)):
    r_id = ex_id.index(valid_plan5[i])
    total_feature[i] = hstack([m.getrow(r_id) for m in feature_mat])
print '%s: feature dimension: %s' %(logTime(), total_feature.shape)    


2016-07-09 13:37:29.888316: feature dimension: (164, 56716)


###sparse matrix manual mode

In [127]:
# initialize as lil
# test = lil_matrix((3,18))
# test[2] = range(18)
test[2,10]

10.0

In [128]:
# convert to csr
t2=csr_matrix(total_feature)

In [11]:
client.close()
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y
