# Data Exploration and Preprocessing
TCGA Reannotated Ovarian Cancer Clinical Data 

In [300]:
import numpy as np
import pandas as pd
from functools import reduce
import itertools

Import Data

In [301]:
# Villalobos 2018 reannotated TCGA data (https://ascopubs.org/doi/suppl/10.1200/CCI.17.00096)
tcga_ov_1 = pd.read_excel('https://github.com/bmurphy1993/Cancer_Reinforcement_Learning/raw/main/Data/Villalobos_TCGA/ds_CCI.17.00096-1.xlsx',
                          sheet_name='Master clinical dataset'
                          )

tcga_ov_2 = pd.read_excel('https://github.com/bmurphy1993/Cancer_Reinforcement_Learning/raw/main/Data/Villalobos_TCGA/ds_CCI.17.00096-2.xlsx',
                          sheet_name='Months'
                          )

tcga_ov_3 = pd.read_excel('https://github.com/bmurphy1993/Cancer_Reinforcement_Learning/raw/main/Data/Villalobos_TCGA/ds_CCI.17.00096-3.xlsx',
                          sheet_name='clinical_drug_all_OV.txt'
                          )

# TCGA Drug standardization (https://gdisc.bme.gatech.edu/cgi-bin/gdisc/tap5.cgi#)
drugs = pd.read_csv('https://raw.githubusercontent.com/bmurphy1993/Cancer_Reinforcement_Learning/main/Data/DrugCorrection1.csv')
drugs['Correction'] = drugs['Correction'].str.strip()

## Clean and Reorganize

TCGA 3: Clean

In [302]:
# Drop columns with all missing values
print(tcga_ov_3.shape)
print(tcga_ov_3.isnull().sum(), '\n')
tcga_ov_3_drop = tcga_ov_3.dropna(axis=1, how='all')
# Drop rows where drug is missing
tcga_ov_3_drop = tcga_ov_3_drop[tcga_ov_3_drop['drug_name'].notna()]
# Drop rows where start AND end are missing (right now it's actuall OR, see below)
tcga_ov_3_drop = tcga_ov_3_drop.dropna(how='any', subset=['days_to_drug_therapy_end', 'days_to_drug_therapy_start']) #change to how='all' when figure out what to do with start OR end missing

# Standardize drug names
    # Merge using values from TCGA drug standardization (https://gdisc.bme.gatech.edu/cgi-bin/gdisc/tap5.cgi)
tcga_ov_3_clean = tcga_ov_3_drop.merge(drugs, how='left', left_on='drug_name', right_on='OldName').drop(['OldName'], axis=1)
    # Additional replacement rules
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Doxoribicin', 'Correction'] = 'Doxorubicin'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'gemcitabin', 'Correction'] = 'Gemcitabine'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Hexlalen', 'Correction'] = 'Altretamine'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Cisplatin/Gemzar', 'Correction'] = 'Cisplatin' # This applies to only one line, which has another sample just for Gemzar (Gemcitabine)
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Ilex', 'Correction'] = 'Ilex'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'ILIZ', 'Correction'] = 'ILIZ'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Lily', 'Correction'] = 'Lily'

    # Print replacement rules and replace 'drug_name'
drug_name_old = tcga_ov_3_clean['drug_name']
drug_name_new = tcga_ov_3_clean['Correction']
rules = pd.DataFrame({'drug_name_old': drug_name_old, 'drug_name_new': drug_name_new}).drop_duplicates().sort_values(by=['drug_name_old']).reset_index().drop('index', axis=1)

pd.set_option('display.max_rows', None)
print('Replacement Rules:\n', rules, '\n')
pd.reset_option('max_rows')

tcga_ov_3_clean['drug_name'] = tcga_ov_3_clean['Correction']
tcga_ov_3_clean = tcga_ov_3_clean.drop('Correction', axis=1)

    # List of drugs in dataset
drug_list = [x for x in list(tcga_ov_3_clean['drug_name'].drop_duplicates()) if str(x) != 'nan']
drug_list.sort()
print('Unique Drugs:', len(drug_list), '\n', drug_list, '\n')

# Drop where therapy start = therapy end
tcga_ov_3_clean = tcga_ov_3_clean[tcga_ov_3_clean['days_to_drug_therapy_end'] != tcga_ov_3_clean['days_to_drug_therapy_start']] 

(2463, 19)
bcr_patient_barcode                 0
bcr_drug_barcode                    0
days_to_drug_therapy_end          339
days_to_drug_therapy_start        145
days_to_drug_treatment_end       2463
days_to_drug_treatment_start     2463
dosage_units                     2463
drug_category                       0
drug_dosage                      2463
drug_name                          11
initial_course                   2463
number_cycles                     391
regimen_indication                  2
regimen_indication_notes         2328
route_of_administration           243
route_of_administration_notes    2126
therapy_ongoing                   167
total_dose                        746
total_dose_units                  715
dtype: int64 

Replacement Rules:
                         drug_name_old            drug_name_new
0                      5F4 Leucovorin  Fluorouracil+Leucovorin
1                         90Y-HU3S193                  Hu3S193
2                             AMG 706      

TCGA 3: Fix/standardize time variables and fix order of therapy lines

In [303]:
# Fix values where start and end are switched
tcga_ov_3_clean.loc[tcga_ov_3_clean['days_to_drug_therapy_start'] > tcga_ov_3_clean['days_to_drug_therapy_end'], ['days_to_drug_therapy_start', 'days_to_drug_therapy_end']] = tcga_ov_3_clean.loc[tcga_ov_3_clean['days_to_drug_therapy_start'] > tcga_ov_3_clean['days_to_drug_therapy_end'], ['days_to_drug_therapy_end', 'days_to_drug_therapy_start']].values

# Set earliest drug therapy start to zero and subtract everything else by min days
ther_start = tcga_ov_3_clean.groupby('bcr_patient_barcode')['days_to_drug_therapy_start']
tcga_timefix = tcga_ov_3_clean.assign(start_day=ther_start.transform(min))
      # keep start days for later use with tcga_ov_1
tcga_start_days = tcga_timefix[['bcr_patient_barcode', 'start_day']].drop_duplicates()

tcga_timefix['therapy_start'] = tcga_timefix['days_to_drug_therapy_start'] - tcga_timefix['start_day']
tcga_timefix['therapy_end'] = tcga_timefix['days_to_drug_therapy_end'] - tcga_timefix['start_day']
tcga_timefix = tcga_timefix.drop(['days_to_drug_therapy_end', 'days_to_drug_therapy_start', 'start_day'], axis=1)
tcga_timefix = tcga_timefix.sort_values(by=['bcr_patient_barcode', 'therapy_start', 'therapy_end'])

# Set up state list for each patient: barcode, timing, drug combo
tcga_drug_lines = []
for barcode in tcga_timefix['bcr_patient_barcode'].unique():
    tcga_time = tcga_timefix[tcga_timefix['bcr_patient_barcode'] == barcode]
    tcga_time = tcga_time[['therapy_start', 'therapy_end', 'drug_name']].drop_duplicates(keep='first').values.tolist() # Drop duplicate drugs that have different dosages or administration but same timing

    points = [] # list of (offset, plus/minus, drug) tuples
    for start,stop,drug in tcga_time:
        points.append((start,'+',drug))
        points.append((stop,'-',drug))
    points.sort()

    ranges = [] # output list of (start, stop, drug_set) tuples
    current_set = []
    last_start = None
    for offset,pm,drug in points:
        if pm == '+':
            if last_start is not None:
                ranges.append([last_start,offset,list(set(current_set.copy()))])
            current_set.append(drug)
            last_start = offset
        elif pm == '-':
            ranges.append([last_start,offset,list(set(current_set.copy()))])
            current_set.remove(drug)
            last_start = offset

    # Finish off
    if last_start is not None:
        ranges.append([last_start,offset,list(set(current_set.copy()))])

    # Remove the ranges where start = stop
    range_drug = []
    for i in range(len(ranges)):
        if ranges[i][0] != ranges[i][1]: # add condition:  <& (ranges[i][2] != [])> to drop no-drug periods
            range_drug.append(ranges[i])

    # Sort drugs in each drug combo
    for i in range(len(range_drug)):
        range_drug[i][2].sort()

    # Remove overlapping/back-to-back duplicate lines. Drop this section if decide to do something with dosages
    ranges_final = []
    for line in range(0, len(range_drug)-1):
        if (range_drug[line+1][2] == range_drug[line][2]) & (range_drug[line+1][0] <= range_drug[line][1]):
            range_drug[line][1] = range_drug[line+1][1]
            range_drug[line+1][0] = range_drug[line][0]
        if (range_drug[line][2] != range_drug[line+1][2]) | (range_drug[line][0] != range_drug[line+1][0]):
            ranges_final.append(range_drug[line])
    # if (range_drug[len(range_drug)-1][2] != range_drug[len(range_drug)-2][2]) | (range_drug[len(range_drug)-1][0] != range_drug[len(range_drug)-2][0]):
    ranges_final.append(range_drug[len(range_drug)-1])

    # Add the number of previous lines of therapy
    for line in range(len(ranges_final)):
        if line == 0:
            ranges_final[line].append(0) 
        elif ranges_final[line-1][2] == []:
            ranges_final[line].append(ranges_final[line-1][3])
        else:
            ranges_final[line].append(ranges_final[line-1][3] + 1)
    
    # Add treat transition var
    for line in range(len(ranges_final)):
        try:
            if ranges_final[line+1][2] != []:
                ranges_final[line].append(1)
            else:
                ranges_final[line].append(0)
        except IndexError:
            ranges_final[line].append(0)

    # Add patient barcodes
    for line in range(len(ranges_final)):
        ranges_final[line].insert(0, barcode)
    

    tcga_drug_lines.extend(ranges_final)

tcga_drug_lines[3] = ['TCGA-04-1332', 0.0, 151.0, ['Carboplatin', 'Paclitaxel', 'Topotecan'], 0, 0] # Special case to fix. Make sure to check this if make changes above

# Back to df
lines_df = pd.DataFrame(tcga_drug_lines, columns=['bcr_patient_barcode', 'start', 'end', 'therapy', 'previous_lines', 'treat'])

# List of patient barcodes
tcga_barcodes = list(lines_df['bcr_patient_barcode'].unique())

# Notes
    # One thing to be aware of is that this code drops all values where therapy start and therapy end are equal
    # Fix: drops all back to back duplicate treatment, should keep first (see TCGA-04-1365)
tcga_drug_lines

[['TCGA-04-1331', 0.0, 133.0, ['Carboplatin', 'Paclitaxel'], 0, 0],
 ['TCGA-04-1331', 133.0, 445.0, [], 1, 1],
 ['TCGA-04-1331', 445.0, 462.0, ['Dortezomib'], 1, 0],
 ['TCGA-04-1332',
  0.0,
  151.0,
  ['Carboplatin', 'Paclitaxel', 'Topotecan'],
  0,
  0],
 ['TCGA-04-1332', 151.0, 396.0, [], 1, 1],
 ['TCGA-04-1332', 396.0, 578.0, ['Carboplatin', 'Paclitaxel'], 1, 0],
 ['TCGA-04-1332', 578.0, 943.0, [], 2, 1],
 ['TCGA-04-1332', 943.0, 1035.0, ['Carboplatin', 'Docetaxel'], 2, 0],
 ['TCGA-04-1332', 1035.0, 1066.0, [], 3, 1],
 ['TCGA-04-1332', 1066.0, 1127.0, ['Cisplatin', 'Docetaxel'], 3, 1],
 ['TCGA-04-1332', 1127.0, 1188.0, ['Doxorubicin'], 4, 0],
 ['TCGA-04-1336', 0.0, 94.0, ['Carboplatin', 'Docetaxel'], 0, 0],
 ['TCGA-04-1336', 94.0, 129.0, [], 1, 1],
 ['TCGA-04-1336', 129.0, 492.0, ['Docetaxel'], 1, 0],
 ['TCGA-04-1338', 0.0, 212.0, ['Carboplatin', 'Paclitaxel'], 0, 0],
 ['TCGA-04-1338', 212.0, 426.0, [], 1, 1],
 ['TCGA-04-1338', 426.0, 510.0, ['Docetaxel'], 1, 0],
 ['TCGA-04-1338', 

TCGA 1: Clean

In [304]:
# Keep subset of variables 
print('NaNs: ', tcga_ov_1.isnull().sum())
tcga_ov_1_keep = tcga_ov_1[['bcr_patient_barcode',
                            'total_days_overall_survival',
                            'outcome_overall_survival_censoring',
                            # 'vital_status',
                            # 'days_to_tumor_progression',
                            # 'days_to_death',
                            # 'days_to_last_followup',
                            # 'days_to_tumor_recurrence',
                            # 'time_to_failure',  
                            # 'Cycles_of_adjuvant_therapy',
                            # 'Adjuvant_chemotherapy_dose_intensity',
                            'age_at_initial_pathologic_diagnosis',
                            # 'anatomic_organ_subdivision',
                            # 'days_to_birth',
                            # 'initial_pathologic_diagnosis_method',
                            # 'person_neoplasm_cancer_status',
                            # 'pretreatment_history',
                            # 'primary_therapy_outcome_success', # ask EKO about this
                            'race',
                            # 'residual_tumor',
                            # 'site_of_tumor_first_recurrence',
                            # 'tissue_source_site',
                            'tumor_grade',
                            # 'tumor_residual_disease',
                            'tumor_stage',
                            # 'tumor_tissue_site'
                            # 'year_of_initial_pathologic_diagnosis',
                            # 'Days off platinum prior to recurrence 1st line',
                            # 'Last day of platinum 1st line',
                            # 'Chemotherapy number of lines of therapy'
                            ]]

# Drop cases that don't have a survival metric
tcga_ov_1_keep = tcga_ov_1_keep.dropna(subset=['total_days_overall_survival'])
tcga_ov_1_keep.drop(tcga_ov_1_keep[tcga_ov_1_keep['total_days_overall_survival'] == 'cannot assess'].index, inplace=True)

# Only keep samples that are in the cleaned 'lines' data
tcga_ov_1_keep = tcga_ov_1_keep[tcga_ov_1_keep['bcr_patient_barcode'].isin(tcga_barcodes)].sort_values(by=['bcr_patient_barcode']).reset_index().drop('index', axis=1)

# Adjust final survival by start of therapy day
tcga_start_days = tcga_start_days[tcga_start_days['bcr_patient_barcode'].isin(tcga_barcodes)].sort_values(by=['bcr_patient_barcode']).reset_index().drop('index', axis=1)
tcga_ov_1_keep['total_days_overall_survival'] = tcga_ov_1_keep['total_days_overall_survival'] - tcga_start_days['start_day'] # this needs work, getting some negative values

tcga_ov_1_keep

NaNs:  bcr_patient_barcode                                 4
total_days_overall_survival                        28
outcome_overall_survival_censoring                 16
vital_status                                       34
days_to_tumor_progression                         570
                                                 ... 
5th_chemo_regimen_days_outcome                    524
6th_chemo_regimen_days_outcome                    552
Days off platinum prior to recurrence 1st line     99
Last day of platinum 1st line                     108
Chemotherapy number of lines of therapy            27
Length: 72, dtype: int64


Unnamed: 0,bcr_patient_barcode,total_days_overall_survival,outcome_overall_survival_censoring,age_at_initial_pathologic_diagnosis,race,tumor_grade,tumor_stage
0,TCGA-04-1331,1300,1,79.0,WHITE,G3,IIIC
1,TCGA-04-1332,1217,1,70.0,WHITE,G3,IIIC
2,TCGA-04-1336,1445,0,55.0,WHITE,G3,IIIB
3,TCGA-04-1338,1418,0,78.0,WHITE,G3,IIIC
4,TCGA-04-1342,531,1,80.0,WHITE,G2,IV
...,...,...,...,...,...,...,...
455,TCGA-61-2113,627,1,54.0,WHITE,G3,IIC
456,TCGA-61-2610,1550,1,61.0,WHITE,G3,IIIC
457,TCGA-61-2611,417,1,40.0,WHITE,G3,IIIC
458,TCGA-61-2612,173,1,63.0,WHITE,G3,IIIC


Add death and no treatment dummies to drug lines data

In [305]:
# Merge in final death event for each patient
lines_df_2 = lines_df.merge(tcga_ov_1_keep[['bcr_patient_barcode', 'outcome_overall_survival_censoring']], on='bcr_patient_barcode').rename(columns={'bcr_patient_barcode': 'patient', 'outcome_overall_survival_censoring': 'death'})

# Only keep death event on last line
lines_df_2.loc[lines_df_2['patient'] == lines_df_2['patient'].shift(-1), 'death'] = 0

# Dummy for transition no treatment
lines_df_2['no_treat'] = 1 - lines_df_2['treat'] - lines_df_2['death']

# Function to get indices of a therapy
def get_index_pos(my_list, val):
    return [i for i, x in enumerate(my_list) if x == val]

## Create MDP objects

* This should probably be the start of the next notebook

State and Action set

In [306]:
# States
state_set = ['N', 'T', 'D']

# Combos resulting from actions
combos = []
for i in range(len(tcga_drug_lines)):
    combos.append(tcga_drug_lines[i][3])

combos.sort()
combos = list(combos for combos,_ in itertools.groupby(combos))
print(len(combos))
combos

128


[[],
 ['Abagovomab'],
 ['Aldesleukin'],
 ['Altretamine'],
 ['Amifostine', 'Carboplatin', 'Cisplatin', 'Paclitaxel'],
 ['Amifostine', 'Carboplatin', 'Paclitaxel'],
 ['Amifostine', 'Cisplatin', 'Paclitaxel'],
 ['Amifostine', 'Paclitaxel'],
 ['Aminocamptothecin'],
 ['Anastrozole'],
 ['Anastrozole', 'Doxorubicin'],
 ['Anastrozole', 'Tamoxifen'],
 ['Bevacizumab'],
 ['Bevacizumab', 'Carboplatin'],
 ['Bevacizumab', 'Carboplatin', 'Docetaxel'],
 ['Bevacizumab', 'Carboplatin', 'Gemcitabine'],
 ['Bevacizumab', 'Carboplatin', 'Paclitaxel'],
 ['Bevacizumab', 'Cisplatin', 'Paclitaxel'],
 ['Bevacizumab', 'Cyclophosphamide'],
 ['Bevacizumab', 'Docetaxel'],
 ['Bevacizumab', 'Docetaxel', 'Oxaliplatin'],
 ['Bevacizumab', 'Doxorubicin'],
 ['Bevacizumab', 'Fluorouracil+Leucovorin'],
 ['Bevacizumab', 'Gemcitabine'],
 ['Bevacizumab', 'Paclitaxel'],
 ['Bevacizumab', 'Topotecan'],
 ['Bevacizumab', 'Topotecan', 'Vinorelbine'],
 ['Bevacizumab', 'Vinorelbine'],
 ['CBP501'],
 ['CEP -11981'],
 ['Capecitabine'],
 [

Transition matrix

In [307]:
# this cell is just for looking at stuff
lines_df_2

Unnamed: 0,patient,start,end,therapy,previous_lines,treat,death,no_treat
0,TCGA-04-1331,0.0,133.0,"[Carboplatin, Paclitaxel]",0,0,0,1
1,TCGA-04-1331,133.0,445.0,[],1,1,0,0
2,TCGA-04-1331,445.0,462.0,[Dortezomib],1,0,1,0
3,TCGA-04-1332,0.0,151.0,"[Carboplatin, Paclitaxel, Topotecan]",0,0,0,1
4,TCGA-04-1332,151.0,396.0,[],1,1,0,0
...,...,...,...,...,...,...,...,...
1912,TCGA-61-2610,500.0,561.0,[],2,1,0,0
1913,TCGA-61-2610,561.0,1101.0,[Doxorubicin],2,0,1,0
1914,TCGA-61-2611,0.0,238.0,"[Carboplatin, Paclitaxel]",0,0,1,0
1915,TCGA-61-2612,0.0,131.0,"[Carboplatin, Paclitaxel]",0,0,1,0


In [308]:
# States = {No Treatment, Treatment, Death}
nactions = len(combos) # action set is list of unique drug combos
nstates = len(state_set) # three state probabilities

# Initialize transition matrix
transitions = np.zeros((nactions,nstates-1,nstates))
print('T_a shape:\n',
      '   N  T  D\n',
      'N', transitions[0][0],'\n',
      'T', transitions[0][1])

# Calculate probabilities
for i in range(nactions):
    inds = get_index_pos(list(lines_df_2['therapy']), combos[i])
    
    n_prob = lines_df_2[['no_treat']].iloc[inds, :].sum().values[0] / len(inds)
    t_prob = lines_df_2[['treat']].iloc[inds, :].sum().values[0] / len(inds)
    d_prob = lines_df_2[['death']].iloc[inds, :].sum().values[0] / len(inds)

    if combos[i] == []:
        # N --> 
        transitions[i][0][0] = n_prob
        transitions[i][0][1] = t_prob
        transitions[i][0][2] = d_prob
        # T -->
        transitions[i][1][0] = 0
        transitions[i][1][1] = 0
        transitions[i][1][2] = 0
    else:
        # N --> 
        transitions[i][0][0] = 0
        transitions[i][0][1] = 0
        transitions[i][0][2] = 0
        # T -->
        transitions[i][1][0] = n_prob
        transitions[i][1][1] = t_prob
        transitions[i][1][2] = d_prob

# Get tranistion probabilities function
def get_probs(act_choice, curr_state, trans_set=transitions, act_set=combos, states_set=state_set):
    return trans_set[act_set.index(act_choice)][states_set.index(curr_state)]

transitions[0:10]

T_a shape:
    N  T  D
 N [0. 0. 0.] 
 T [0. 0. 0.]


array([[[0.        , 1.        , 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.8       , 0.        , 0.2       ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.33333333, 0.66666667, 0.        ]]])

In [309]:
# States = {Death, Survive}
nactions = len(combos) # action set is list of unique drug combos
nstates = 2 # two state probabilities

# Initialize transition matrix
transitions_2 = np.zeros((nactions,nstates))

# Calculate probabilities
for i in range(nactions):
    inds = get_index_pos(list(lines_df_2['therapy']), combos[i])

    d_prob = lines_df_2[['death']].iloc[inds, :].sum().values[0] / len(inds)

    transitions_2[i][0] = d_prob
    transitions_2[i][1] = 1 - d_prob

Reward matrix

In [310]:
# get to this later

## Environment

Designed based on: https://github.com/MJeremy2017/reinforcement-learning-implementation/blob/master/GridWorld/gridWorld.py

In [311]:
import random

State Class

In [346]:
# Globals
START_STATE = 'T'
START_COMBO = []
DETERMINISTIC = True
LEARN_RATE = 0.2
EXP_RATE = 0.3

class State:
    def __init__(self, state=START, combo=START_COMBO):
        self.state = state
        self.combo = combo
        self.isEnd = False
        self.determine = DETERMINISTIC

    # Give reward: for now, simple rewards
    def giveReward(self):
        if self.state == 'N':
            return 1 
        elif self.state == 'T':
            return 0.5
        elif self.state == 'D':
            return -1 
    
    def isEndFunc(self):
        if self.state == 'D':
            self.isEnd = True

    # Update step
    def nxtPosition(self, action):
        act = action[0]
        drug = action[1]
        if self.determine:
            # Update drug combo
            if act == '+': 
                self.combo.append(drug)
            elif act == '-':
                self.combo.remove(drug)
            else:
                pass
            self.combo.sort()
            
            # Choose probabilities from transition matrix
            probs = get_probs(self.combo, self.state)
            # Return next state
            nxtState = random.choices(state_set, weights=probs, k=1)[0]
            return nxtState

Agent Class

In [357]:
class Agent:
    def __init__(self):
        self.State = State()        
        self.s_a = []
        self.actions = [['+','Carboplatin'], ['+', 'Paclitaxel'], ['','']]

        self.lr = LEARN_RATE
        self.exp_rate = EXP_RATE

        # initial state and action reward
        self.state_values = {'N':1, 'T':0.5, 'D':-1}       # not sure if need these, need action values instead. Can it be like: Combo: Value - xt?
        self.statevals_list = [1, 0.5, -1]      # for [N, T, D]
        self.s_a_values = {} 
        for i in state_set:
            for j in combos:
                self.s_a_values[repr([i, j])] = 0
    
    # Choose Action: Try with just Carboplatin and Paclitaxel
    def chooseAction(self):
        # choose action with most expected value
        mx_exp_reward = 0
        if self.State.combo == []: 
            self.actions = [['+','Carboplatin'], ['+', 'Paclitaxel'], ['','']]
        elif self.State.combo == ['Carboplatin', 'Paclitaxel']: 
            self.actions = [['-','Carboplatin'], ['-', 'Paclitaxel'], ['','']]
        elif self.State.combo == ['Carboplatin']:
            self.actions = [['+', 'Paclitaxel'], ['-','Carboplatin'], ['','']]
        elif self.State.combo == ['Paclitaxel']:
            self.actions = [['+','Carboplatin'], ['-', 'Paclitaxel'], ['','']]

        if np.random.uniform(0, 1) <= self.exp_rate:
            action = self.actions[np.random.choice(len(self.actions))]
        else:
            # greedy action
            for a in self.actions:
                # if the action is deterministic
                ####################################### Don't need this if use combos as actions
                test_combo = self.State.combo.copy()
                act = a[0]
                drug = a[1]
                if act == '+': 
                    test_combo.append(drug)
                elif act == '-':
                    test_combo.remove(drug)
                else:
                    pass
                test_combo.sort()
                #######################################
                probs = get_probs(test_combo, self.State.state)

                exp_reward = sum(x * y for x, y in zip(probs, self.statevals_list))
                if exp_reward >= mx_exp_reward:
                    action = a
                    mx_exp_reward = exp_reward
        return action

    def takeAction(self, action):
        position = self.State.nxtPosition(action)
        return State(state=position, combo=self.State.combo)

    def reset(self):
        self.s_a = []
        self.State = State()

    def play(self, rounds=10):
        i = 0
        while i < rounds:
            # to the end of game back propagate reward
            if self.State.isEnd:
                # back propagate
                reward = self.State.giveReward()
                # explicitly assign end state to reward values
                self.s_a_values[repr([self.State.state, self.State.combo])] = reward  # this is optional
                for s in reversed(self.s_a):
                    reward = self.s_a_values[repr([s[0], s[1]])] + self.lr * (reward - self.s_a_values[repr([s[0], s[1]])]) # I think i need to update this formula because using (s,a) now
                    self.s_a_values[repr([s[0], s[1]])] = reward
                # Calculate total reward
                total_reward = 0
                for s in self.s_a:
                    total_reward += self.state_values[s[0]]
                print(self.s_a)
                print('Total reward: ', total_reward)

                self.reset()
                i += 1
            else:
                action = self.chooseAction()
                # append trace
                self.s_a.append([self.State.state, self.State.combo])
                # print('current state {} action {}'.format(self.State.state, self.State.combo))
                # by taking the action, it reaches the next state
                self.State = self.takeAction(action)
                # mark is end
                self.State.isEndFunc()
                # print('nxt state', self.State.state)
                # print('---------------------')

Treat patients

In [358]:
if __name__ == '__main__':
    ag = Agent()
    ag.play()
    print(ag.s_a_values)

[['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']]]
Total reward:  1.5
[['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']], ['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']], ['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']], ['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']], ['T', ['Carboplatin', 'Paclitaxel']], ['N', ['Carboplatin', 'Paclitaxel']]]
Total reward:  7.5
[['T', []], ['N', []], ['T', []], ['N', []], ['T', []], ['T', []], ['N', []], ['T', []]]
Total reward:  5.5
[['T', ['Carboplatin']], ['N', ['Carboplatin']]]
Total reward:  1.5
[['T', ['Carboplatin']], ['T', ['Carboplatin']], ['N', ['Carboplatin']], ['T', ['Carboplatin']], ['N', ['Carboplatin']], ['T', ['Carboplatin']], ['N', ['Carboplatin']], ['T', ['Carboplatin']], ['N', ['Carboplatin']]]
Total reward:  6.5
[['T', ['Carboplatin']]]
Total reward:  0.5
[['T', ['Carboplatin', 'Paclitaxel']], ['N', ['C

## Scratch/notes

Notes
* Add in "remission" as a state (i.e. no treatment)
  - Calculate similar to probability of death: probability that a no-treatment period follows a treatment
  - Then need to calculate probability following "remission" for each:
    - Death
    - remission, i.e. no treatment
    - More treatment i.e. "progression/recurrence" 
  - No-treatment ("[]") will not be in the action set
    - Unless already in the no-treatment state
    - Importantly, this means that the agent cannot decide to take a patient off of treatement. Not being on chemo is not a choice made by the AI "doctor", but rather, is a result of the previous treatment. Agent can, however, choose not to add drugs if already achieved remission (i.e. in the no-treatment state) 

* "Grid world" verion of action space
  - 56 dimensional object with up to 5 drugs activated at once
  - Action set is [Start drug {n}, Stop drug {n}, Do nothing]
    - Cannot [Stop drug {n}] if resulting state is no treatment
    - Cannot [Start drug {n}] if resulting len(combo) > 5
    - Cannot start a drug that is already in the combo
  - \> 4mm possible combinations, only 127 drug combos in data - is this necessary? 

* Rewards
  - I think in the final version this should be a probability distribution returning a number of days of survival/until treatment failure after each action
  - For starters, just try to get the average days survived when therapy doesn't result in death.
  - Also need to think of something for when patients don't die
    - Maybe just probability of death and a final non-death stop state?
    - Death vs. no death in the end doesn't matter unless the no deaths get a final reward

* Bellman action value equation: Q(s,a) = r + γ(max(Q(s’,a’))
  - "This says that the Q-value for a given state (s) and action (a) should represent the current reward (r) plus the maximum discounted (γ) future reward expected according to our own table for the next state (s’) we would end up in"

* Replay memory?

Intermittant reward matrix for each patient

In [None]:
# fix because including no drug periods and using df

# tcga_rewards = {}
# for barcode in tcga_lines_keys:
#     rewards = []
#     for line in range(len(tcga_drug_lines[barcode])):
#         try:
#             rewards.append(tcga_drug_lines[barcode][line+1][0] - tcga_drug_lines[barcode][line][0])
#         except:
#             rewards.append(tcga_ov_1_keep.loc[tcga_ov_1_keep['bcr_patient_barcode'] == barcode, 'total_days_overall_survival'].iloc[0] - tcga_drug_lines[barcode][line][0])

#     tcga_rewards[barcode] = rewards

# tcga_rewards

In [None]:
# negs = []
# for barcode in tcga_rewards:
#     if tcga_rewards[barcode][-1] < 0:
#         negs.append(barcode)

# negs

In [None]:
# Add in final state for each patient: [time, 'death']
# Figure out what to do with patients where overall survival is < the end of the last therapy line

In [None]:
# define transition matrix
# def transition(state, action): 
#     if state,action = (living, drug A): 
#         then return 1 if np.random() < 0.8 …. 
        
#     if state, action = (living, drug B), return 1 
    
#     if ..
