# ClinicalNotes dataset generation (v0)

Generated for Hugging Face: https://huggingface.co/datasets/jmaasch/compositional_causal_reasoning/

Version 0.

Code by Jacqueline Maasch | April 2025

In [1]:
# General importations.
import pandas as pd
import numpy as np

from utils import Utils
from clinical_notes import ClinicalNotes
from dataset_generator import DataSetGenerator

In [2]:
u = Utils()
dg = DataSetGenerator()

In [3]:
path = "static_datasets/clinical_notes/"

## Step 1: Get raw dataset.

In [4]:
# x levels of graphical complexity (captured by BCC size).
# y tasks per graphical complexity level.
# z samples per task.
# w replicates per sample.
# = x*y*z*w subtasks.
graph_sizes = [[2,2],[3,3],[4,4],[5,5]]
n_tasks_per_size = 3
n_samples_per_task = 1000
reps_per_sample = 5

df = dg.get_dataset(task_generator = ClinicalNotes,
                    graph_sizes = graph_sizes,
                    n_tasks_per_size = n_tasks_per_size,
                    n_samples_per_task = n_samples_per_task, 
                    reps_per_sample = reps_per_sample)

print(df.info())
display(df.head(5))
display(df.tail(5))

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 60000 entries, 0 to 59999
Data columns (total 19 columns):
 #   Column                                  Non-Null Count  Dtype 
---  ------                                  --------------  ----- 
 0   Task ID                                 60000 non-null  object
 1   Context ID                              60000 non-null  object
 2   Sample ID                               60000 non-null  object
 3   Replicate ID                            60000 non-null  int64 
 4   Nodes per BCC                           60000 non-null  object
 5   DAG adjacency matrix                    60000 non-null  object
 6   DAG nodes                               60000 non-null  object
 7   CCT adjacency matrix                    60000 non-null  object
 8   CCT nodes                               60000 non-null  object
 9   Exogenous variables                     60000 non-null  object
 10  Bernoulli parameters                    60000 non-null  object
 11  Gl

Unnamed: 0,Task ID,Context ID,Sample ID,Replicate ID,Nodes per BCC,DAG adjacency matrix,DAG nodes,CCT adjacency matrix,CCT nodes,Exogenous variables,Bernoulli parameters,Global quantity,Local quantities,Compositions,Causal context,Sample context,Factual queries,Counterfactual queries (cause = True),Counterfactual queries (cause = False)
0,0.0,0,0,0,"[2, 2]","[[0, 1, 0], [0, 0, 1], [0, 0, 0]]","[pain, VGJS, surgery]","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, VGJS, surgery]","[LTMZ, G9TD, XD0U]","[0.5, 0.5, 0.5]","(pain, surgery)","[(pain, VGJS), (VGJS, surgery)]","[[(pain, VGJS), (VGJS, surgery)]]",Chronic disease G5HW23 sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
1,0.0,0,0,1,"[2, 2]","[[0, 1, 0], [0, 0, 1], [0, 0, 0]]","[pain, VGJS, surgery]","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, VGJS, surgery]","[LTMZ, G9TD, XD0U]","[0.5, 0.5, 0.5]","(pain, surgery)","[(pain, VGJS), (VGJS, surgery)]","[[(pain, VGJS), (VGJS, surgery)]]",Chronic disease G5HW23 sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
2,0.0,0,0,2,"[2, 2]","[[0, 1, 0], [0, 0, 1], [0, 0, 0]]","[pain, VGJS, surgery]","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, VGJS, surgery]","[LTMZ, G9TD, XD0U]","[0.5, 0.5, 0.5]","(pain, surgery)","[(pain, VGJS), (VGJS, surgery)]","[[(pain, VGJS), (VGJS, surgery)]]",Chronic disease G5HW23 sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
3,0.0,0,0,3,"[2, 2]","[[0, 1, 0], [0, 0, 1], [0, 0, 0]]","[pain, VGJS, surgery]","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, VGJS, surgery]","[LTMZ, G9TD, XD0U]","[0.5, 0.5, 0.5]","(pain, surgery)","[(pain, VGJS), (VGJS, surgery)]","[[(pain, VGJS), (VGJS, surgery)]]",Chronic disease G5HW23 sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
4,0.0,0,0,4,"[2, 2]","[[0, 1, 0], [0, 0, 1], [0, 0, 0]]","[pain, VGJS, surgery]","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, VGJS, surgery]","[LTMZ, G9TD, XD0U]","[0.5, 0.5, 0.5]","(pain, surgery)","[(pain, VGJS), (VGJS, surgery)]","[[(pain, VGJS), (VGJS, surgery)]]",Chronic disease G5HW23 sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."


Unnamed: 0,Task ID,Context ID,Sample ID,Replicate ID,Nodes per BCC,DAG adjacency matrix,DAG nodes,CCT adjacency matrix,CCT nodes,Exogenous variables,Bernoulli parameters,Global quantity,Local quantities,Compositions,Causal context,Sample context,Factual queries,Counterfactual queries (cause = True),Counterfactual queries (cause = False)
59995,11.999,11,999,0,"[5, 5]","[[0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, ...","[pain, OHSQ, 8GBL, VEE1, KHIU, KQGH, 83XU, WLV...","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, KHIU, surgery]","[WTGV, U0US, Y753, CWO5, PCKG, EL98, 5QNU, QFY...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]","(pain, surgery)","[(pain, KHIU), (KHIU, surgery)]","[[(pain, KHIU), (KHIU, surgery)]]",Chronic disease 6QTMQG sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
59996,11.999,11,999,1,"[5, 5]","[[0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, ...","[pain, OHSQ, 8GBL, VEE1, KHIU, KQGH, 83XU, WLV...","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, KHIU, surgery]","[WTGV, U0US, Y753, CWO5, PCKG, EL98, 5QNU, QFY...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]","(pain, surgery)","[(pain, KHIU), (KHIU, surgery)]","[[(pain, KHIU), (KHIU, surgery)]]",Chronic disease 6QTMQG sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
59997,11.999,11,999,2,"[5, 5]","[[0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, ...","[pain, OHSQ, 8GBL, VEE1, KHIU, KQGH, 83XU, WLV...","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, KHIU, surgery]","[WTGV, U0US, Y753, CWO5, PCKG, EL98, 5QNU, QFY...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]","(pain, surgery)","[(pain, KHIU), (KHIU, surgery)]","[[(pain, KHIU), (KHIU, surgery)]]",Chronic disease 6QTMQG sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
59998,11.999,11,999,3,"[5, 5]","[[0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, ...","[pain, OHSQ, 8GBL, VEE1, KHIU, KQGH, 83XU, WLV...","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, KHIU, surgery]","[WTGV, U0US, Y753, CWO5, PCKG, EL98, 5QNU, QFY...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]","(pain, surgery)","[(pain, KHIU), (KHIU, surgery)]","[[(pain, KHIU), (KHIU, surgery)]]",Chronic disease 6QTMQG sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."
59999,11.999,11,999,4,"[5, 5]","[[0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, ...","[pain, OHSQ, 8GBL, VEE1, KHIU, KQGH, 83XU, WLV...","[[0.0, 1.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, ...","[pain, KHIU, surgery]","[WTGV, U0US, Y753, CWO5, PCKG, EL98, 5QNU, QFY...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]","(pain, surgery)","[(pain, KHIU), (KHIU, surgery)]","[[(pain, KHIU), (KHIU, surgery)]]",Chronic disease 6QTMQG sometimes requires surg...,"Now, we will review the history and physical n...",{'surgery': {'Prompt': 'Given these history an...,"{('pain', 'surgery'): {'Prompt': 'Now suppose ...","{('pain', 'surgery'): {'Prompt': 'Now suppose ..."


In [5]:
df.to_csv(path+"clinical_notes_v0.csv", index = False)  

## Step 2: Process factual and counterfactual prompts.

In [6]:
# Process prompts.
df_factual, df_cf = dg.process_prompts()

In [7]:
print(df_factual.info())
display(df_factual.head(5))

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 120000 entries, 0 to 119999
Data columns (total 8 columns):
 #   Column         Non-Null Count   Dtype 
---  ------         --------------   ----- 
 0   Task ID        120000 non-null  object
 1   Context ID     120000 non-null  int64 
 2   Sample ID      120000 non-null  int64 
 3   Replicate ID   120000 non-null  int64 
 4   Nodes per BCC  120000 non-null  object
 5   Effect         120000 non-null  object
 6   Prompt         120000 non-null  object
 7   True           120000 non-null  int64 
dtypes: int64(4), object(4)
memory usage: 7.3+ MB
None


Unnamed: 0,Task ID,Context ID,Sample ID,Replicate ID,Nodes per BCC,Effect,Prompt,True
0,0.0,0,0,0,"[2, 2]",surgery,Chronic disease G5HW23 sometimes requires surg...,1
1,0.0,0,0,0,"[2, 2]",VGJS,Chronic disease G5HW23 sometimes requires surg...,1
2,0.0,0,0,1,"[2, 2]",surgery,Chronic disease G5HW23 sometimes requires surg...,1
3,0.0,0,0,1,"[2, 2]",VGJS,Chronic disease G5HW23 sometimes requires surg...,1
4,0.0,0,0,2,"[2, 2]",surgery,Chronic disease G5HW23 sometimes requires surg...,1


In [8]:
print(df_cf.info())
display(df_cf.head(5))

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 180000 entries, 0 to 179999
Data columns (total 12 columns):
 #   Column                  Non-Null Count   Dtype 
---  ------                  --------------   ----- 
 0   Task ID                 180000 non-null  object
 1   Context ID              180000 non-null  int64 
 2   Sample ID               180000 non-null  int64 
 3   Replicate ID            180000 non-null  int64 
 4   Nodes per BCC           180000 non-null  object
 5   Cause-effect pair       180000 non-null  object
 6   Cause                   180000 non-null  object
 7   Effect                  180000 non-null  object
 8   Prompt (cause = True)   180000 non-null  object
 9   True (cause = True)     180000 non-null  int64 
 10  Prompt (cause = False)  180000 non-null  object
 11  True (cause = False)    180000 non-null  int64 
dtypes: int64(5), object(7)
memory usage: 16.5+ MB
None


Unnamed: 0,Task ID,Context ID,Sample ID,Replicate ID,Nodes per BCC,Cause-effect pair,Cause,Effect,Prompt (cause = True),True (cause = True),Prompt (cause = False),True (cause = False)
0,0.0,0,0,0,"[2, 2]","(pain, surgery)",pain,surgery,Chronic disease G5HW23 sometimes requires surg...,1,Chronic disease G5HW23 sometimes requires surg...,0
1,0.0,0,0,0,"[2, 2]","(pain, VGJS)",pain,VGJS,Chronic disease G5HW23 sometimes requires surg...,1,Chronic disease G5HW23 sometimes requires surg...,0
2,0.0,0,0,0,"[2, 2]","(VGJS, surgery)",VGJS,surgery,Chronic disease G5HW23 sometimes requires surg...,1,Chronic disease G5HW23 sometimes requires surg...,0
3,0.0,0,0,1,"[2, 2]","(pain, surgery)",pain,surgery,Chronic disease G5HW23 sometimes requires surg...,1,Chronic disease G5HW23 sometimes requires surg...,0
4,0.0,0,0,1,"[2, 2]","(pain, VGJS)",pain,VGJS,Chronic disease G5HW23 sometimes requires surg...,1,Chronic disease G5HW23 sometimes requires surg...,0


In [9]:
df_factual.to_csv(path+"clinical_notes_factual_v0.csv", index = False)  

In [10]:
df_cf.to_csv(path+"clinical_notes_counterfactual_v0.csv", index = False)  

In [11]:
l = len(df_factual[(df_factual["Context ID"] == 0) & (df_factual["Effect"] == "surgery")])
print("\nTotal factual q's per quantity per task:", l)


Total factual q's per quantity per task: 5000


In [12]:
l = len(df_cf[(df_cf["Context ID"] == 0) & (df_cf["Cause-effect pair"] == ("pain", "surgery"))])
print("\nTotal counterfactual q's per quantity per task:", l)


Total counterfactual q's per quantity per task: 5000


## Step 3: Get ground truth PNS values.

Get dictionary mapping cause-effect pairs to their PNS value.

Keys are the Context ID. Values are dictionaries whose keys are the cause-effect pair and whose values are the finite sample PNS computed using ground truth response vectors.

In [13]:
pns_dict = dg.get_pns_dict(verbose = False)
display(pns_dict)

{0: {"('pain', 'surgery')": 0.255,
  "('pain', 'VGJS')": 0.507,
  "('VGJS', 'surgery')": 0.489,
  "[('pain', 'VGJS'), ('VGJS', 'surgery')]": 0.247923},
 1: {"('pain', 'surgery')": 0.244,
  "('pain', '4ZDL')": 0.504,
  "('4ZDL', 'surgery')": 0.484,
  "[('pain', '4ZDL'), ('4ZDL', 'surgery')]": 0.243936},
 2: {"('pain', 'surgery')": 0.239,
  "('pain', 'DP8H')": 0.476,
  "('DP8H', 'surgery')": 0.479,
  "[('pain', 'DP8H'), ('DP8H', 'surgery')]": 0.22800399999999998},
 3: {"('pain', 'surgery')": 0.127,
  "('pain', 'G883')": 0.262,
  "('G883', 'surgery')": 0.466,
  "[('pain', 'G883'), ('G883', 'surgery')]": 0.122092},
 4: {"('pain', 'surgery')": 0.134,
  "('pain', 'HSP3')": 0.261,
  "('HSP3', 'surgery')": 0.524,
  "[('pain', 'HSP3'), ('HSP3', 'surgery')]": 0.13676400000000002},
 5: {"('pain', 'surgery')": 0.124,
  "('pain', '0G5T')": 0.242,
  "('0G5T', 'surgery')": 0.487,
  "[('pain', '0G5T'), ('0G5T', 'surgery')]": 0.117854},
 6: {"('pain', 'surgery')": 0.068,
  "('pain', '8YFC')": 0.118,
  

In [14]:
# Save with numpy.
np.save(path+"clinical_notes_pns_dict_v0.npy", pns_dict) 

# Test loading.
#pns_dict_loaded = np.load("clinical_notes_pns_dict.npy",
#                          allow_pickle = "TRUE").item()
#display(pns_dict_loaded)

## Step 4: Compute internal consistency thresholds.

Return a dictionary that maps compositions to their correctness threshold
for internal compositional consistency evaluation. Thresholds are the RAE
for each composition relative to the global quantity of interest, times a
multiplier of the user's choice. 

* RAE = (abs(global PNS - composition PNS) / global PNS)
* Threhold = RAE*multiplier
        
This method of obtaining the threshold accounts for the innate error owed
to PNS estimation on finite samples, while the multiplier represents the
user's tolerance level for errors larger than the finite sample error.

Keys are the Context ID. Values are dictionaries whose keys are the causal composition (denoted by a list of cause-effect pairs whose PNS values are multiplied) and whose values are the internal consistency threshold.

For public use, we export threholds with multiplier 1.0 so that the end user can select 
their own multiplier downstream.

In [15]:
# Note for export.
threshold_dict = dg.get_internal_consistency_thresholds(multiplier = 1.25)
display(threshold_dict)

{0: {"[('pain', 'VGJS'), ('VGJS', 'surgery')]": 0.03469117647058823},
 1: {"[('pain', '4ZDL'), ('4ZDL', 'surgery')]": 0.0003278688524590599},
 2: {"[('pain', 'DP8H'), ('DP8H', 'surgery')]": 0.05751046025104606},
 3: {"[('pain', 'G883'), ('G883', 'surgery')]": 0.04830708661417319},
 4: {"[('pain', 'HSP3'), ('HSP3', 'surgery')]": 0.025783582089552393},
 5: {"[('pain', '0G5T'), ('0G5T', 'surgery')]": 0.061955645161290304},
 6: {"[('pain', '8YFC'), ('8YFC', 'surgery')]": 0.10904411764705893},
 7: {"[('pain', '5FQY'), ('5FQY', 'surgery')]": 0.03999999999999998},
 8: {"[('pain', 'TGJG'), ('TGJG', 'surgery')]": 0.038983050847457804},
 9: {"[('pain', 'DJF5'), ('DJF5', 'surgery')]": 0.1429347826086957},
 10: {"[('pain', 'RJDR'), ('RJDR', 'surgery')]": 0.07096774193548405},
 11: {"[('pain', 'KHIU'), ('KHIU', 'surgery')]": 0.036615853658536604}}

In [16]:
# For export.
threshold_dict = dg.get_internal_consistency_thresholds(multiplier = 1.0)
display(threshold_dict)

{0: {"[('pain', 'VGJS'), ('VGJS', 'surgery')]": 0.027752941176470588},
 1: {"[('pain', '4ZDL'), ('4ZDL', 'surgery')]": 0.00026229508196724794},
 2: {"[('pain', 'DP8H'), ('DP8H', 'surgery')]": 0.04600836820083685},
 3: {"[('pain', 'G883'), ('G883', 'surgery')]": 0.03864566929133855},
 4: {"[('pain', 'HSP3'), ('HSP3', 'surgery')]": 0.020626865671641913},
 5: {"[('pain', '0G5T'), ('0G5T', 'surgery')]": 0.049564516129032246},
 6: {"[('pain', '8YFC'), ('8YFC', 'surgery')]": 0.08723529411764715},
 7: {"[('pain', '5FQY'), ('5FQY', 'surgery')]": 0.03199999999999999},
 8: {"[('pain', 'TGJG'), ('TGJG', 'surgery')]": 0.031186440677966245},
 9: {"[('pain', 'DJF5'), ('DJF5', 'surgery')]": 0.11434782608695654},
 10: {"[('pain', 'RJDR'), ('RJDR', 'surgery')]": 0.056774193548387246},
 11: {"[('pain', 'KHIU'), ('KHIU', 'surgery')]": 0.029292682926829284}}

In [17]:
# Save with numpy.
np.save(path+"clinical_notes_threshold_dict_v0.npy", threshold_dict) 

# Test loading.
#threshold_dict_loaded = np.load("clinical_notes_threshold_dict.npy",
#                                allow_pickle = "TRUE").item()
#display(threshold_dict_loaded)