# Healthcare Sample Data Generator

This notebook generates synthetic healthcare data for clean room demonstrations and testing. It creates a star schema with one fact table and four dimension tables, ensuring complete referential integrity before persisting to Unity Catalog.

## Data Model

**Fact Table:**
* `visits` - Patient visit records linking to all dimension tables

**Dimension Tables:**
* `patients` - Patient demographic information
* `doctors` - Healthcare provider information with specialties
* `hospitals` - Hospital location data
* `diagnoses` - Medical diagnosis codes (ICD-10 format)

## Notebook Flow

1. **Configure Catalog and Schema** - Widget inputs for target catalog and schema (default: `mkgs.clean_room_sample_data`)
2. **Set Default Namespace** - Execute `USE CATALOG` and `USE SCHEMA` statements
3. **Create Schema** - Ensure the target schema exists
4. **Generate Sample Data** - Programmatically create randomized healthcare data:
	* Patients: 900-1100 (random)
	* Visits: 1400-1600 (random)
	* Doctors: 15-30 (random)
	* Hospitals: 3 (fixed)
	* Diagnoses: 15 (all available)
5. **Verify Referential Integrity** - Run comprehensive checks:
	* Primary key uniqueness
	* Null value detection
	* Foreign key validation
	* **Only save tables if all checks pass**

## Key Features

* **Randomized data generation** - Different counts on each run (with seed for reproducibility)
* **Referential integrity enforcement** - Tables only saved if validation succeeds
* **Overwrite mode** - Tables are created or replaced without manual drops
* **ICD-10 diagnosis codes** - Realistic medical coding standards

In [0]:
dbutils.widgets.text("catalog", "cr_owner_catalog", "Catalog")
dbutils.widgets.text("schema", "patient_visits", "Schema")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS IDENTIFIER(:catalog || "." || :schema);

In [0]:
# Use the specified catalog and schema as defaults
spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"USE SCHEMA {schema}")

print(f"Using catalog: {catalog}, schema: {schema}")

In [0]:
from pyspark.sql import Row
import random
from datetime import datetime, timedelta

# Set seed for reproducibility
random.seed(42)

# Generate random counts for each entity type
num_patients = random.randint(900, 1100)
num_visits = random.randint(1400, 1600)
num_doctors = random.randint(15, 30)
num_hospitals = 3
num_diagnoses = 15  # Use all 15 diagnoses

print(f"Generating data with:")
print(f"  Patients: {num_patients}")
print(f"  Visits: {num_visits}")
print(f"  Doctors: {num_doctors}")
print(f"  Hospitals: {num_hospitals}")
print(f"  Diagnoses: {num_diagnoses}")
print()

# Dimension table: hospitals (fixed at 3)
hospitals = [
	Row(hospital_id=301, name="General Hospital", city="Springfield")
	, Row(hospital_id=302, name="City Medical Center", city="Rivertown")
	, Row(hospital_id=303, name="Children's Hospital", city="Lakeside")
]
df_hospitals = spark.createDataFrame(hospitals)

# Dimension table: doctors (random between 15-30)
specialties = [
	"Cardiology", "Neurology", "Pediatrics", "Orthopedics", "Dermatology"
	, "Oncology", "Psychiatry", "Endocrinology", "Gastroenterology", "Pulmonology"
	, "Rheumatology", "Urology", "Nephrology", "Ophthalmology", "Otolaryngology"
	, "Radiology", "Anesthesiology", "Emergency Medicine", "Family Medicine", "Internal Medicine"
	, "Obstetrics", "Gynecology", "Pathology", "Surgery", "Hematology"
	, "Infectious Disease", "Allergy", "Sports Medicine", "Geriatrics", "Palliative Care"
]
doctor_names = [
	"Dr. Adams", "Dr. Baker", "Dr. Clark", "Dr. Davis", "Dr. Evans"
	, "Dr. Foster", "Dr. Green", "Dr. Harris", "Dr. Irwin", "Dr. Johnson"
	, "Dr. Kelly", "Dr. Lopez", "Dr. Murphy", "Dr. Nelson", "Dr. Owens"
	, "Dr. Parker", "Dr. Quinn", "Dr. Reed", "Dr. Stone", "Dr. Turner"
	, "Dr. Allen", "Dr. Bell", "Dr. Cooper", "Dr. Dixon", "Dr. Ellis"
	, "Dr. Fisher", "Dr. Gray", "Dr. Hughes", "Dr. Ingram", "Dr. Jenkins"
]

doctors = []
for i in range(num_doctors):
	doctor_id = 201 + i
	name = doctor_names[i % len(doctor_names)]
	specialty = specialties[i % len(specialties)]
	doctors.append(Row(doctor_id=doctor_id, name=name, specialty=specialty))

df_doctors = spark.createDataFrame(doctors)

# Dimension table: diagnoses (all 15)
all_diagnoses = [
	Row(diagnosis_id=401, code="I10", description="Hypertension")
	, Row(diagnosis_id=402, code="E11", description="Type 2 Diabetes")
	, Row(diagnosis_id=403, code="J45", description="Asthma")
	, Row(diagnosis_id=404, code="F32", description="Depression")
	, Row(diagnosis_id=405, code="M79", description="Fibromyalgia")
	, Row(diagnosis_id=406, code="K21", description="GERD")
	, Row(diagnosis_id=407, code="I25", description="Coronary Artery Disease")
	, Row(diagnosis_id=408, code="J44", description="COPD")
	, Row(diagnosis_id=409, code="N18", description="Chronic Kidney Disease")
	, Row(diagnosis_id=410, code="E78", description="Hyperlipidemia")
	, Row(diagnosis_id=411, code="M81", description="Osteoporosis")
	, Row(diagnosis_id=412, code="G43", description="Migraine")
	, Row(diagnosis_id=413, code="K58", description="Irritable Bowel Syndrome")
	, Row(diagnosis_id=414, code="L40", description="Psoriasis")
	, Row(diagnosis_id=415, code="F41", description="Anxiety Disorder")
]

# Use all 15 diagnoses
diagnoses = all_diagnoses[:num_diagnoses]
df_diagnoses = spark.createDataFrame(diagnoses)

# Generate patients programmatically (random between 900-1100)
first_names = ["Alice", "Bob", "Carol", "David", "Emma", "Frank", "Grace", "Henry", "Iris", "Jack"
	, "Karen", "Leo", "Maria", "Nathan", "Olivia", "Paul", "Quinn", "Ryan", "Sarah", "Tom"
	, "Uma", "Victor", "Wendy", "Xavier", "Yara", "Zack", "Amy", "Brian", "Chloe", "Daniel"
	, "Emily", "Felix", "Gina", "Hugo", "Ivy", "James", "Kate", "Liam", "Mia", "Noah"
	, "Ava", "Ben", "Cara", "Dean", "Ella", "Finn", "Gia", "Hank", "Isla", "Jake"]
last_names = ["Smith", "Jones", "Lee", "Kim", "Wilson", "Miller", "Taylor", "Brown", "Chen", "Davis"
	, "White", "Martinez", "Garcia", "Rodriguez", "Anderson", "Thomas", "Jackson", "Moore", "Martin", "Thompson"
	, "Patel", "Harris", "Clark", "Lewis", "Walker", "Hall", "Young", "King", "Wright", "Scott"
	, "Green", "Baker", "Adams", "Nelson", "Carter", "Mitchell", "Perez", "Roberts", "Turner", "Phillips"]
genders = ["M", "F"]

patients = []
for i in range(num_patients):
	patient_id = 101 + i
	first_name = first_names[i % len(first_names)]
	last_name = last_names[(i // len(first_names)) % len(last_names)]
	name = f"{first_name} {last_name}"
	# Generate random DOB between 1950 and 2010
	year = 1950 + (i % 61)
	month = 1 + (i % 12)
	day = 1 + (i % 28)
	dob = f"{year:04d}-{month:02d}-{day:02d}"
	gender = genders[i % 2]
	patients.append(Row(patient_id=patient_id, name=name, dob=dob, gender=gender))

df_patients = spark.createDataFrame(patients)

# Generate visits programmatically with valid foreign keys (random between 1400-1600)
patient_ids = list(range(101, 101 + num_patients))
doctor_ids = list(range(201, 201 + num_doctors))
hospital_ids = [301, 302, 303]
diagnosis_ids = [d.diagnosis_id for d in diagnoses]

patient_visits = []
start_date = datetime(2025, 1, 1)
for i in range(num_visits):
	visit_id = i + 1
	patient_id = random.choice(patient_ids)
	doctor_id = random.choice(doctor_ids)
	hospital_id = random.choice(hospital_ids)
	diagnosis_id = random.choice(diagnosis_ids)
	# Spread visits across 365 days
	visit_date = (start_date + timedelta(days=i % 365)).strftime("%Y-%m-%d")
	patient_visits.append(Row(
		visit_id=visit_id
		, patient_id=patient_id
		, doctor_id=doctor_id
		, hospital_id=hospital_id
		, diagnosis_id=diagnosis_id
		, visit_date=visit_date
	))

df_patient_visits = spark.createDataFrame(patient_visits)

print(f"\nGenerated:")
print(f"  {df_patients.count()} patients")
print(f"  {df_patient_visits.count()} visits")
print(f"  {df_doctors.count()} doctors")
print(f"  {df_hospitals.count()} hospitals")
print(f"  {df_diagnoses.count()} diagnoses")

display(df_patient_visits.limit(10))
display(df_patients.limit(10))

In [0]:
print("=== REFERENTIAL INTEGRITY CHECKS ===\n")

# Check 1: Verify primary key uniqueness in dimension tables
print("1. PRIMARY KEY UNIQUENESS:")
print(f"   Patients: {df_patients.count()} total, {df_patients.select('patient_id').distinct().count()} unique patient_ids")
print(f"   Doctors: {df_doctors.count()} total, {df_doctors.select('doctor_id').distinct().count()} unique doctor_ids")
print(f"   Hospitals: {df_hospitals.count()} total, {df_hospitals.select('hospital_id').distinct().count()} unique hospital_ids")
print(f"   Diagnoses: {df_diagnoses.count()} total, {df_diagnoses.select('diagnosis_id').distinct().count()} unique diagnosis_ids")
print(f"   Visits: {df_patient_visits.count()} total, {df_patient_visits.select('visit_id').distinct().count()} unique visit_ids")

# Check 2: Verify no nulls in key columns
print("\n2. NULL CHECKS IN KEY COLUMNS:")
print(f"   Null patient_ids in visits: {df_patient_visits.filter('patient_id IS NULL').count()}")
print(f"   Null doctor_ids in visits: {df_patient_visits.filter('doctor_id IS NULL').count()}")
print(f"   Null hospital_ids in visits: {df_patient_visits.filter('hospital_id IS NULL').count()}")
print(f"   Null diagnosis_ids in visits: {df_patient_visits.filter('diagnosis_id IS NULL').count()}")

# Check 3: Verify foreign key references (patient_id)
print("\n3. FOREIGN KEY INTEGRITY:")
patient_ids_in_visits = df_patient_visits.select('patient_id').distinct()
patient_ids_in_dim = df_patients.select('patient_id')
invalid_patients = patient_ids_in_visits.join(patient_ids_in_dim, 'patient_id', 'left_anti')
print(f"   Invalid patient_ids in visits: {invalid_patients.count()}")
if invalid_patients.count() > 0:
	print(f"   Invalid patient_ids: {[row.patient_id for row in invalid_patients.collect()]}")

# Check 4: Verify foreign key references (doctor_id)
doctor_ids_in_visits = df_patient_visits.select('doctor_id').distinct()
doctor_ids_in_dim = df_doctors.select('doctor_id')
invalid_doctors = doctor_ids_in_visits.join(doctor_ids_in_dim, 'doctor_id', 'left_anti')
print(f"   Invalid doctor_ids in visits: {invalid_doctors.count()}")
if invalid_doctors.count() > 0:
	print(f"   Invalid doctor_ids: {[row.doctor_id for row in invalid_doctors.collect()]}")

# Check 5: Verify foreign key references (hospital_id)
hospital_ids_in_visits = df_patient_visits.select('hospital_id').distinct()
hospital_ids_in_dim = df_hospitals.select('hospital_id')
invalid_hospitals = hospital_ids_in_visits.join(hospital_ids_in_dim, 'hospital_id', 'left_anti')
print(f"   Invalid hospital_ids in visits: {invalid_hospitals.count()}")
if invalid_hospitals.count() > 0:
	print(f"   Invalid hospital_ids: {[row.hospital_id for row in invalid_hospitals.collect()]}")

# Check 6: Verify foreign key references (diagnosis_id)
diagnosis_ids_in_visits = df_patient_visits.select('diagnosis_id').distinct()
diagnosis_ids_in_dim = df_diagnoses.select('diagnosis_id')
invalid_diagnoses = diagnosis_ids_in_visits.join(diagnosis_ids_in_dim, 'diagnosis_id', 'left_anti')
print(f"   Invalid diagnosis_ids in visits: {invalid_diagnoses.count()}")
if invalid_diagnoses.count() > 0:
	print(f"   Invalid diagnosis_ids: {[row.diagnosis_id for row in invalid_diagnoses.collect()]}")

# Summary
print("\n=== SUMMARY ===")
total_issues = (
	(df_patients.count() - df_patients.select('patient_id').distinct().count()) +
	(df_doctors.count() - df_doctors.select('doctor_id').distinct().count()) +
	(df_hospitals.count() - df_hospitals.select('hospital_id').distinct().count()) +
	(df_diagnoses.count() - df_diagnoses.select('diagnosis_id').distinct().count()) +
	(df_patient_visits.count() - df_patient_visits.select('visit_id').distinct().count()) +
	df_patient_visits.filter('patient_id IS NULL').count() +
	df_patient_visits.filter('doctor_id IS NULL').count() +
	df_patient_visits.filter('hospital_id IS NULL').count() +
	df_patient_visits.filter('diagnosis_id IS NULL').count() +
	invalid_patients.count() +
	invalid_doctors.count() +
	invalid_hospitals.count() +
	invalid_diagnoses.count()
)

if total_issues == 0:
	print("✓ All referential integrity checks PASSED")
	print("✓ No duplicate primary keys")
	print("✓ No null foreign keys")
	print("✓ All foreign keys reference valid dimension records")
	
	# Only save tables if integrity checks pass
	print("\n=== SAVING TABLES ===")
	df_patient_visits.write.mode("overwrite").saveAsTable("visits")
	print("✓ Saved table: visits")
	
	df_patients.write.mode("overwrite").saveAsTable("patients")
	print("✓ Saved table: patients")
	
	df_doctors.write.mode("overwrite").saveAsTable("doctors")
	print("✓ Saved table: doctors")
	
	df_hospitals.write.mode("overwrite").saveAsTable("hospitals")
	print("✓ Saved table: hospitals")
	
	df_diagnoses.write.mode("overwrite").saveAsTable("diagnoses")
	print("✓ Saved table: diagnoses")
	
	print("\n✓ All tables created/replaced successfully in schema")
else:
	print(f"✗ Found {total_issues} integrity issue(s)")
	print("✗ TABLES NOT SAVED - Fix integrity issues before saving")