# Data Exploration and Preprocessing
TCGA Reannotated Ovarian Cancer Clinical Data 

In [181]:
import numpy as np
import pandas as pd
from functools import reduce
import itertools
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import namedtuple
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import MinMaxScaler

device = torch.device('cpu')

Import Data

In [182]:
# Villalobos 2018 reannotated TCGA data (https://ascopubs.org/doi/suppl/10.1200/CCI.17.00096)
tcga_ov_1 = pd.read_excel('./Data/Villalobos_TCGA/ds_CCI.17.00096-1.xlsx',
                          sheet_name='Master clinical dataset'
                          )

tcga_ov_2 = pd.read_excel('./Data/Villalobos_TCGA/ds_CCI.17.00096-2.xlsx',
                          sheet_name='Months'
                          )

tcga_ov_3 = pd.read_excel('./Data/Villalobos_TCGA/ds_CCI.17.00096-3.xlsx',
                          sheet_name='clinical_drug_all_OV.txt'
                          )

# TCGA Drug standardization (https://gdisc.bme.gatech.edu/cgi-bin/gdisc/tap5.cgi#)
drugs = pd.read_csv('./Data/DrugCorrection1.csv')
drugs['Correction'] = drugs['Correction'].str.strip()

## Clean and Reorganize

TCGA 3: Clean

In [183]:
# Drop columns with all missing values
print(tcga_ov_3.shape)
print(tcga_ov_3.isnull().sum(), '\n')
tcga_ov_3_drop = tcga_ov_3.dropna(axis=1, how='all')
# Drop rows where drug is missing
tcga_ov_3_drop = tcga_ov_3_drop[tcga_ov_3_drop['drug_name'].notna()]
# Drop rows where start AND end are missing (right now it's actuall OR, see below)
tcga_ov_3_drop = tcga_ov_3_drop.dropna(how='any', subset=['days_to_drug_therapy_end', 'days_to_drug_therapy_start']) #change to how='all' when figure out what to do with start OR end missing

# Standardize drug names
    # Merge using values from TCGA drug standardization (https://gdisc.bme.gatech.edu/cgi-bin/gdisc/tap5.cgi)
tcga_ov_3_clean = tcga_ov_3_drop.merge(drugs, how='left', left_on='drug_name', right_on='OldName').drop(['OldName'], axis=1)
    # Additional replacement rules
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Doxoribicin', 'Correction'] = 'Doxorubicin'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'gemcitabin', 'Correction'] = 'Gemcitabine'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Hexlalen', 'Correction'] = 'Altretamine'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Cisplatin/Gemzar', 'Correction'] = 'Cisplatin' # This applies to only one line, which has another sample just for Gemzar (Gemcitabine)
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Ilex', 'Correction'] = 'Ilex'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'ILIZ', 'Correction'] = 'ILIZ'
tcga_ov_3_clean.loc[tcga_ov_3_clean['drug_name'] == 'Lily', 'Correction'] = 'Lily'

    # Print replacement rules and replace 'drug_name'
drug_name_old = tcga_ov_3_clean['drug_name']
drug_name_new = tcga_ov_3_clean['Correction']
rules = pd.DataFrame({'drug_name_old': drug_name_old, 'drug_name_new': drug_name_new}).drop_duplicates().sort_values(by=['drug_name_old']).reset_index().drop('index', axis=1)

pd.set_option('display.max_rows', None)
print('Replacement Rules:\n', rules.head(), '\n')
pd.reset_option('max_rows')

tcga_ov_3_clean['drug_name'] = tcga_ov_3_clean['Correction']
tcga_ov_3_clean = tcga_ov_3_clean.drop('Correction', axis=1)

    # List of drugs in dataset
drug_list = [x for x in list(tcga_ov_3_clean['drug_name'].drop_duplicates()) if str(x) != 'nan']
drug_list.sort()
print('Unique Drugs:', len(drug_list), '\n', drug_list, '\n')

# Drop where therapy start = therapy end
tcga_ov_3_clean = tcga_ov_3_clean[tcga_ov_3_clean['days_to_drug_therapy_end'] != tcga_ov_3_clean['days_to_drug_therapy_start']] 

(2463, 19)
bcr_patient_barcode                 0
bcr_drug_barcode                    0
days_to_drug_therapy_end          339
days_to_drug_therapy_start        145
days_to_drug_treatment_end       2463
days_to_drug_treatment_start     2463
dosage_units                     2463
drug_category                       0
drug_dosage                      2463
drug_name                          11
initial_course                   2463
number_cycles                     391
regimen_indication                  2
regimen_indication_notes         2328
route_of_administration           243
route_of_administration_notes    2126
therapy_ongoing                   167
total_dose                        746
total_dose_units                  715
dtype: int64 

Replacement Rules:
            drug_name_old            drug_name_new
0         5F4 Leucovorin  Fluorouracil+Leucovorin
1            90Y-HU3S193                  Hu3S193
2                AMG 706                Motesanib
3  Abagovomab or Placebo        

TCGA 3: Fix/standardize time variables and fix order of therapy lines

In [184]:
# Fix values where start and end are switched
tcga_ov_3_clean.loc[tcga_ov_3_clean['days_to_drug_therapy_start'] > tcga_ov_3_clean['days_to_drug_therapy_end'], ['days_to_drug_therapy_start', 'days_to_drug_therapy_end']] = tcga_ov_3_clean.loc[tcga_ov_3_clean['days_to_drug_therapy_start'] > tcga_ov_3_clean['days_to_drug_therapy_end'], ['days_to_drug_therapy_end', 'days_to_drug_therapy_start']].values

# Set earliest drug therapy start to zero and subtract everything else by min days
ther_start = tcga_ov_3_clean.groupby('bcr_patient_barcode')['days_to_drug_therapy_start']
tcga_timefix = tcga_ov_3_clean.assign(start_day=ther_start.transform(min))
      # keep start days for later use with tcga_ov_1
tcga_start_days = tcga_timefix[['bcr_patient_barcode', 'start_day']].drop_duplicates()

tcga_timefix['therapy_start'] = tcga_timefix['days_to_drug_therapy_start'] - tcga_timefix['start_day']
tcga_timefix['therapy_end'] = tcga_timefix['days_to_drug_therapy_end'] - tcga_timefix['start_day']
tcga_timefix = tcga_timefix.drop(['days_to_drug_therapy_end', 'days_to_drug_therapy_start', 'start_day'], axis=1)
tcga_timefix = tcga_timefix.sort_values(by=['bcr_patient_barcode', 'therapy_start', 'therapy_end'])

# Set up state list for each patient: barcode, timing, drug combo
tcga_drug_lines = []
for barcode in tcga_timefix['bcr_patient_barcode'].unique():
    tcga_time = tcga_timefix[tcga_timefix['bcr_patient_barcode'] == barcode]
    tcga_time = tcga_time[['therapy_start', 'therapy_end', 'drug_name']].drop_duplicates(keep='first').values.tolist() # Drop duplicate drugs that have different dosages or administration but same timing

    points = [] # list of (offset, plus/minus, drug) tuples
    for start,stop,drug in tcga_time:
        points.append((start,'+',drug))
        points.append((stop,'-',drug))
    points.sort()

    ranges = [] # output list of (start, stop, drug_set) tuples
    current_set = []
    last_start = None
    for offset,pm,drug in points:
        if pm == '+':
            if last_start is not None:
                ranges.append([last_start,offset,list(set(current_set.copy()))])
            current_set.append(drug)
            last_start = offset
        elif pm == '-':
            ranges.append([last_start,offset,list(set(current_set.copy()))])
            current_set.remove(drug)
            last_start = offset

    # Finish off
    if last_start is not None:
        ranges.append([last_start,offset,list(set(current_set.copy()))])

    # Remove the ranges where start = stop
    range_drug = []
    for i in range(len(ranges)):
        if ranges[i][0] != ranges[i][1]: # add condition:  <& (ranges[i][2] != [])> to drop no-drug periods
            range_drug.append(ranges[i])

    # Sort drugs in each drug combo
    for i in range(len(range_drug)):
        range_drug[i][2].sort()

    # Remove overlapping/back-to-back duplicate lines. Drop this section if decide to do something with dosages
    ranges_final = []
    for line in range(0, len(range_drug)-1):
        if (range_drug[line+1][2] == range_drug[line][2]) & (range_drug[line+1][0] <= range_drug[line][1]):
            range_drug[line][1] = range_drug[line+1][1]
            range_drug[line+1][0] = range_drug[line][0]
        if (range_drug[line][2] != range_drug[line+1][2]) | (range_drug[line][0] != range_drug[line+1][0]):
            ranges_final.append(range_drug[line])
    ranges_final.append(range_drug[len(range_drug)-1])

    # Add the number of previous lines of therapy
    for line in range(len(ranges_final)):
        if line == 0:
            ranges_final[line].append(0) 
        elif ranges_final[line-1][2] == []:
            ranges_final[line].append(ranges_final[line-1][3])
        else:
            ranges_final[line].append(ranges_final[line-1][3] + 1)
    
    # Add treat transition var
    for line in range(len(ranges_final)):
        try:
            if (ranges_final[line][2] != []) & (ranges_final[line+1][2] != []):
                ranges_final[line].append(math.ceil((ranges_final[line][1] - ranges_final[line][0]) / 30))
            elif (ranges_final[line][2] != []) & (ranges_final[line+1][2] == []):
                ranges_final[line].append(math.ceil((ranges_final[line][1] - ranges_final[line][0]) / 30) - 1)
            elif (ranges_final[line][2] == []) & (ranges_final[line+1][2] != []):
                ranges_final[line].append(1)
            elif (ranges_final[line][2] == []) & (ranges_final[line+1][2] == []):
                ranges_final[line].append(0)
        except IndexError:
            ranges_final[line].append(0)
    
    # Add no treat transition var
    for line in range(len(ranges_final)):
        try:
            if (ranges_final[line][2] != []) & (ranges_final[line+1][2] != []):
                ranges_final[line].append(0)                
            elif (ranges_final[line][2] != []) & (ranges_final[line+1][2] == []):
                ranges_final[line].append(1)
            elif (ranges_final[line][2] == []) & (ranges_final[line+1][2] != []):
                ranges_final[line].append(math.ceil((ranges_final[line][1] - ranges_final[line][0]) / 30) - 1)
            elif (ranges_final[line][2] == []) & (ranges_final[line+1][2] == []):
                ranges_final[line].append(math.ceil((ranges_final[line][1] - ranges_final[line][0]) / 30))
        except IndexError:
            ranges_final[line].append(math.ceil((ranges_final[line][1] - ranges_final[line][0]) / 30)-1)

    # Add patient barcodes
    for line in range(len(ranges_final)):
        ranges_final[line].insert(0, barcode)
    

    tcga_drug_lines.extend(ranges_final)

tcga_drug_lines[3] = ['TCGA-04-1332', 0.0, 151.0, ['Carboplatin', 'Paclitaxel', 'Topotecan'], 0, 5, 1] # Special case to fix. Make sure to check this if make changes above

# Back to df
lines_df = pd.DataFrame(tcga_drug_lines, columns=['bcr_patient_barcode', 'start', 'end', 'therapy', 'previous_lines', 'treat', 'no_treat'])

# List of patient barcodes
tcga_barcodes = list(lines_df['bcr_patient_barcode'].unique())

# Notes
    # One thing to be aware of is that this code drops all values where therapy start and therapy end are equal
    # Fix: drops all back to back duplicate treatment, should keep first (see TCGA-04-1365)

TCGA 1: Clean

In [185]:
# Keep subset of variables 
print('NaNs: ', tcga_ov_1.isnull().sum())
tcga_ov_1_keep = tcga_ov_1[['bcr_patient_barcode',
                            'total_days_overall_survival',
                            'outcome_overall_survival_censoring',
                            # 'vital_status',
                            # 'days_to_tumor_progression',
                            # 'days_to_death',
                            # 'days_to_last_followup',
                            # 'days_to_tumor_recurrence',
                            # 'time_to_failure',  
                            # 'Cycles_of_adjuvant_therapy',
                            # 'Adjuvant_chemotherapy_dose_intensity',
                            'age_at_initial_pathologic_diagnosis',
                            # 'anatomic_organ_subdivision',
                            # 'days_to_birth',
                            # 'initial_pathologic_diagnosis_method',
                            # 'person_neoplasm_cancer_status',
                            # 'pretreatment_history',
                            # 'primary_therapy_outcome_success', # ask EKO about this
                            'race',
                            # 'residual_tumor',
                            # 'site_of_tumor_first_recurrence',
                            # 'tissue_source_site',
                            'tumor_grade',
                            # 'tumor_residual_disease',
                            'tumor_stage',
                            # 'tumor_tissue_site'
                            # 'year_of_initial_pathologic_diagnosis',
                            # 'Days off platinum prior to recurrence 1st line',
                            # 'Last day of platinum 1st line',
                            # 'Chemotherapy number of lines of therapy'
                            ]]

# Drop cases that don't have a survival metric
tcga_ov_1_keep = tcga_ov_1_keep.dropna(subset=['total_days_overall_survival'])
tcga_ov_1_keep.drop(tcga_ov_1_keep[tcga_ov_1_keep['total_days_overall_survival'] == 'cannot assess'].index, inplace=True)

# Only keep samples that are in the cleaned 'lines' data
tcga_ov_1_keep = tcga_ov_1_keep[tcga_ov_1_keep['bcr_patient_barcode'].isin(tcga_barcodes)].sort_values(by=['bcr_patient_barcode']).reset_index().drop('index', axis=1)

# Adjust final survival by start of therapy day
tcga_start_days = tcga_start_days[tcga_start_days['bcr_patient_barcode'].isin(tcga_barcodes)].sort_values(by=['bcr_patient_barcode']).reset_index().drop('index', axis=1)
tcga_ov_1_keep['total_days_overall_survival'] = tcga_ov_1_keep['total_days_overall_survival'] - tcga_start_days['start_day'] # this needs work, getting some negative values

tcga_ov_1_keep

NaNs:  bcr_patient_barcode                                 4
total_days_overall_survival                        28
outcome_overall_survival_censoring                 16
vital_status                                       34
days_to_tumor_progression                         570
                                                 ... 
5th_chemo_regimen_days_outcome                    524
6th_chemo_regimen_days_outcome                    552
Days off platinum prior to recurrence 1st line     99
Last day of platinum 1st line                     108
Chemotherapy number of lines of therapy            27
Length: 72, dtype: int64


Unnamed: 0,bcr_patient_barcode,total_days_overall_survival,outcome_overall_survival_censoring,age_at_initial_pathologic_diagnosis,race,tumor_grade,tumor_stage
0,TCGA-04-1331,1300.0,1,79.0,WHITE,G3,IIIC
1,TCGA-04-1332,1217.0,1,70.0,WHITE,G3,IIIC
2,TCGA-04-1336,1445.0,0,55.0,WHITE,G3,IIIB
3,TCGA-04-1338,1418.0,0,78.0,WHITE,G3,IIIC
4,TCGA-04-1342,531.0,1,80.0,WHITE,G2,IV
...,...,...,...,...,...,...,...
455,TCGA-61-2113,627.0,1,54.0,WHITE,G3,IIC
456,TCGA-61-2610,1550.0,1,61.0,WHITE,G3,IIIC
457,TCGA-61-2611,417.0,1,40.0,WHITE,G3,IIIC
458,TCGA-61-2612,173.0,1,63.0,WHITE,G3,IIIC


Add death dummy to drug lines data

In [186]:
# Merge in final death event for each patient
lines_df_2 = lines_df.merge(tcga_ov_1_keep[['bcr_patient_barcode', 'outcome_overall_survival_censoring']], on='bcr_patient_barcode').rename(columns={'bcr_patient_barcode': 'patient', 'outcome_overall_survival_censoring': 'death'})

# Only keep death event on last line
lines_df_2.loc[lines_df_2['patient'] == lines_df_2['patient'].shift(-1), 'death'] = 0

# Total events
lines_df_2['total_ev'] = lines_df_2['treat'] + lines_df_2['no_treat'] + lines_df_2['death']

# Function to get indices of a therapy
def get_index_pos(my_list, val):
    return [i for i, x in enumerate(my_list) if x == val]

In [187]:
# Create version of final dataset that only includes patients who died
d_pats = list(tcga_ov_1_keep['bcr_patient_barcode'].loc[tcga_ov_1_keep['outcome_overall_survival_censoring'] == 1])

lines_df_d = lines_df_2[lines_df_2['patient'].isin(d_pats)].reset_index().drop(['index'], axis=1)
lines_df_d

Unnamed: 0,patient,start,end,therapy,previous_lines,treat,no_treat,death,total_ev
0,TCGA-04-1331,0.0,133.0,"[Carboplatin, Paclitaxel]",0,4,1,0,5
1,TCGA-04-1331,133.0,445.0,[],1,1,10,0,11
2,TCGA-04-1331,445.0,462.0,[Dortezomib],1,0,0,1,1
3,TCGA-04-1332,0.0,151.0,"[Carboplatin, Paclitaxel, Topotecan]",0,5,1,0,6
4,TCGA-04-1332,151.0,396.0,[],1,1,8,0,9
...,...,...,...,...,...,...,...,...,...
1259,TCGA-61-2610,500.0,561.0,[],2,1,2,0,3
1260,TCGA-61-2610,561.0,1101.0,[Doxorubicin],2,0,17,1,18
1261,TCGA-61-2611,0.0,238.0,"[Carboplatin, Paclitaxel]",0,0,7,1,8
1262,TCGA-61-2612,0.0,131.0,"[Carboplatin, Paclitaxel]",0,0,4,1,5


Create version of data for regression

In [188]:
def fix_treat(row):
    if (row['patient'] == row['pat_lag']) & (row['ther_lag'] != ''):
        return 1
    else:
        return 0

def fix_death(row):
    if row['treat'] == 1:
        return 0
    else:
        return row['death']

def fix_notreat(row):
    if row['patient'] == row['pat_lag']:
        return (1 - row['treat'] - row['death'])
    else:
        return 0

def treat_state(row):
    if row['ther_str'] == '':
        return 0
    else:
        return 1
    
def make_reg_df(df): 
    df_reg = df.loc[df.index.repeat(df.total_ev)].reset_index(drop=True)
    df_reg['ther_str'] = [','.join(map(str, l)) for l in df_reg['therapy']]
    df_reg['ther_lag'] = df_reg['ther_str'].shift(-1)
    df_reg['pat_lag'] = df_reg['patient'].shift(-1)

    df_reg['treat'] = df_reg.apply(lambda row: fix_treat(row), axis=1) # prob that next state is treatment
    df_reg['death'] = df_reg.apply(lambda row: fix_death(row), axis=1) # prob that next state is no treatment
    df_reg['no_treat'] = df_reg.apply(lambda row: fix_notreat(row), axis=1) # prob that next state is death
    
    df_reg = df_reg.drop(columns=['total_ev', 'ther_lag', 'pat_lag', 'start', 'end'])
    
    df_reg['months'] =  df_reg.groupby('patient').cumcount() # time trend
    
    df_reg['treat_state'] = df_reg.apply(lambda row: treat_state(row), axis=1) # current state
    
    df_reg = df_reg.merge(pd.get_dummies(df_reg['ther_str']), left_index=True, right_index=True)
    
    return df_reg

df_d_reg = make_reg_df(lines_df_d)

df_reg = make_reg_df(lines_df_2)

df_d_reg

Unnamed: 0,patient,therapy,previous_lines,treat,no_treat,death,ther_str,months,treat_state,Unnamed: 10,...,Sargramostim,Sorafenib,Tamoxifen,"Tamoxifen,Topotecan",Topotecan,"Topotecan,Vinorelbine",Trabectedin,Vamydex,Vinorelbine,Vosaroxin
0,TCGA-04-1331,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",0,1,0,...,0,0,0,0,0,0,0,0,0,0
1,TCGA-04-1331,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",1,1,0,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-04-1331,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",2,1,0,...,0,0,0,0,0,0,0,0,0,0
3,TCGA-04-1331,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",3,1,0,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-04-1331,"[Carboplatin, Paclitaxel]",0,0,1,0,"Carboplatin,Paclitaxel",4,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5926,TCGA-61-2612,"[Carboplatin, Paclitaxel]",0,0,0,1,"Carboplatin,Paclitaxel",4,1,0,...,0,0,0,0,0,0,0,0,0,0
5927,TCGA-61-2613,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",0,1,0,...,0,0,0,0,0,0,0,0,0,0
5928,TCGA-61-2613,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",1,1,0,...,0,0,0,0,0,0,0,0,0,0
5929,TCGA-61-2613,"[Carboplatin, Paclitaxel]",0,1,0,0,"Carboplatin,Paclitaxel",2,1,0,...,0,0,0,0,0,0,0,0,0,0


## Create MDP objects

* This should probably be the start of the next notebook

State and Action set

In [189]:
# States
state_set = ['N', 'T', 'D']

# Combos resulting from actions
combos_all = list(lines_df_2['therapy'])
combos_all.sort()
combos_all = list(combos_all for combos_all,_ in itertools.groupby(combos_all))
print(len(combos_all))
print(combos_all[0:10])

combos_all_drugs = combos_all.copy()
combos_all_drugs.remove([])
print(combos_all_drugs[0:10]) 

# Combos for deceased patients
combos_d = list(lines_df_d['therapy'])

combos_d.sort()
combos_d = list(combos_d for combos_d,_ in itertools.groupby(combos_d))
print(len(combos_d))
print(combos_d[0:10])

combos_drugs_d = combos_d.copy()
combos_drugs_d.remove([])
print(combos_drugs_d[0:10]) 

128
[[], ['Abagovomab'], ['Aldesleukin'], ['Altretamine'], ['Amifostine', 'Carboplatin', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Carboplatin', 'Paclitaxel'], ['Amifostine', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Paclitaxel'], ['Aminocamptothecin'], ['Anastrozole']]
[['Abagovomab'], ['Aldesleukin'], ['Altretamine'], ['Amifostine', 'Carboplatin', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Carboplatin', 'Paclitaxel'], ['Amifostine', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Paclitaxel'], ['Aminocamptothecin'], ['Anastrozole'], ['Anastrozole', 'Doxorubicin']]
108
[[], ['Aldesleukin'], ['Altretamine'], ['Amifostine', 'Carboplatin', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Carboplatin', 'Paclitaxel'], ['Amifostine', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Paclitaxel'], ['Aminocamptothecin'], ['Anastrozole'], ['Anastrozole', 'Doxorubicin']]
[['Aldesleukin'], ['Altretamine'], ['Amifostine', 'Carboplatin', 'Cisplatin', 'Paclitaxel'], ['Amifostine', 'Carboplatin', 'Paclitaxel'

Static transition matrix

In [190]:
# States = {No Treatment, Treatment, Death}
combos = combos_d # Switch to _all when use full dataset
combos_drugs = combos_drugs_d

nactions = len(combos) # action set is list of unique drug combos, not including choosing no drugs
nstates = len(state_set) # three state probabilities

# Initialize transition matrix
transitions = np.zeros((nactions,nstates-1,nstates))
print('T_a shape:\n',
      '   N  T  D\n',
      'N', transitions[0][0],'\n',
      'T', transitions[0][1])

# Calculate probabilities
for i in range(nactions):
    inds = get_index_pos(list(lines_df_2['therapy']), combos[i])
    
    n_prob = lines_df_2[['no_treat']].iloc[inds, :].sum().values[0] / lines_df_2[['total_ev']].iloc[inds, :].sum().values[0]
    t_prob = lines_df_2[['treat']].iloc[inds, :].sum().values[0] / lines_df_2[['total_ev']].iloc[inds, :].sum().values[0]
    d_prob = lines_df_2[['death']].iloc[inds, :].sum().values[0] / lines_df_2[['total_ev']].iloc[inds, :].sum().values[0]
    # placeholder so that every action has a chance of death. Will update with regression to calc probs
    # if d_prob == 0:
    #   d_prob = .1
    #   n_prob = n_prob - (d_prob / 2)
    #   t_prob = t_prob - (d_prob / 2)
    ###################################################################################################

    if combos[i] == []:
        # N --> 
        transitions[i][0][0] = n_prob
        transitions[i][0][1] = t_prob
        transitions[i][0][2] = d_prob
        # T -->
        transitions[i][1][0] = 0
        transitions[i][1][1] = 0
        transitions[i][1][2] = 0
    else:
        # N --> 
        transitions[i][0][0] = 0
        transitions[i][0][1] = 0
        transitions[i][0][2] = 0
        # T -->
        transitions[i][1][0] = n_prob
        transitions[i][1][1] = t_prob
        transitions[i][1][2] = d_prob

# Get tranistion probabilities function
def get_probs(act_choice, curr_state, trans_set=transitions, act_set=combos, states_set=state_set):
    return trans_set[act_set.index(act_choice)][states_set.index(curr_state)]

transitions[0:10]

T_a shape:
    N  T  D
 N [0. 0. 0.] 
 T [0. 0. 0.]


array([[[0.84903439, 0.15096561, 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.25      , 0.75      , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.26315789, 0.68421053, 0.05263158]],

       [[0.        , 0.        , 0.        ],
        [0.11111111, 0.88888889, 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.33333333, 0.66666667, 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.33333333, 0.66666667, 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.16666667, 0.83333333, 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.5       , 0.5       , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.2       , 0.8       , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 1.        , 0.        ]]])

Transition probability regression

In [191]:
# OLS Linear, event on time and treatment, only patients who died
        # Treatment prob
X = np.array(df_d_reg.loc[:, 'months':])
y = np.array(df_d_reg['treat'])
t_reg = LinearRegression().fit(X, y)

        # No treatment prob
y = np.array(df_d_reg['no_treat'])
n_reg = LinearRegression(fit_intercept=False).fit(X, y)

        # Death prob
y = np.array(df_d_reg['death'])
d_reg = LinearRegression(fit_intercept=False).fit(X, y)

# MinMax and sum to 1 scaling for negative probs
scaler = MinMaxScaler(feature_range=(0.05, 0.95)) # maybe change min to the overall proportion of deaths to total transitions

def scaleMM(inprobs):
    scaler.fit(inprobs.reshape(-1, 1))
    outprobs = scaler.transform(inprobs)
    outprobs = outprobs / outprobs.sum()
    return outprobs

# Function to calculate transition probabilities based on regressions
def reg_probs(state, months, action):
    act_ind = combos.index(action) + 2
    x_test = np.zeros(110)
    x_test[0] = months
    x_test[1] = state_set.index(state)
    x_test[act_ind] = 1
    x_test = x_test.reshape(1,-1)
    probs = np.array([[n_reg.predict(x_test)[0], t_reg.predict(x_test)[0], d_reg.predict(x_test)[0]]])
    probs = scaleMM(probs)[0].tolist()
    return probs

# Does it work??
reg_probs('N', 0, ['Carboplatin', 'Paclitaxel'])

[0.047619047616918465, 0.9047619047214508, 0.047619047661630644]

In [192]:
# If want formatted regression output
import statsmodels.api as sm
# from scipy import stats

# X = sm.add_constant(X)
est = sm.OLS(y, X)
est2 = est.fit()
print(est2.summary())
est2.predict(x_test)

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       0.079
Model:                            OLS   Adj. R-squared:                  0.062
Method:                 Least Squares   F-statistic:                     4.654
Date:                Tue, 13 Apr 2021   Prob (F-statistic):           1.99e-49
Time:                        15:29:36   Log-Likelihood:                 1647.2
No. Observations:                5931   AIC:                            -3076.
Df Residuals:                    5822   BIC:                            -2347.
Df Model:                         108                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
x1             0.0003      0.000      2.397      0.0

array([-0.01596884])

Q-Network: MLP

In [193]:
# Define DQN, one hidden layer MLP for now
class DQN(nn.Module):
    def __init__(self, in_features, hidden_size, out_features):
        super(DQN, self).__init__()
        self.fc = nn.Linear(in_features, hidden_size)
        self.relu = nn.ReLU()
        self.fc_out = nn.Linear(hidden_size, out_features)

    def forward(self, x):
        out = self.fc(x)
        out = self.relu(out)
        out = self.fc_out(out)
        return out

Replay Memory

https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

In [194]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

Model training loop

https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

In [195]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.stack([s for s in batch.next_state
                                                if s is not None])

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

## Environment

Designed based on: https://github.com/MJeremy2017/reinforcement-learning-implementation/blob/master/GridWorld/gridWorld.py

State Class

In [196]:
len(combos_drugs)

107

In [197]:
# Globals
START_STATE = 'T'
START_COMBO = []
GAMMA = 0.1
EXP_RATE = 0.3
BATCH_SIZE = 100
TARGET_UPDATE = 5

policy = DQN(2, 100, len(combos_drugs)).to(device) # input features are the state and the number of previous states
target_net = DQN(2, 100, len(combos_drugs)).to(device)
target_net.load_state_dict(policy.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy.parameters())
memory = ReplayMemory(10000)

class State:
    def __init__(self, state=START_STATE):
        self.state = state
        self.isEnd = False

    # Give reward: for now, simple rewards
    def giveReward(self):
        if self.state == 'N':
            return 1 
        elif self.state == 'T':
            return 1
        elif self.state == 'D':
            return -1 
    
    def isEndFunc(self):
        if self.state == 'D':
            self.isEnd = True

    # Update step
    def nxtPosition(self, action, months):
        # Choose probabilities from transition matrix
#         probs = get_probs(action, self.state) # static probs
        probs = reg_probs(self.state, months, action) # regression probs
        # Return next state
        nxtState = random.choices(state_set, weights=probs, k=1)[0]
        return nxtState

Agent Class

In [198]:
class Agent:
    def __init__(self):
        self.State = State()        
        self.s_a = []
        self.actions = combos_drugs
#         self.lr = LEARN_RATE # this isn't doing anything right now
        self.exp_rate = EXP_RATE
        # initial state and action reward
        self.state_values = {'N':1, 'T':1, 'D':-1}       # not sure if need these, need action values instead. Can it be like: Combo: Value - xt?
        self.statevals_list = [1, 1, -1]      # for [N, T, D]
        self.s_a_values = {} 
        for i in state_set:
            for j in combos:
                self.s_a_values[repr([i, j])] = 0
    
    # Choose Action
    def chooseAction(self):
        # choose action with most expected value
        mx_exp_reward = 0

        if self.State.state == 'N':
            action = []
        else:
            if np.random.uniform(0, 1) <= self.exp_rate:
                action = self.actions[np.random.choice(len(self.actions))]
            else:
                # greedy action
                with torch.no_grad():
                    state_t = torch.FloatTensor([state_set.index(self.State.state), len(self.s_a)])
                    act = policy(state_t).max(0)[1].item()
                    action = self.actions[act]
        return action

    def takeAction(self, action):
        position = self.State.nxtPosition(action, len(self.s_a))
        return State(state=position)

    def reset(self):
        self.s_a = []
        self.State = State()

    def play(self, rounds=20):
        i = 0
        while i < rounds:
            # to the end of game back propagate reward
            if self.State.isEnd:
                # back propagate
                reward = self.State.giveReward()
                # explicitly assign end state to reward values
                # self.s_a_values[repr([self.State.state, action])] = reward  # this is optional
                # for s in reversed(self.s_a):
                #     reward = self.s_a_values[repr([s[0], s[1]])] + self.lr * (reward - self.s_a_values[repr([s[0], s[1]])]) # I think i need to update this formula because using (s,a) now
                #     self.s_a_values[repr([s[0], s[1]])] = reward
                # Calculate total reward
                total_reward = 0
                for s in self.s_a:
                    total_reward += self.state_values[s[0]]
                print(self.s_a)
                print('Total reward: ', total_reward)

                self.reset()
                i += 1
            else:
                action = self.chooseAction()
                try:
                    act_ind = torch.tensor([combos_drugs.index(action)])
                except ValueError:
                    pass
                # append trace
                self.s_a.append([self.State.state, action])
                print('current state {} action {}'.format(self.State.state, action))
                
                s_t = torch.FloatTensor([state_set.index(self.State.state), len(self.s_a)]) 
                # by taking the action, it reaches the next state
                self.State = self.takeAction(action)
                # save transition vals for memory
                s_t1 = torch.FloatTensor([state_set.index(self.State.state), len(self.s_a) + 1])
                r_t = torch.FloatTensor([self.State.giveReward()])
                # store transition in memory
                if action != []:
                    memory.push(s_t, act_ind, s_t1, r_t)
                else:
                    pass
                # perform an optimizaiton step
                optimize_model()
                # update target model
                if i % TARGET_UPDATE == 0:
                    target_net.load_state_dict(policy.state_dict())
                # mark is end
                self.State.isEndFunc()
                print('nxt state', self.State.state)
                print('---------------------')

Treat patients

In [199]:
if __name__ == '__main__':
    ag = Agent()
    ag.play(50)
    # print(ag.s_a_values)

current state T action ['Carboplatin', 'Cyclophosphamide', 'Topotecan']
nxt state T
---------------------
current state T action ['Carboplatin', 'Cyclophosphamide', 'Topotecan']
nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
curr

nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Carboplatin', 'Etoposide']
nxt state T
---------------------
current state T action ['Bevacizumab', 'Docetaxel']
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state T
---------------------
current state T action ['NOS']
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state T
---------------------
current

nxt state T
---------------------
current state T action ['Letrozole']
nxt state D
---------------------
[['T', ['Carboplatin', 'Docetaxel', 'Letrozole', 'Leuprolide']], ['T', ['Letrozole']]]
Total reward:  2
current state T action ['Docetaxel', 'Doxorubicin', 'Paclitaxel', 'Topotecan']
nxt state T
---------------------
current state T action ['Bevacizumab', 'Paclitaxel']
nxt state T
---------------------
current state T action ['Carboplatin', 'Docetaxel', 'Letrozole', 'Leuprolide']
nxt state T
---------------------
current state T action ['Bevacizumab', 'Gemcitabine']
nxt state T
---------------------
current state T action ['Carboplatin', 'Docetaxel', 'Letrozole', 'Leuprolide']
nxt state T
---------------------
current state T action ['Carboplatin', 'Docetaxel', 'Leuprolide']
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state D
-------------------

nxt state T
---------------------
current state T action ['Doxorubicin', 'Gefitinib', 'Topotecan']
nxt state D
---------------------
[['T', ['Bevacizumab', 'Paclitaxel']], ['T', ['Amifostine', 'Paclitaxel']], ['N', []], ['N', []], ['N', []], ['N', []], ['N', []], ['N', []], ['T', ['Carboplatin', 'Doxorubicin', 'Paclitaxel']], ['T', ['Letrozole']], ['T', ['Cisplatin', 'Gemcitabine']], ['T', ['Doxorubicin', 'Gefitinib', 'Topotecan']]]
Total reward:  12
current state T action ['Cisplatin', 'Cyclophosphamide']
nxt state T
---------------------
current state T action ['Carboplatin', 'Docetaxel', 'Leuprolide']
nxt state T
---------------------
current state T action ['Irofulven']
nxt state D
---------------------
[['T', ['Cisplatin', 'Cyclophosphamide']], ['T', ['Carboplatin', 'Docetaxel', 'Leuprolide']], ['T', ['Irofulven']]]
Total reward:  3
current state T action ['Anastrozole', 'Doxorubicin']
nxt state T
---------------------
current state T action ['Gemcitabine']
nxt state N
-----------

nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state T
---------------------
current state T action ['Carboplatin', 'Doxorubicin', 'Paclitaxel']
nxt state T
---------------------
current state T action ['Paclitaxel', 'Topotecan']
nxt state T
---------------------
current state T action ['Paclitaxel', 'Topotecan']
nxt state T
---------------------
current state T action ['Carboplatin', 'Doxorubicin', 'Paclitaxel']
nxt state D
---------------------
[['T', ['Sargramostim']], ['N', []], ['N', []], ['N'

nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state T
---------------------
current state T action ['Docetaxel', 'Topotecan']
nxt state T
---------------------
current state T action ['Cisplatin', 'Tamoxifen']
nxt state T
---------------------
current state T action ['Doxorubicin', 'Gemcitabine']
nxt state T
---------------------
current state T action ['Cyclophosphamide']
nxt state T
---------------------
current state T action ['Cyclophosphamide', 'Paclitaxel']
nxt state T
---------------------
current state T action ['Vinorelbine']
nxt state T
---------------------
current state T action ['Etoposide']
nxt state T
---------------------
current state T action ['Cyclophosphamide', 'Paclitaxel']
nxt state D
---------------------
[['T

nxt state T
---------------------
current state T action ['Altretamine']
nxt state T
---------------------
current state T action ['Aminocamptothecin']
nxt state N
---------------------
current state N action []
nxt state N
---------------------
current state N action []
nxt state D
---------------------
[['T', ['Cisplatin', 'Paclitaxel']], ['N', []], ['N', []], ['T', ['Doxorubicin', 'Irinotecan', 'Vinorelbine']], ['T', ['Cyclophosphamide', 'Paclitaxel']], ['T', ['Cyclophosphamide', 'Paclitaxel']], ['T', ['Cyclophosphamide', 'Paclitaxel']], ['T', ['Cisplatin', 'Topotecan']], ['T', ['Altretamine']], ['T', ['Aminocamptothecin']], ['N', []], ['N', []]]
Total reward:  12
current state T action ['Carboplatin', 'Docetaxel', 'Letrozole', 'Leuprolide']
nxt state T
---------------------
current state T action ['Anastrozole', 'Tamoxifen']
nxt state T
---------------------
current state T action ['Cisplatin', 'Topotecan']
nxt state T
---------------------
current state T action ['Altretamine']
nx

## Scratch/notes

Notes
* Add in "remission" as a state (i.e. no treatment)
  - Calculate similar to probability of death: probability that a no-treatment period follows a treatment
  - Then need to calculate probability following "remission" for each:
    - Death
    - remission, i.e. no treatment
    - More treatment i.e. "progression/recurrence" 
  - No-treatment ("[]") will not be in the action set
    - Unless already in the no-treatment state
    - Importantly, this means that the agent cannot decide to take a patient off of treatement. Not being on chemo is not a choice made by the AI "doctor", but rather, is a result of the previous treatment. Agent can, however, choose not to add drugs if already achieved remission (i.e. in the no-treatment state) 

* "Grid world" verion of action space
  - 56 dimensional object with up to 5 drugs activated at once
  - Action set is [Start drug {n}, Stop drug {n}, Do nothing]
    - Cannot [Stop drug {n}] if resulting state is no treatment
    - Cannot [Start drug {n}] if resulting len(combo) > 5
    - Cannot start a drug that is already in the combo
  - \> 4mm possible combinations, only 127 drug combos in data - is this necessary? 

* Rewards
  - I think in the final version this should be a probability distribution returning a number of days of survival/until treatment failure after each action
  - For starters, just try to get the average days survived when therapy doesn't result in death.
  - Also need to think of something for when patients don't die
    - Maybe just probability of death and a final non-death stop state?
    - Death vs. no death in the end doesn't matter unless the no deaths get a final reward

* Bellman action value equation: Q(s,a) = r + γ(max(Q(s’,a’))
  - "This says that the Q-value for a given state (s) and action (a) should represent the current reward (r) plus the maximum discounted (γ) future reward expected according to our own table for the next state (s’) we would end up in"

* Replay memory?

Intermittant reward matrix for each patient

In [200]:
# fix because including no drug periods and using df

# tcga_rewards = {}
# for barcode in tcga_lines_keys:
#     rewards = []
#     for line in range(len(tcga_drug_lines[barcode])):
#         try:
#             rewards.append(tcga_drug_lines[barcode][line+1][0] - tcga_drug_lines[barcode][line][0])
#         except:
#             rewards.append(tcga_ov_1_keep.loc[tcga_ov_1_keep['bcr_patient_barcode'] == barcode, 'total_days_overall_survival'].iloc[0] - tcga_drug_lines[barcode][line][0])

#     tcga_rewards[barcode] = rewards

# tcga_rewards

In [201]:
# negs = []
# for barcode in tcga_rewards:
#     if tcga_rewards[barcode][-1] < 0:
#         negs.append(barcode)

# negs

In [202]:
# Add in final state for each patient: [time, 'death']
# Figure out what to do with patients where overall survival is < the end of the last therapy line

In [203]:
# define transition matrix
# def transition(state, action): 
#     if state,action = (living, drug A): 
#         then return 1 if np.random() < 0.8 …. 
        
#     if state, action = (living, drug B), return 1 
    
#     if ..
