<a href="https://colab.research.google.com/github/chuunibian/fhir-claimresource-model/blob/main/fhir.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

So SyntheaPatientModel will take in the json and then use the parse to extract information, it is flexible what to extract but first try to get all of the claims.

In [None]:
'''
Macros
'''
MAX_CLAIM_C = 30
BATCH_SIZE_C = 8
CLAIM_ENCODE_SIZE = 22

In [None]:
from google.colab import drive
drive.mount('/content/drive')
from pathlib import Path
import os

fhir_test_data = "/content/drive/MyDrive/synthea_sample_data_fhir_latest"
data_path = Path(fhir_test_data)

files = os.listdir(data_path)

if files:
  file = data_path / Path(files[5])  # get first file
  print(file)





Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/synthea_sample_data_fhir_latest/Caryl47_Kassulke119_4569671e-ed39-055f-8e78-422b96c9896b.json


In [None]:
# Input Validation Which will be done at begenning
try:
  import fhir.resources
except ImportError:
  !pip install fhir.resources

from fhir.resources.R4B.bundle import Bundle

def validate_fhir_bundle(json_string):
    try:
        bundle = Bundle.model_validate_json(json_string)
        return True, "Valid FHIR Bundle"
    except Exception as e:
        return False, f"Invalid FHIR Bundle: {str(e)}"



In [None]:
from datetime import datetime

class Claim:
  '''
  What goes into claim_instance is a dict rep of a claim resource

  each of claim object represents one claim resource
  '''
  def __init__(self, claim_instance):
        self.claim_id = self._extract_claim_id(claim_instance)

        self.status = self._extract_status(claim_instance)

        self.type_of_claim = self._extract_type_of_claim(claim_instance)

        self.type_of_subclaim = self._extract_type_of_subclaim(claim_instance)

        self.bill_period = self._extract_bill_period(claim_instance)

        self.claim_creation = self._extract_claim_creation(claim_instance)

        self.priority = self._extract_priority(claim_instance)

        self.total = self._extract_total(claim_instance)

        self.number_of_diagnoses = self._extract_number_of_diagnoses(claim_instance)

        self.number_of_items = self._extract_number_of_items(claim_instance)

        self.number_of_drugs = self._extract_number_of_items(claim_instance)  # assuming pharmacy claims

        self.billable_duration_days = self._calculate_billable_duration_days()

        self.insurance = self._extract_insurance(claim_instance)

  def _extract_claim_id(self, claim_instance):
        if 'id' in claim_instance:
            return claim_instance['id']
        return None

  def _extract_insurance(self, claim_instance):
        insurance_entries = claim_instance.get("insurance", [])
        return len(insurance_entries)

  def _extract_status(self, claim_instance):
        if 'status' in claim_instance:
            return claim_instance['status']
        return None

  def _extract_type_of_claim(self, claim_instance):
        if 'type' in claim_instance:
            type_obj = claim_instance['type']
            if 'coding' in type_obj and type_obj['coding'] and 'display' in type_obj['coding'][0]:
                return type_obj['coding'][0]['display']
            elif 'coding' in type_obj and type_obj['coding'] and 'code' in type_obj['coding'][0]:
                return type_obj['coding'][0]['code']
        return None

  def _extract_type_of_subclaim(self, claim_instance):
        if 'subType' in claim_instance:
            subtype_obj = claim_instance['subType']
            if 'coding' in subtype_obj and subtype_obj['coding'] and 'display' in subtype_obj['coding'][0]:
                return subtype_obj['coding'][0]['display']
            elif 'coding' in subtype_obj and subtype_obj['coding'] and 'code' in subtype_obj['coding'][0]:
                return subtype_obj['coding'][0]['code']
        return None

  def _extract_bill_period(self, claim_instance):
    if 'billablePeriod' in claim_instance:
        period = claim_instance['billablePeriod']
        start = period.get('start')
        end = period.get('end')

        def convert_to_numeric(date_str):
            dt = datetime.fromisoformat(date_str)
            return int(dt.strftime('%Y%m%d%H%M%S'))

        if start and end:
            return {
                'start': convert_to_numeric(start),
                'end': convert_to_numeric(end)
            }
        elif start:
            return {'start': convert_to_numeric(start)}
        elif end:
            return {'end': convert_to_numeric(end)}

    return None

  def _extract_claim_creation(self, claim_instance):
        if 'created' in claim_instance:
            return claim_instance['created']
        return None

  def _extract_priority(self, claim_instance):
        if 'priority' in claim_instance:
            priority_obj = claim_instance['priority']
            if 'coding' in priority_obj and priority_obj['coding'] and 'display' in priority_obj['coding'][0]:
                return priority_obj['coding'][0]['display']
            elif 'coding' in priority_obj and priority_obj['coding'] and 'code' in priority_obj['coding'][0]:
                return priority_obj['coding'][0]['code']
        return None

  def _extract_total(self, claim_instance):
        if 'total' in claim_instance:
            total_obj = claim_instance['total']
            value = total_obj.get('value')
            currency = total_obj.get('currency')

            if value is not None:
                if currency:
                    return {'value': value, 'currency': currency}
                return {'value': value}
        return None

  # TODO THIS IS WRONG???
  def _extract_number_of_diagnoses(self, claim_instance):
        diagnoses = claim_instance.get("diagnosis", [])
        return len(diagnoses)


  def _extract_number_of_items(self, claim_instance):
        items = claim_instance.get('item', [])
        return len(items)

  def _calculate_billable_duration_days(self):
        start = self.bill_period.get('start')
        end = self.bill_period.get('end')
        if start and end:
            try:
                start_dt = datetime.fromisoformat(start.replace('Z', '+00:00'))
                end_dt = datetime.fromisoformat(end.replace('Z', '+00:00'))
                duration = (end_dt - start_dt).days
                return max(duration, 0)
            except Exception:
                return None
        return None

  def __repr__(self):
        # Custom representation of the Claim object
        return f"Claim(claim_id={self.claim_id})"



If using instance var you need self and you need to pass in self to function so that the funciton knows what object to call on

In [None]:
import json
import os
from typing import Dict, List, Any, Optional, Union

class SyntheaGenerationFHIRCustomParser:
  """
  Custom parser to extract resources and data within those resources
  of R4 FHIR .json data.

  For instance of class it will generate a custom object represeting parsed
  values from the FHIR .json

  Main job is to find [X_type] resource and then pass each instance of found
  resource into constructor for wrapper class then return a list of those wrapper classes
  """

  def __init__(self, json_path: str):
    self.file_path = json_path
    self.parsed_json = {}

    with open(json_path, 'r') as file:
      self.parsed_json = json.load(file)



  def get_list_of_claims(self) -> Dict[int, Claim]:

    list_of_claim = {}
    counter = 0

    # Get all the claim entries
    if 'entry' in self.parsed_json:
      entries = self.parsed_json['entry']
      for entry in entries:
        if 'resource' in entry:
          resource = entry['resource'] # Gets a resource from entries
          if 'Claim' in resource.get('resourceType'):
            # print(f"Found claim {counter}")
            temp_claim = Claim(resource)
            list_of_claim[counter] = temp_claim # insert index to claim pair
            counter+=1


    return list_of_claim

  def get_list_of_medications():
      pass




In [None]:
import json
import os
from datetime import datetime
from typing import Dict, List, Any, Optional, Union
from collections import defaultdict

class SyntheaBundleModel:
    """
    A flexible model for representing Synthea patient data with varying structures
    used for Synthea FHIR R4 data generations

    The model will have representations of the various different resources within
    FHIR Standard

    It will go like list of [X Type] resource contained within this object
    """

    def __init__(self, json_path: str):
      parser = SyntheaGenerationFHIRCustomParser(json_path)
      self.list_of_claim = parser.get_list_of_claims()
      #print(len(self.list_of_claim))



In [None]:
fhir_json_instance = SyntheaBundleModel(file)

print(len(fhir_json_instance.list_of_claim))

print(fhir_json_instance.list_of_claim[1].claim_id) # from claim obj get such attrib

print(fhir_json_instance.list_of_claim[3].type_of_claim)

print(fhir_json_instance.list_of_claim[3].number_of_diagnoses)

print(fhir_json_instance.list_of_claim[3].bill_period)

print(fhir_json_instance.list_of_claim[3].billable_duration_days)

print(fhir_json_instance.list_of_claim[3].number_of_items)

print(fhir_json_instance.list_of_claim[3].number_of_diagnoses)

print(fhir_json_instance.list_of_claim[3].priority)

print(fhir_json_instance.list_of_claim[3].status)

print(fhir_json_instance.list_of_claim[3].insurance)

print(fhir_json_instance.list_of_claim[3].total)




36
2b85d8e7-a230-5a3e-f47b-b6bf1c8d2f9b
pharmacy
0
{'start': 20140726142141, 'end': 20140726144950}
None
1
0
normal
active
1
{'value': 299.3, 'currency': 'USD'}


Notes:

Python is essentially doing the following under the hood:

It creates a function object for each method (such as __init__ and say_hello).
These function objects are stored in the __dict__ of the class object (MyClass).
When an instance is created from the class, these function objects are linked to the instance through the class object. The instance doesn't directly store the function; instead, it looks up the function in the class object.

In [None]:
'''
First maybe try aggregated single vector

then try X max amount of claims probably most recent ones or if not enough to meet
X then take all exisitng and then just do padding then input to autoencoder will be
[Max claim X static length of each claim]
'''

'\nFirst maybe try aggregated single vector\n\nthen try X max amount of claims probably most recent ones or if not enough to meet\nX then take all exisitng and then just do padding then input to autoencoder will be\n[Max claim X static length of each claim]\n'

In [None]:
import numpy as np
import torch

OneHotEnc_claim_types = ['institutional', 'professional', 'oral', 'pharmacy', 'vision', 'hearing', 'others']
OneHotEnc_status_types = ['active', 'cancelled', 'rejected', 'pending', 'completed', 'others']
OneHotEnc_priority_types = ['normal', 'stat', 'deferred', 'others']
BinaryFlag_insuranceCoverage = [0,1] # 0 is no 1 is yes

'''
returns a static lengthed vector encoding for the passed in claim
'''
def encode_claim(claim: Claim) -> np.ndarray:

  claim_vector = np.array([])

  # print(claim)

  # Concat billable period start 19960210094418 L(1)
  # claim_vector = np.append(claim_vector, claim.bill_period['start'])

  # Concat billable period end L(1)
  # claim_vector = np.append(claim_vector, claim.bill_period['end'])

  # Concat type ONE HOT ENC L(7)
  type_temp = [0] * len(OneHotEnc_claim_types)
  if claim.type_of_claim in OneHotEnc_claim_types:
    type_temp[OneHotEnc_claim_types.index(claim.type_of_claim)] = 1
  else:
    type_temp[OneHotEnc_claim_types.index('others')] = 1

  claim_vector = np.append(claim_vector, type_temp)

  # Concat status ONE HOT ENC L(6)
  type_status = [0] * len(OneHotEnc_status_types)
  if claim.status in OneHotEnc_status_types:
    type_status[OneHotEnc_status_types.index(claim.status)] = 1
  else:
    type_status[OneHotEnc_status_types.index('others')] = 1

  claim_vector = np.append(claim_vector, type_status)

  # Concat priority ONE HOT ENC L(4)
  priority_status = [0] * len(OneHotEnc_priority_types)
  if claim.priority in OneHotEnc_priority_types:
    priority_status[OneHotEnc_priority_types.index(claim.priority)] = 1
  else:
    priority_status[OneHotEnc_priority_types.index('others')] = 1

  claim_vector = np.append(claim_vector, priority_status)

  # Concat insurance L(1)
  if claim.insurance is not None:
    claim_vector = np.append(claim_vector, claim.insurance)

  # Concat number of drugs L(1)
  claim_vector = np.append(claim_vector, claim.number_of_drugs)

  # Concat number of items L(1)
  claim_vector = np.append(claim_vector, claim.number_of_items)

  # Concat number of diagnoises L(1)
  claim_vector = np.append(claim_vector, claim.number_of_diagnoses)

  # Concat total costs L(1)
  # TODO: maybe need normalization based on type of currency currently only assume usd
  # claim_vector = np.append(claim_vector, claim.total['value'])
  claim_vector = np.append(claim_vector, 1)


  return claim_vector


'''
Create the model input
'''
def create_model_input(list_of_claim: List, max_claim: int, claim_vector_size: int) -> torch.Tensor:

  np_model_input = np.empty((max_claim, claim_vector_size))

  if len(list_of_claim) < max_claim:
    number_of_claims = len(list_of_claim)
    print(f"Padded person claim count {number_of_claims}")
    for idx in range(number_of_claims):
      temp_vector = encode_claim(list_of_claim[idx])
      np_model_input[idx,:] = temp_vector
    # Rest is padding zeros
    for idx in range(number_of_claims, max_claim):
      np_model_input[idx,:] = np.zeros(claim_vector_size)

    # TODO MAYBE NEED TO MAKE CHANGES TO THIS

  else:
    # Take the most recent max_claim claims
    count = 0
    for idx in range(len(list_of_claim) - max_claim, len(list_of_claim)):
      temp_vector = encode_claim(list_of_claim[idx])
      np_model_input[count,:] = temp_vector
      count+=1

  return torch.from_numpy(np_model_input).float()


In [None]:
'''
Second attempt at encode claim
this time making it a multi dimensional tensor
'''

def encode_claim2(claim: Claim) -> np.ndarray:

  claim_vector = np.array([])

  print(claim)

  # Concat billable period start 19960210094418 L(1)
  claim_vector = np.append(claim_vector, claim.bill_period['start'])

  # Concat billable period end L(1)
  claim_vector = np.append(claim_vector, claim.bill_period['end'])

  # Concat type ONE HOT ENC L(7)
  type_temp = [0] * len(OneHotEnc_claim_types)
  if claim.type_of_claim in OneHotEnc_claim_types:
    type_temp[OneHotEnc_claim_types.index(claim.type_of_claim)] = 1
  else:
    type_temp[OneHotEnc_claim_types.index('others')] = 1

  claim_vector = np.append(claim_vector, type_temp)

  # Concat status ONE HOT ENC L(6)
  type_status = [0] * len(OneHotEnc_status_types)
  if claim.status in OneHotEnc_status_types:
    type_status[OneHotEnc_status_types.index(claim.status)] = 1
  else:
    type_status[OneHotEnc_status_types.index('others')] = 1

  claim_vector = np.append(claim_vector, type_status)

  # Concat priority ONE HOT ENC L(4)
  priority_status = [0] * len(OneHotEnc_priority_types)
  if claim.priority in OneHotEnc_priority_types:
    priority_status[OneHotEnc_priority_types.index(claim.priority)] = 1
  else:
    priority_status[OneHotEnc_priority_types.index('others')] = 1

  claim_vector = np.append(claim_vector, priority_status)

  # Concat insurance L(1)
  if claim.insurance is not None:
    claim_vector = np.append(claim_vector, claim.insurance)

  # Concat number of drugs L(1)
  claim_vector = np.append(claim_vector, claim.number_of_drugs)

  # Concat number of items L(1)
  claim_vector = np.append(claim_vector, claim.number_of_items)

  # Concat number of diagnoises L(1)
  claim_vector = np.append(claim_vector, claim.number_of_diagnoses)

  # Concat total costs L(1)
  # TODO: maybe need normalization based on type of currency currently only assume usd
  claim_vector = np.append(claim_vector, claim.total['value'])

  return claim_vector

^ For that above possibly add warning if too less claims and too much padding


Can also try masking

normalization based on number of claims also another option

input_tensor: Your padded claims data (shape: [max_claim, claim_vector_size])
mask: A separate vector indicating which positions contain real data (shape: [max_claim])



In [None]:
'''
Testing the whole thign out
'''

fhir_json_instance = SyntheaBundleModel(file)

tensor = create_model_input(fhir_json_instance.list_of_claim, MAX_CLAIM_C, CLAIM_ENCODE_SIZE)
print(tensor.dtype)
print(tensor.shape)

torch.float32
torch.Size([30, 22])


In [None]:
'''
Model

Need to study it internally
'''

import torch
import torch.nn as nn

class LSTMAutoEncoder(nn.Module):
  def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int, num_layers: int):
    super().__init__()
    self.encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)

    self.latent_fc = nn.Linear(hidden_dim, latent_dim)
    self.decoder_fc = nn.Linear(latent_dim, hidden_dim)

    self.decoder = nn.LSTM(input_size=hidden_dim, hidden_size=input_dim, num_layers=num_layers, batch_first=True)

  def forward(self, x):
    # Encode
    _, (hidden_state_temp, _) = self.encoder(x)
    latent_vector = self.latent_fc(hidden_state_temp[-1])  # Take last layer’s hidden state, map to latent dim

    # Decode
    hidden_state = self.decoder_fc(latent_vector).unsqueeze(0)  # Expand back to hidden_dim
    output, _ = self.decoder(hidden_state.repeat(x.size(1), 1, 1).permute(1, 0, 2))  # Reconstruct sequence

    return output



In [None]:
'''
Test out model shapes
'''

try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

from torchinfo import summary

model_0 = LSTMAutoEncoder(CLAIM_ENCODE_SIZE, 64,16, 2)
summary(model_0, input_size = [8, 30, 22])

# print(list(model_0.parameters()))

Layer (type:depth-idx)                   Output Shape              Param #
LSTMAutoEncoder                          [8, 30, 22]               --
├─LSTM: 1-1                              [8, 30, 128]              209,920
├─Linear: 1-2                            [8, 32]                   4,128
├─Linear: 1-3                            [8, 128]                  4,224
├─LSTM: 1-4                              [8, 30, 22]               17,424
Total params: 235,696
Trainable params: 235,696
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 54.63
Input size (MB): 0.02
Forward/backward pass size (MB): 0.30
Params size (MB): 0.94
Estimated Total Size (MB): 1.26

In [None]:
'''
Custom dataset
meant to be input into dataloader
'''

import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

class fhirJsonCustomDataset(Dataset):
  def __init__(self, target_dir: str):
    # Make a list of file paths
    self.file_paths = [os.path.join(data_path, file_path) for file_path in os.listdir(target_dir)
    if file_path.endswith('.json')]


  def __len__(self) -> int:
    return(len(self.file_paths))

  def __getitem__(self, index: int) -> torch.Tensor:

    fhir_json_instance = SyntheaBundleModel(self.file_paths[index])
    temp_input_tensor = create_model_input(fhir_json_instance.list_of_claim, MAX_CLAIM_C, CLAIM_ENCODE_SIZE)

    return temp_input_tensor

train_fhir_json_dataset = fhirJsonCustomDataset(target_dir=data_path) # make the dataset

# len(train_fhir_json_dataset) # 108 individual files

In [None]:
'''
Dataloader
'''
from torch.utils.data import DataLoader

train_dataloader_custom = DataLoader(train_fhir_json_dataset, batch_size = BATCH_SIZE_C, num_workers=0, shuffle=True)

test_input = next(iter(train_dataloader_custom)) # gets 1 batch, batchsize set to 8

print(test_input.shape) # [8, 30, 24] this is a batch

Padded person claim count 19
torch.Size([8, 30, 22])


In [None]:
'''
Training
'''
epochs = 3

loss_fn = nn.MSELoss() # Mean sq err
optimizer = torch.optim.SGD(params=model_0.parameters(), lr=.01)

debug_list = []

for epoch in range(epochs):
  for batch, X in enumerate(train_dataloader_custom):
    model_0.train()
    reconstruction = model_0(X) # Get out the models reconstructed model, call forward once
    loss = loss_fn(reconstruction, X)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    debug_list.append(f"Epoch {epoch} Batch {batch} Loss: {loss.item()}")

for item in debug_list:
  print(item)




Padded person claim count 20
Padded person claim count 28
Padded person claim count 25
Padded person claim count 23
Padded person claim count 29
Padded person claim count 25
Padded person claim count 27
Padded person claim count 23
Padded person claim count 14
Padded person claim count 21
Padded person claim count 19
Padded person claim count 12
Padded person claim count 14
Padded person claim count 16
Padded person claim count 21
Padded person claim count 8
Padded person claim count 13
Padded person claim count 17
Padded person claim count 21
Padded person claim count 25
Padded person claim count 28
Padded person claim count 12
Padded person claim count 28
Padded person claim count 25
Padded person claim count 23
Padded person claim count 25
Padded person claim count 21
Padded person claim count 27
Padded person claim count 17
Padded person claim count 16
Padded person claim count 13
Padded person claim count 14
Padded person claim count 21
Padded person claim count 23
Padded person c

In [None]:
'''
Test Predicting
'''

fhir_json_instance_test = SyntheaBundleModel(file)

tensor_test = create_model_input(fhir_json_instance.list_of_claim, MAX_CLAIM_C, CLAIM_ENCODE_SIZE)

model_0.eval()
reconstruction = model_0(tensor_test.unsqueeze(0))

print(tensor_test.unsqueeze(0))
print(reconstruction)

error = torch.mean((reconstruction - tensor_test.unsqueeze(0)) ** 2, dim=(1, 2)) # - is element wise subtraction
print(error)

tensor([[[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  2.,  2.,  0.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  3.,  3.,  0.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  2.,  2.,  1.,  1.],
         [ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  1.,  1.,  0.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  2.,  2.,  1.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1.,  3.,  3.,  0.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
           0.,  0.,  0.,  1., 12., 12.,  0.,  1.],
         [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,
   

In [None]:
'''
3/26/2025
LSTM Autoenc seems to fail have not tweaked much but error goes very high and never decreases

SyntaxError: incomplete input (<ipython-input-47-d8085cecdefb>, line 1)