Autoregressive model that utilizes the decoder transformer architecture using the Keras/Tensor framework. In laymens terms we want to train a model similar to chat gpt, however instead of training on natural language we will use a patients life long healthcare journey through diagnosis codes.

# Data

In [3]:
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from fn_data_prep import create_sequences

We use synthetic claims data. And create a sequence ordered by date to mimic a typical text based corpus.

In [5]:


claims = pd.read_csv('../data/DE1_0_2008_to_2010_Inpatient_Claims_Sample_1.csv')

claims_renamed = claims[['DESYNPUF_ID','CLM_ADMSN_DT','ICD9_DGNS_CD_1']]\
    .rename(columns={'DESYNPUF_ID': 'patient', 'CLM_ADMSN_DT': 'date', 'ICD9_DGNS_CD_1':'dx'})\
    .sort_values(['patient', 'date'])

print(claims_renamed.head(3))

dx_sequences = create_sequences(claims_renamed)

dx_sequences[:3]

            patient      date    dx
0  00013D2EFD8E45D1  20100312  7802
1  00016F745862898F  20090412  1970
2  00016F745862898F  20090831  6186
Created 15140 patient sequences


['1970 6186 29623 3569', '33811 V5789 49121 7366', '42789 5781']

We convert our data into batches and into Tensors.

In [6]:
BATCH_SIZE = 32 # default - how many observations per batch that are fed into our NN
VOCAB_SIZE = 10000 # only consider top 10000 dx by volume
MAX_LEN = 50 # max sequence length

# Convert to a Tensorflow Dataset
sequence_ds = (
    tf.data.Dataset.from_tensor_slices(dx_sequences) # converts it into a dataset where each element in the input becomes a separate data point
    .batch(BATCH_SIZE)
)

This followed by tokenization. Or converting our diagnosis into numeric values.

In [7]:
# Create a vectorisation layer
vectorize_layer = layers.TextVectorization(
    standardize="lower", # This converts our text to lowercase, note some dx contain strings. 
    max_tokens=VOCAB_SIZE, # gives the most prevalent dx an integer token
    output_mode="int",
    output_sequence_length=MAX_LEN + 1, # max length of each of our sequences + 1
)

# Adapt the layer to the training set
vectorize_layer.adapt(sequence_ds)

vocab = vectorize_layer.get_vocabulary()  # To get words back from token indices

# Display the same example converted to ints
example_tokenised = vectorize_layer(dx_sequences[0])
print("original dx sequence\n",dx_sequences[0])
print("token representation after appling vectorization layer\n",example_tokenised.numpy())

original dx sequence
 1970 6186 29623 3569
token representation after appling vectorization layer
 [ 222 1112  377  725    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0]


Prepare our data.

Example: 

222 -> predict 1112

222 + 1112 -> predict 377

222 + 1112 + 377 -> predict 725

In [8]:
def prepare_inputs(text):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.
    """
    text = tf.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1] # contains all tokens except the last one in each sequence
    y = tokenized_sentences[:, 1:] # contains all tokens except the first one, effectively shifting the sequence by one position
    return x, y

train_ds = sequence_ds.map(prepare_inputs)

for x, y in train_ds.take(1):  # Get first batch
    print("First Input (x):", x[0].numpy())  # First observation
    print("First Target (y):", y[0].numpy()) 

First Input (x): [ 222 1112  377  725    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0]
First Target (y): [1112  377  725    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0]
