In [37]:
import math
import pandas as pd
import numpy as np

In [38]:
data = pd.read_csv('data/warfarin.csv')
data = data[data['Therapeutic Dose of Warfarin'].notnull()]

In [39]:
data = data.reset_index(drop=True) # fix weird row indexing 

In [40]:
def assign_bucket(n):
    if n < 21:
        return 0
    elif n > 49:
        return 2
    else:
        return 1


In [41]:
def create_feature_vec(row):
    x = np.zeros([8,1])
    index = row.index[0]
    
    # Age 
    age =  row.loc[index, 'Age']
    if type(age) != float:
        x[0] = int(age[0])
    
    # Height 
    height = row.loc[index,'Height (cm)']
    if not math.isnan(height):
        x[1] = height
    
    # Weight 
    weight = row.loc[index,'Weight (kg)']
    if not math.isnan(weight):
        x[2] = weight 
    
    race = row.loc[index,'Race']
    # Race
    if race == 'Asian':
        x[3] = 1
    if race == 'Black or African American':
        x[4] = 1
    if race == 'Unknown':
        x[5] = 1

    # enzyme inducer 
    carbamazepine = int(row.loc[index,'Carbamazepine (Tegretol)']) if not math.isnan(row.loc[index,'Carbamazepine (Tegretol)']) else 0
    phenytoin = int(row.loc[index,'Phenytoin (Dilantin)']) if not math.isnan(row.loc[index,'Phenytoin (Dilantin)']) else 0
    rifampin = int(row.loc[index,'Rifampin or Rifampicin']) if not math.isnan(row.loc[index,'Rifampin or Rifampicin']) else 0

    if carbamazepine or phenytoin or rifampin:
        x[6] = 1
    
    # amiodarone 
    amiodarone = int(row.loc[index,'Amiodarone (Cordarone)']) if not math.isnan(row.loc[index,'Amiodarone (Cordarone)']) else 0
    if amiodarone: 
        x[7] = 1
    
    return x



In [42]:
def new_create_feature_vec(row):
    x = np.zeros([17,1])
    index = row.index[0]
    
    # Age 
    age =  row.loc[index, 'Age']
    if type(age) != float:
        x[0] = int(age[0])
    
    # Height 
    height = row.loc[index,'Height (cm)']
    if not math.isnan(height):
        x[1] = height
    
    # Weight 
    weight = row.loc[index,'Weight (kg)']
    if not math.isnan(weight):
        x[2] = weight 

    # VKORC1 
    vkorc1_ag = row.loc[index, 'VKORC1 genotype: -1639 G>A (3673); chr16:31015190; rs9923231; C/T']
    if type(vkorc1_ag) != float:
        if vkorc1_ag == 'A/G':
            x[3] = 1
        if vkorc1_ag == 'A/A':
            x[4] = 1
        if vkorc1_ag == 'NA':
            x[5] = 1
    
    # CYP2C9 
    cyp2c9 = row.loc[index, 'Cyp2C9 genotypes']
    if type(cyp2c9) != float:
        if cyp2c9 == '*1/*2':
            x[6] = 1
        if cyp2c9 == '*1/*3':
            x[7] = 1
        if cyp2c9 == '*2/*2':
            x[8] = 1
        if cyp2c9 == '*2/*3':
            x[9] = 1
        if cyp2c9 == '*3/*3':
            x[10] = 1
        if cyp2c9 == 'NA':
            x[11] = 1
    
    race = row.loc[index,'Race']
    # Race
    if race == 'Asian':
        x[12] = 1
    if race == 'Black or African American':
        x[13] = 1
    if race == 'Unknown':
        x[14] = 1

    # enzyme inducer 
    carbamazepine = int(row.loc[index,'Carbamazepine (Tegretol)']) if not math.isnan(row.loc[index,'Carbamazepine (Tegretol)']) else 0
    phenytoin = int(row.loc[index,'Phenytoin (Dilantin)']) if not math.isnan(row.loc[index,'Phenytoin (Dilantin)']) else 0
    rifampin = int(row.loc[index,'Rifampin or Rifampicin']) if not math.isnan(row.loc[index,'Rifampin or Rifampicin']) else 0

    if carbamazepine or phenytoin or rifampin:
        x[15] = 1
    
    # amiodarone 
    amiodarone = int(row.loc[index,'Amiodarone (Cordarone)']) if not math.isnan(row.loc[index,'Amiodarone (Cordarone)']) else 0
    if amiodarone: 
        x[16] = 1
    
    return x



In [43]:
doses = data['Therapeutic Dose of Warfarin']
doses = doses.apply(lambda dose : assign_bucket(dose)).to_numpy()
print(len(doses))
print(len(data))

5528
5528


In [48]:
r_list = []
action_list = []
T = len(data)
d = 17
R = 1
eps = 1 / np.log(T)
delta = .1
v = R * np.sqrt(24 / eps * d * np.log(1 / delta))

B_list = [np.identity(d), np.identity(d), np.identity(d)]
mu_hat_list = [np.zeros(d), np.zeros(d), np.zeros(d)]
f_list = [np.zeros(d), np.zeros(d), np.zeros(d)]
for t in range(T):
    # row = data.sample(replace=False)
    row = data.loc[[t]]
    x = new_create_feature_vec(row)
    mu_tildes = [np.random.multivariate_normal(mu_hat, v*v*np.linalg.inv(B)) for mu_hat, B in zip(mu_hat_list, B_list)]
    arm_scores = [x.T @ mu_tilde for mu_tilde in mu_tildes]

    best_a = np.argmax(arm_scores)    
    correct_dose = doses[row.index[0]]

    r = 0
    if best_a != correct_dose: 
        r = -1
    r_list.append(r)
    action_list.append(best_a)
    B_list[best_a] += x @ x.T
    f_list[best_a] += x.squeeze() * r
    mu_hat_list[best_a] = np.linalg.inv(B_list[best_a]) @ f_list[best_a]

Traceback (most recent call last):
  File "_pydevd_bundle/pydevd_cython.pyx", line 1078, in _pydevd_bundle.pydevd_cython.PyDBFrame.trace_dispatch
  File "_pydevd_bundle/pydevd_cython.pyx", line 297, in _pydevd_bundle.pydevd_cython.PyDBFrame.do_wait_suspend
  File "/usr/local/lib/python3.9/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 1976, in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
  File "/usr/local/lib/python3.9/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 2011, in _do_wait_suspend
    time.sleep(0.01)
KeyboardInterrupt


KeyboardInterrupt: 

In [46]:
sum(r_list) / len(data)


-0.6635311143270622

In [47]:
action_list

[0,
 2,
 1,
 1,
 1,
 1,
 0,
 0,
 2,
 0,
 1,
 0,
 2,
 2,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 2,
 1,
 0,
 0,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 0,
 1,
 0,
 0,
 2,
 1,
 2,
 0,
 1,
 2,
 0,
 1,
 1,
 1,
 1,
 2,
 2,
 0,
 0,
 0,
 1,
 2,
 2,
 2,
 2,
 0,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 2,
 1,
 1,
 1,
 1,
 2,
 0,
 2,
 1,
 0,
 0,
 1,
 1,
 2,
 2,
 1,
 0,
 0,
 2,
 2,
 1,
 0,
 2,
 2,
 2,
 0,
 1,
 0,
 1,
 2,
 2,
 1,
 0,
 2,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 2,
 0,
 1,
 2,
 2,
 0,
 1,
 2,
 2,
 0,
 2,
 0,
 2,
 2,
 0,
 2,
 0,
 1,
 1,
 2,
 1,
 2,
 0,
 2,
 1,
 2,
 0,
 2,
 0,
 0,
 2,
 2,
 0,
 1,
 0,
 1,
 2,
 2,
 2,
 1,
 1,
 0,
 0,
 0,
 2,
 1,
 0,
 0,
 1,
 2,
 1,
 2,
 2,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 0,
 2,
 1,
 2,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 2,
 0,
 2,
 0,
 2,
 2,
 2,
 1,
 1,
 2,
 1,
 0,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 0,
 1,
 2,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 2,
 0,
 0,
 2,
 2,
 2,
 0,
 0,
 2,
 1,
 0,
 0,
 0,
 2,
 1,
 0,
 0,
