In [None]:
USE ROLE AI_ENGINEER;
USE SCHEMA AI_DEVELOPMENT.SI_CLINICAL_TRIAL;

# Imports

In [None]:
!pip install faker --quiet

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from faker import Faker
import random
import warnings
warnings.filterwarnings('ignore')

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# 1. Generate Clinical Trial Data

In [None]:
"""
Clinical Trial Data Generator for DrugX
=====================================

This script generates a comprehensive, synthetic dataset for a clinical trial of a fictional drug named "DrugX".
The data contains realistic patterns and specific anomalies that allow users to ask business-relevant questions
using natural language through Snowflake Intelligence.
"""

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from faker import Faker
import random
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)
fake = Faker()
Faker.seed(42)

# ============================================================================
# SECTION 1: SETUP AND CONFIGURATION
# ============================================================================

# Constants - easily configurable
NUM_PATIENTS = 500
TRIAL_DURATION_DAYS = 365
TRIAL_START_DATE = datetime(2023, 1, 1)
TRIAL_END_DATE = TRIAL_START_DATE + timedelta(days=TRIAL_DURATION_DAYS)

# Treatment groups and dosages
TREATMENT_GROUPS = {
    'DrugX - Low Dose': 100,   # mg
    'DrugX - High Dose': 300,  # mg
    'Placebo': 0               # mg
}

# Lab test normal ranges (for reference)
LAB_NORMAL_RANGES = {
    'White Blood Cell Count': {'mean': 7.0, 'std': 1.5, 'unit': '10^3/ŒºL'},
    'Liver Function Test': {'mean': 25.0, 'std': 8.0, 'unit': 'U/L'},
    'Kidney Function Test': {'mean': 1.0, 'std': 0.2, 'unit': 'mg/dL'},
    'Cholesterol Level': {'mean': 180.0, 'std': 30.0, 'unit': 'mg/dL'}
}

print("üß¨ Clinical Trial Data Generator Starting...")
print(f"üìä Generating data for {NUM_PATIENTS} patients over {TRIAL_DURATION_DAYS} days")
print("=" * 60)

# ============================================================================
# SECTION 2: GENERATING STRUCTURED DATA TABLES
# ============================================================================

def generate_patients_table():
    """Generate the PATIENTS table with realistic demographic data."""
    print("üë• Generating PATIENTS table...")
    
    patients_data = []
    
    for i in range(1, NUM_PATIENTS + 1):
        # Generate realistic patient data
        gender = random.choice(['Male', 'Female'])
        
        # Age distribution: mostly adults 18-75
        age = np.random.normal(45, 15)
        age = max(18, min(75, int(age)))
        
        birth_date = datetime.now() - timedelta(days=age * 365.25)
        
        patient = {
            'PATIENT_ID': f'P{i:04d}',
            'FIRST_NAME': fake.first_name_male() if gender == 'Male' else fake.first_name_female(),
            'LAST_NAME': fake.last_name(),
            'DATE_OF_BIRTH': birth_date,  # Keep as datetime object
            'GENDER': gender,
            'COUNTRY': fake.country(),
            'AGE': age
        }
        patients_data.append(patient)
    
    df = pd.DataFrame(patients_data)
    # Ensure date column has correct datatype
    df['DATE_OF_BIRTH'] = pd.to_datetime(df['DATE_OF_BIRTH']).dt.date
    return df

def generate_treatments_table(patients_df):
    """Generate the TREATMENTS table with drug assignments."""
    print("üíä Generating TREATMENTS table...")
    
    treatments_data = []
    
    # Ensure roughly equal distribution across treatment groups
    treatment_names = list(TREATMENT_GROUPS.keys())
    patients_per_group = NUM_PATIENTS // len(treatment_names)
    
    # Assign treatments
    treatment_assignments = []
    for treatment in treatment_names:
        treatment_assignments.extend([treatment] * patients_per_group)
    
    # Handle remainder patients
    remainder = NUM_PATIENTS % len(treatment_names)
    for i in range(remainder):
        treatment_assignments.append(random.choice(treatment_names))
    
    # Shuffle assignments
    random.shuffle(treatment_assignments)
    
    for i, patient_id in enumerate(patients_df['PATIENT_ID']):
        treatment_group = treatment_assignments[i]
        dosage = TREATMENT_GROUPS[treatment_group]
        
        # Start date within first month of trial
        start_offset = random.randint(0, 30)
        start_date = TRIAL_START_DATE + timedelta(days=start_offset)
        
        treatment = {
            'PATIENT_ID': patient_id,
            'TREATMENT_GROUP': treatment_group,
            'DOSAGE_MG': dosage,
            'START_DATE': start_date,  # Keep as datetime object
            'END_DATE': TRIAL_END_DATE  # Keep as datetime object
        }
        treatments_data.append(treatment)
    
    df = pd.DataFrame(treatments_data)
    # Ensure date columns have correct datatype
    df['START_DATE'] = pd.to_datetime(df['START_DATE']).dt.date
    df['END_DATE'] = pd.to_datetime(df['END_DATE']).dt.date
    return df

def generate_lab_results_table(patients_df, treatments_df):
    """Generate the LAB_RESULTS table with diverse test types."""
    print("üî¨ Generating LAB_RESULTS table...")
    
    lab_results_data = []
    
    for _, patient in patients_df.iterrows():
        patient_id = patient['PATIENT_ID']
        treatment_info = treatments_df[treatments_df['PATIENT_ID'] == patient_id].iloc[0]
        treatment_group = treatment_info['TREATMENT_GROUP']
        
        # Generate lab results at regular intervals (monthly)
        for month in range(0, 13):  # 0 to 12 months
            test_date = TRIAL_START_DATE + timedelta(days=month * 30)
            
            for test_name, normal_range in LAB_NORMAL_RANGES.items():
                # Base normal values for placebo group
                if treatment_group == 'Placebo':
                    test_value = np.random.normal(normal_range['mean'], normal_range['std'])
                else:
                    # Slight variations for drug groups (will add anomalies later)
                    test_value = np.random.normal(normal_range['mean'], normal_range['std'])
                
                # Ensure positive values
                test_value = max(0.1, test_value)
                
                lab_result = {
                    'PATIENT_ID': patient_id,
                    'TEST_DATE': test_date,  # Keep as datetime object
                    'TEST_NAME': test_name,
                    'TEST_VALUE': round(test_value, 2),
                    'UNIT': normal_range['unit'],
                    'NORMAL_RANGE': f"{normal_range['mean']-2*normal_range['std']:.1f}-{normal_range['mean']+2*normal_range['std']:.1f}"
                }
                lab_results_data.append(lab_result)
    
    df = pd.DataFrame(lab_results_data)
    # Ensure date column has correct datatype
    df['TEST_DATE'] = pd.to_datetime(df['TEST_DATE']).dt.date
    return df

def generate_adverse_events_table(patients_df, treatments_df):
    """Generate the ADVERSE_EVENTS table with realistic side effects."""
    print("‚ö†Ô∏è  Generating ADVERSE_EVENTS table...")
    
    adverse_events_data = []
    
    # Common adverse events
    adverse_events_list = [
        'nausea', 'fatigue', 'headache', 'elevated blood pressure',
        'dizziness', 'insomnia', 'rash', 'diarrhea', 'constipation',
        'muscle pain', 'joint pain', 'dry mouth'
    ]
    
    severities = ['Mild', 'Moderate', 'Severe']
    
    for _, patient in patients_df.iterrows():
        patient_id = patient['PATIENT_ID']
        treatment_info = treatments_df[treatments_df['PATIENT_ID'] == patient_id].iloc[0]
        
        # Random chance of adverse events (30% of patients experience at least one)
        if random.random() < 0.3:
            num_events = random.randint(1, 3)
            
            for _ in range(num_events):
                event_date = TRIAL_START_DATE + timedelta(days=random.randint(1, TRIAL_DURATION_DAYS))
                
                adverse_event = {
                    'PATIENT_ID': patient_id,
                    'EVENT_DATE': event_date,  # Keep as datetime object
                    'EVENT_NAME': random.choice(adverse_events_list),
                    'SEVERITY': random.choice(severities),
                    'RESOLVED': random.choice([True, False])
                }
                adverse_events_data.append(adverse_event)
    
    df = pd.DataFrame(adverse_events_data)
    # Ensure date column has correct datatype
    df['EVENT_DATE'] = pd.to_datetime(df['EVENT_DATE']).dt.date
    return df

def generate_clinical_visits_table(patients_df, treatments_df):
    """Generate the CLINICAL_VISITS table with monthly visit data."""
    print("üè• Generating CLINICAL_VISITS table...")
    
    clinical_visits_data = []
    
    for _, patient in patients_df.iterrows():
        patient_id = patient['PATIENT_ID']
        treatment_info = treatments_df[treatments_df['PATIENT_ID'] == patient_id].iloc[0]
        treatment_group = treatment_info['TREATMENT_GROUP']
        
        # Generate monthly visits
        for month in range(0, 13):  # 0 to 12 months
            visit_date = TRIAL_START_DATE + timedelta(days=month * 30)
            
            # 5-10% chance of missed visit
            if random.random() < 0.075:
                continue  # Skip this visit (simulate missed appointment)
            
            # Generate realistic vital signs
            # Base values with some individual variation
            base_systolic = np.random.normal(120, 15)
            base_diastolic = np.random.normal(80, 10)
            base_heart_rate = np.random.normal(70, 12)
            
            # Ensure realistic ranges
            systolic_bp = max(90, min(180, int(base_systolic)))
            diastolic_bp = max(60, min(110, int(base_diastolic)))
            heart_rate = max(50, min(120, int(base_heart_rate)))
            
            clinical_visit = {
                'PATIENT_ID': patient_id,
                'VISIT_DATE': visit_date,  # Keep as datetime object
                'VISIT_TYPE': 'Scheduled Follow-up',
                'SYSTOLIC_BP': systolic_bp,
                'DIASTOLIC_BP': diastolic_bp,
                'HEART_RATE': heart_rate,
                'WEIGHT_KG': round(np.random.normal(70, 15), 1),
                'NOTES': 'Patient doing well' if random.random() > 0.2 else 'Minor concerns noted'
            }
            clinical_visits_data.append(clinical_visit)
    
    df = pd.DataFrame(clinical_visits_data)
    # Ensure date column has correct datatype
    df['VISIT_DATE'] = pd.to_datetime(df['VISIT_DATE']).dt.date
    return df

# Generate all base tables
patients_df = generate_patients_table()
treatments_df = generate_treatments_table(patients_df)
lab_results_df = generate_lab_results_table(patients_df, treatments_df)
adverse_events_df = generate_adverse_events_table(patients_df, treatments_df)
clinical_visits_df = generate_clinical_visits_table(patients_df, treatments_df)

print(f"‚úÖ Base tables generated successfully!")
print(f"   - Patients: {len(patients_df)} records")
print(f"   - Treatments: {len(treatments_df)} records")
print(f"   - Lab Results: {len(lab_results_df)} records")
print(f"   - Adverse Events: {len(adverse_events_df)} records")
print(f"   - Clinical Visits: {len(clinical_visits_df)} records")
print()

# ============================================================================
# SECTION 3: INJECTING ANOMALIES AND KEY INSIGHTS
# ============================================================================

print("üéØ Injecting anomalies and key insights...")

def inject_dose_dependent_adverse_events():
    """Inject pattern where high-dose patients have more nausea/fatigue in first 30 days."""
    print("   üìà Injecting dose-dependent adverse events...")
    
    high_dose_patients = treatments_df[treatments_df['TREATMENT_GROUP'] == 'DrugX - High Dose']['PATIENT_ID'].tolist()
    
    # Select 30% of high-dose patients for this anomaly
    affected_patients = random.sample(high_dose_patients, int(len(high_dose_patients) * 0.3))
    
    global adverse_events_df
    new_events = []
    
    for patient_id in affected_patients:
        # Add nausea within first 30 days
        event_date = TRIAL_START_DATE + timedelta(days=random.randint(5, 30))
        new_events.append({
            'PATIENT_ID': patient_id,
            'EVENT_DATE': event_date,  # Keep as datetime object
            'EVENT_NAME': 'nausea',
            'SEVERITY': 'Moderate',
            'RESOLVED': True
        })
        
        # Add fatigue within first 30 days
        event_date = TRIAL_START_DATE + timedelta(days=random.randint(10, 30))
        new_events.append({
            'PATIENT_ID': patient_id,
            'EVENT_DATE': event_date,  # Keep as datetime object
            'EVENT_NAME': 'fatigue',
            'SEVERITY': 'Moderate',
            'RESOLVED': random.choice([True, False])
        })
    
    # Add new events to the dataframe
    new_events_df = pd.DataFrame(new_events)
    # Ensure date column has correct datatype for new events
    new_events_df['EVENT_DATE'] = pd.to_datetime(new_events_df['EVENT_DATE']).dt.date
    adverse_events_df = pd.concat([adverse_events_df, new_events_df], ignore_index=True)

def inject_liver_function_anomaly():
    """Inject hepatotoxicity pattern in high-dose patients after 90 days."""
    print("   üî¨ Injecting liver function anomalies...")
    
    high_dose_patients = treatments_df[treatments_df['TREATMENT_GROUP'] == 'DrugX - High Dose']['PATIENT_ID'].tolist()
    
    # Select 15% of high-dose patients for liver issues
    affected_patients = random.sample(high_dose_patients, int(len(high_dose_patients) * 0.15))
    
    global lab_results_df
    
    # Convert trial start date to date object for comparison
    trial_start_date = TRIAL_START_DATE.date()
    cutoff_date = trial_start_date + timedelta(days=90)
    
    # Modify liver function test results for affected patients after 90 days
    for patient_id in affected_patients:
        mask = (lab_results_df['PATIENT_ID'] == patient_id) & \
               (lab_results_df['TEST_NAME'] == 'Liver Function Test') & \
               (lab_results_df['TEST_DATE'] >= cutoff_date)
        
        # Increase liver function values significantly (2-3x normal)
        lab_results_df.loc[mask, 'TEST_VALUE'] = lab_results_df.loc[mask, 'TEST_VALUE'] * np.random.uniform(2.0, 3.5, sum(mask))

def inject_cholesterol_anomaly():
    """Inject cholesterol increase in low-dose patients after 6 months."""
    print("   üíä Injecting cholesterol anomalies...")
    
    low_dose_patients = treatments_df[treatments_df['TREATMENT_GROUP'] == 'DrugX - Low Dose']['PATIENT_ID'].tolist()
    
    # Select 20% of low-dose patients for cholesterol issues
    affected_patients = random.sample(low_dose_patients, int(len(low_dose_patients) * 0.2))
    
    global lab_results_df
    
    # Convert trial start date to date object for comparison
    trial_start_date = TRIAL_START_DATE.date()
    cutoff_date = trial_start_date + timedelta(days=180)
    
    # Modify cholesterol levels for affected patients after 6 months
    for patient_id in affected_patients:
        mask = (lab_results_df['PATIENT_ID'] == patient_id) & \
               (lab_results_df['TEST_NAME'] == 'Cholesterol Level') & \
               (lab_results_df['TEST_DATE'] >= cutoff_date)
        
        # Increase cholesterol values by 40-80%
        lab_results_df.loc[mask, 'TEST_VALUE'] = lab_results_df.loc[mask, 'TEST_VALUE'] * np.random.uniform(1.4, 1.8, sum(mask))

def inject_cardiac_side_effects():
    """Inject slightly elevated heart rate in high-dose patients."""
    print("   ‚ù§Ô∏è  Injecting cardiac side effects...")
    
    high_dose_patients = treatments_df[treatments_df['TREATMENT_GROUP'] == 'DrugX - High Dose']['PATIENT_ID'].tolist()
    
    global clinical_visits_df
    
    # Increase heart rate by 5-10% for high-dose patients
    mask = clinical_visits_df['PATIENT_ID'].isin(high_dose_patients)
    clinical_visits_df.loc[mask, 'HEART_RATE'] = clinical_visits_df.loc[mask, 'HEART_RATE'] * np.random.uniform(1.05, 1.10, sum(mask))
    
    # Ensure values stay within realistic range
    clinical_visits_df['HEART_RATE'] = clinical_visits_df['HEART_RATE'].clip(upper=120)

# Apply all anomalies
inject_dose_dependent_adverse_events()
inject_liver_function_anomaly()
inject_cholesterol_anomaly()
inject_cardiac_side_effects()

print("‚úÖ Anomalies injected successfully!")
print()


patients_df = session.write_pandas(
    patients_df,
    table_name='PATIENTS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)
patients_df.show()

treatments_df = session.write_pandas(
    treatments_df,
    table_name='TREATMENTS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)
treatments_df.show()

lab_results_df = session.write_pandas(
    lab_results_df,
    table_name='LAB_RESULTS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)
lab_results_df.show()

adverse_events_df = session.write_pandas(
    adverse_events_df,
    table_name='ADVERSE_EVENTS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)
adverse_events_df.show()

clinical_visits_df = session.write_pandas(
    clinical_visits_df,
    table_name='CLINICAL_VISITS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)
clinical_visits_df.show()

# 2. Add Metadata

In [None]:
-- Table Descriptions
COMMENT ON TABLE PATIENTS IS 'A dimension table containing demographic and baseline information for all patients enrolled in the DrugX clinical trial. This table provides the core patient context for analyzing treatment outcomes, adverse events, and clinical measurements.';
COMMENT ON TABLE TREATMENTS IS 'A dimension table that records the treatment assignment for each patient in the clinical trial. It specifies which treatment group (DrugX Low Dose, DrugX High Dose, or Placebo) each patient was assigned to, along with dosage information and treatment duration.';
COMMENT ON TABLE LAB_RESULTS IS 'A fact table that stores all laboratory test results collected throughout the clinical trial. This includes various biomarkers and safety parameters measured at regular intervals to monitor patient health and drug effects.';
COMMENT ON TABLE ADVERSE_EVENTS IS 'A fact table that records all adverse events experienced by patients during the clinical trial. This critical safety data tracks the occurrence, severity, and resolution status of side effects potentially related to the study drug.';
COMMENT ON TABLE CLINICAL_VISITS IS 'A fact table that captures vital signs and clinical observations recorded during each patient visit. This includes blood pressure, heart rate, weight measurements, and clinical notes from healthcare providers.';

-- Column Descriptions
COMMENT ON COLUMN PATIENTS.PATIENT_ID IS 'The primary key and unique identifier for each patient enrolled in the clinical trial (format: P0001, P0002, etc.).';
COMMENT ON COLUMN PATIENTS.FIRST_NAME IS 'The first name of the patient, used for identification and communication purposes.';
COMMENT ON COLUMN PATIENTS.LAST_NAME IS 'The last name of the patient, used for identification and communication purposes.';
COMMENT ON COLUMN PATIENTS.DATE_OF_BIRTH IS 'The patient''s date of birth, used to calculate age and ensure appropriate dosing and safety monitoring.';
COMMENT ON COLUMN PATIENTS.GENDER IS 'The patient''s gender (Male or Female), important for analyzing gender-specific drug effects and adverse events.';
COMMENT ON COLUMN PATIENTS.COUNTRY IS 'The country where the patient is enrolled, enabling geographic analysis of treatment outcomes and regulatory compliance.';
COMMENT ON COLUMN PATIENTS.AGE IS 'The patient''s age at enrollment, calculated from date of birth and used for age-stratified analyses.';

COMMENT ON COLUMN TREATMENTS.PATIENT_ID IS 'The foreign key linking to PATIENTS table, identifying which patient received this treatment assignment.';
COMMENT ON COLUMN TREATMENTS.TREATMENT_GROUP IS 'The treatment arm assigned to the patient: ''DrugX - Low Dose'', ''DrugX - High Dose'', or ''Placebo''.';
COMMENT ON COLUMN TREATMENTS.DOSAGE_MG IS 'The daily dosage of the study drug in milligrams (100mg for low dose, 300mg for high dose, 0mg for placebo).';
COMMENT ON COLUMN TREATMENTS.START_DATE IS 'The date when the patient began taking the assigned treatment.';
COMMENT ON COLUMN TREATMENTS.END_DATE IS 'The planned end date for the patient''s treatment period.';

COMMENT ON COLUMN LAB_RESULTS.PATIENT_ID IS 'The foreign key linking to PATIENTS table, identifying which patient the lab result belongs to.';
COMMENT ON COLUMN LAB_RESULTS.TEST_DATE IS 'The date when the laboratory test was performed.';
COMMENT ON COLUMN LAB_RESULTS.TEST_NAME IS 'The name of the laboratory test performed (e.g., ''White Blood Cell Count'', ''Liver Function Test'').';
COMMENT ON COLUMN LAB_RESULTS.TEST_VALUE IS 'The numerical result of the laboratory test.';
COMMENT ON COLUMN LAB_RESULTS.UNIT IS 'The unit of measurement for the test result (e.g., ''10^3/ŒºL'', ''U/L'', ''mg/dL'').';
COMMENT ON COLUMN LAB_RESULTS.NORMAL_RANGE IS 'The normal reference range for this test, used to identify abnormal values.';

COMMENT ON COLUMN ADVERSE_EVENTS.PATIENT_ID IS 'The foreign key linking to PATIENTS table, identifying which patient experienced the adverse event.';
COMMENT ON COLUMN ADVERSE_EVENTS.EVENT_DATE IS 'The date when the adverse event occurred or was first reported.';
COMMENT ON COLUMN ADVERSE_EVENTS.EVENT_NAME IS 'The name or description of the adverse event (e.g., ''nausea'', ''fatigue'', ''headache'').';
COMMENT ON COLUMN ADVERSE_EVENTS.SEVERITY IS 'The severity classification of the adverse event: ''Mild'', ''Moderate'', or ''Severe''.';
COMMENT ON COLUMN ADVERSE_EVENTS.RESOLVED IS 'Boolean indicator of whether the adverse event has been resolved (TRUE) or is ongoing (FALSE).';

COMMENT ON COLUMN CLINICAL_VISITS.PATIENT_ID IS 'The foreign key linking to PATIENTS table, identifying which patient the visit data belongs to.';
COMMENT ON COLUMN CLINICAL_VISITS.VISIT_DATE IS 'The date when the clinical visit occurred.';
COMMENT ON COLUMN CLINICAL_VISITS.VISIT_TYPE IS 'The type of clinical visit (typically ''Scheduled Follow-up'' for regular monitoring visits).';
COMMENT ON COLUMN CLINICAL_VISITS.SYSTOLIC_BP IS 'The systolic blood pressure measurement in mmHg recorded during the visit.';
COMMENT ON COLUMN CLINICAL_VISITS.DIASTOLIC_BP IS 'The diastolic blood pressure measurement in mmHg recorded during the visit.';
COMMENT ON COLUMN CLINICAL_VISITS.HEART_RATE IS 'The heart rate measurement in beats per minute recorded during the visit.';
COMMENT ON COLUMN CLINICAL_VISITS.WEIGHT_KG IS 'The patient''s weight in kilograms recorded during the visit.';
COMMENT ON COLUMN CLINICAL_VISITS.NOTES IS 'Clinical notes and observations recorded by healthcare providers during the visit.';

-- Primary / Foreign Keys
ALTER TABLE PATIENTS ADD CONSTRAINT PK_PATIENTS PRIMARY KEY (PATIENT_ID);

ALTER TABLE TREATMENTS ADD CONSTRAINT PK_TREATMENTS PRIMARY KEY (PATIENT_ID);
ALTER TABLE TREATMENTS
ADD CONSTRAINT FK_TREATMENTS_PATIENT FOREIGN KEY (PATIENT_ID)
REFERENCES PATIENTS(PATIENT_ID);

ALTER TABLE LAB_RESULTS
ADD CONSTRAINT FK_LAB_RESULTS_PATIENT FOREIGN KEY (PATIENT_ID)
REFERENCES PATIENTS(PATIENT_ID);

ALTER TABLE ADVERSE_EVENTS
ADD CONSTRAINT FK_ADVERSE_EVENTS_PATIENT FOREIGN KEY (PATIENT_ID)
REFERENCES PATIENTS(PATIENT_ID);

ALTER TABLE CLINICAL_VISITS
ADD CONSTRAINT FK_CLINICAL_VISITS_PATIENT FOREIGN KEY (PATIENT_ID)
REFERENCES PATIENTS(PATIENT_ID); 

# 2. Search Services for High Cardinality Columns

In [None]:
CREATE OR REPLACE CORTEX SEARCH SERVICE _CA_COUNTRY
  ON COUNTRY
  WAREHOUSE = AI_WH
  TARGET_LAG = '12 hour'
  EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0'
AS (
  SELECT
      DISTINCT COUNTRY
  FROM PATIENTS
);

In [None]:
CREATE OR REPLACE CORTEX SEARCH SERVICE _CA_TEST_NAME
  ON TEST_NAME
  WAREHOUSE = AI_WH
  TARGET_LAG = '12 hour'
  EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0'
AS (
  SELECT
      DISTINCT TEST_NAME
  FROM LAB_RESULTS
);

In [None]:
CREATE OR REPLACE CORTEX SEARCH SERVICE _CA_EVENT_NAME
  ON EVENT_NAME
  WAREHOUSE = AI_WH
  TARGET_LAG = '12 hour'
  EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0'
AS (
  SELECT
      DISTINCT EVENT_NAME
  FROM ADVERSE_EVENTS
);

# 3. Create Semantic View

In [None]:
create or replace semantic view AI_DEVELOPMENT.SI_CLINICAL_TRIAL.CLINICAL_TRIAL_DATA_MODEL tables (
    ADVERSE_EVENTS with synonyms =(
        'adverse events',
        'side effects',
        'clinical trial safety data',
        'patient adverse reactions',
        'trial safety records',
        'adverse reactions',
        'safety incidents',
        'patient safety data'
    ) comment = 'A fact table that records all adverse events experienced by patients during the clinical trial. This critical safety data tracks the occurrence, severity, and resolution status of side effects potentially related to the study drug.',
    CLINICAL_VISITS with synonyms =(
        'patient_visits',
        'clinical_encounters',
        'medical_visits',
        'patient_encounters',
        'healthcare_visits',
        'clinical_appointments'
    ) comment = 'A fact table that captures vital signs and clinical observations recorded during each patient visit. This includes blood pressure, heart rate, weight measurements, and clinical notes from healthcare providers.',
    LAB_RESULTS with synonyms =(
        'lab results',
        'test results',
        'clinical trial data',
        'patient lab data',
        'biomarker results',
        'safety parameter data',
        'laboratory test data',
        'patient test results',
        'clinical trial lab results'
    ) comment = 'A fact table that stores all laboratory test results collected throughout the clinical trial. This includes various biomarkers and safety parameters measured at regular intervals to monitor patient health and drug effects.',
    PATIENTS primary key (PATIENT_ID) with synonyms =(
        'patients',
        'patient_info',
        'patient_demographics',
        'patient_data',
        'clinical_trial_participants',
        'trial_enrollees',
        'patient_enrollees',
        'study_participants'
    ) comment = 'A dimension table containing demographic and baseline information for all patients enrolled in the DrugX clinical trial. This table provides the core patient context for analyzing treatment outcomes, adverse events, and clinical measurements.',
    TREATMENTS primary key (PATIENT_ID) with synonyms =(
        'treatment_assignments',
        'patient_treatments',
        'clinical_trial_treatments',
        'treatment_groups',
        'patient_medication',
        'study_drug_assignments'
    ) comment = 'A dimension table that records the treatment assignment for each patient in the clinical trial. It specifies which treatment group (DrugX Low Dose, DrugX High Dose, or Placebo) each patient was assigned to, along with dosage information and treatment duration.'
) relationships (
    PATIENTS_X_ADVERSE_EVENTS as ADVERSE_EVENTS(PATIENT_ID) references PATIENTS(PATIENT_ID),
    PATIENTS_X_CLINICAL_VISITS as CLINICAL_VISITS(PATIENT_ID) references PATIENTS(PATIENT_ID),
    PATIENTS_X_LAB_RESULTS as LAB_RESULTS(PATIENT_ID) references PATIENTS(PATIENT_ID),
    PATIENTS_X_TREATMENTS as TREATMENTS(PATIENT_ID) references PATIENTS(PATIENT_ID)
) facts (
    PUBLIC CLINICAL_VISITS.DIASTOLIC_BP as DIASTOLIC_BP with synonyms =(
        'diastolic_blood_pressure',
        'diastolic_pressure',
        'resting_diastolic_pressure',
        'minimum_blood_pressure',
        'low_blood_pressure',
        'diastolic_blood_pressure_reading'
    ) comment = 'The diastolic blood pressure measurement in mmHg recorded during the visit.',
    PUBLIC CLINICAL_VISITS.HEART_RATE as HEART_RATE with synonyms =(
        'heart_rate_bpm',
        'pulse_rate',
        'beats_per_minute',
        'bpm',
        'cardiac_rate',
        'heart_beat_rate'
    ) comment = 'The heart rate measurement in beats per minute recorded during the visit.',
    PUBLIC CLINICAL_VISITS.SYSTOLIC_BP as SYSTOLIC_BP with synonyms =(
        'top_number',
        'systolic_pressure',
        'systolic_reading',
        'blood_pressure_systolic',
        'systolic_blood_pressure_value'
    ) comment = 'The systolic blood pressure measurement in mmHg recorded during the visit.',
    PUBLIC CLINICAL_VISITS.WEIGHT_KG as WEIGHT_KG with synonyms =(
        'body_weight',
        'patient_weight',
        'weight_in_kg',
        'kilograms',
        'mass_in_kg',
        'patient_mass'
    ) comment = 'The patient''s weight in kilograms recorded during the visit.',
    PUBLIC LAB_RESULTS.TEST_VALUE as TEST_VALUE with synonyms =(
        'lab_result',
        'test_result',
        'measurement',
        'value',
        'score',
        'reading',
        'numerical_result',
        'test_score',
        'result_value'
    ) comment = 'The numerical result of the laboratory test.',
    PUBLIC PATIENTS.AGE as AGE with synonyms =(
        'years_old',
        'years_of_age',
        'patient_age',
        'age_at_enrollment',
        'years_since_birth',
        'age_in_years'
    ) comment = 'The patient''s age at enrollment, calculated from date of birth and used for age-stratified analyses.',
    PUBLIC TREATMENTS.DOSAGE_MG as DOSAGE_MG with synonyms =(
        'daily_dosage',
        'milligram_dose',
        'medication_strength',
        'prescribed_amount',
        'mg_per_day',
        'treatment_dose'
    ) comment = 'The daily dosage of the study drug in milligrams (100mg for low dose, 300mg for high dose, 0mg for placebo).'
) dimensions (
    PUBLIC ADVERSE_EVENTS.EVENT_DATE as EVENT_DATE with synonyms =(
        'occurrence_date',
        'event_occurrence',
        'adverse_event_date',
        'date_reported',
        'incident_date',
        'event_timestamp'
    ) comment = 'The date when the adverse event occurred or was first reported.',
    PUBLIC ADVERSE_EVENTS.EVENT_NAME as EVENT_NAME comment = 'The name or description of the adverse event (e.g., ''nausea'', ''fatigue'', ''headache'').',
    PUBLIC ADVERSE_EVENTS.PATIENT_ID as PATIENT_ID with synonyms =(
        'patient_key',
        'patient_identifier',
        'subject_id',
        'participant_id',
        'individual_id'
    ) comment = 'The foreign key linking to PATIENTS table, identifying which patient experienced the adverse event.',
    PUBLIC ADVERSE_EVENTS.RESOLVED as RESOLVED with synonyms =(
        'resolved_status',
        'is_resolved',
        'resolution_status',
        'resolved_indicator',
        'closed',
        'completed',
        'settled',
        'cleared'
    ) comment = 'Boolean indicator of whether the adverse event has been resolved (TRUE) or is ongoing (FALSE).',
    PUBLIC ADVERSE_EVENTS.SEVERITY as SEVERITY with synonyms =(
        'intensity',
        'level',
        'degree',
        'magnitude',
        'seriousness',
        'criticality',
        'impact',
        'classification'
    ) comment = 'The severity classification of the adverse event: ''Mild'', ''Moderate'', or ''Severe''.',
    PUBLIC CLINICAL_VISITS.NOTES as NOTES with synonyms =(
        'clinical_notes',
        'observations',
        'comments',
        'remarks',
        'medical_notes',
        'visit_summary',
        'healthcare_provider_comments',
        'patient_visit_notes'
    ) comment = 'Clinical notes and observations recorded by healthcare providers during the visit.',
    PUBLIC CLINICAL_VISITS.PATIENT_ID as PATIENT_ID with synonyms =(
        'patient_key',
        'patient_identifier',
        'patient_code',
        'subject_id',
        'participant_id',
        'individual_id'
    ) comment = 'The foreign key linking to PATIENTS table, identifying which patient the visit data belongs to.',
    PUBLIC CLINICAL_VISITS.VISIT_DATE as VISIT_DATE with synonyms =(
        'visit_timestamp',
        'appointment_date',
        'clinical_visit_date',
        'patient_visit_date',
        'encounter_date',
        'visit_day',
        'admission_date'
    ) comment = 'The date when the clinical visit occurred.',
    PUBLIC CLINICAL_VISITS.VISIT_TYPE as VISIT_TYPE with synonyms =(
        'visit_category',
        'visit_purpose',
        'appointment_type',
        'clinical_visit_reason',
        'visit_description',
        'visit_classification'
    ) comment = 'The type of clinical visit (typically ''Scheduled Follow-up'' for regular monitoring visits).',
    PUBLIC LAB_RESULTS.NORMAL_RANGE as NORMAL_RANGE with synonyms =(
        'reference_range',
        'normal_values',
        'expected_range',
        'standard_range',
        'typical_range',
        'usual_range',
        'common_range',
        'acceptable_range',
        'healthy_range'
    ) comment = 'The normal reference range for this test, used to identify abnormal values.',
    PUBLIC LAB_RESULTS.PATIENT_ID as PATIENT_ID with synonyms =(
        'patient_key',
        'patient_identifier',
        'subject_id',
        'participant_id',
        'individual_id',
        'person_id'
    ) comment = 'The foreign key linking to PATIENTS table, identifying which patient the lab result belongs to.',
    PUBLIC LAB_RESULTS.TEST_DATE as TEST_DATE with synonyms =(
        'lab_test_date',
        'test_performed_date',
        'date_of_test',
        'test_date_recorded',
        'lab_result_date',
        'sample_collection_date'
    ) comment = 'The date when the laboratory test was performed.',
    PUBLIC LAB_RESULTS.TEST_NAME as TEST_NAME with synonyms =(
        'lab_test',
        'test_type',
        'test_description',
        'laboratory_test_name',
        'test_label',
        'assay_name',
        'diagnostic_test',
        'medical_test_name'
    ) comment = 'The name of the laboratory test performed (e.g., ''White Blood Cell Count'', ''Liver Function Test'').',
    PUBLIC LAB_RESULTS.UNIT as UNIT with synonyms =(
        'unit_of_measurement',
        'measurement_unit',
        'test_unit',
        'unit_value',
        'measurement_scale',
        'unit_type'
    ) comment = 'The unit of measurement for the test result (e.g., ''10^3/ŒºL'', ''U/L'', ''mg/dL'').',
    PUBLIC PATIENTS.COUNTRY as COUNTRY with synonyms =(
        'nation',
        'state',
        'territory',
        'land',
        'region',
        'geographical_area',
        'location',
        'nationality',
        'place_of_origin'
    ) comment = 'The country where the patient is enrolled, enabling geographic analysis of treatment outcomes and regulatory compliance.',
    PUBLIC PATIENTS.DATE_OF_BIRTH as DATE_OF_BIRTH with synonyms =(
        'birth_date',
        'date_of_birth',
        'dob',
        'birthdate',
        'birthday'
    ) comment = 'The patient''s date of birth, used to calculate age and ensure appropriate dosing and safety monitoring.',
    PUBLIC PATIENTS.FIRST_NAME as FIRST_NAME with synonyms =(
        'given_name',
        'forename',
        'personal_name',
        'christian_name',
        'personal_identifier'
    ) comment = 'The first name of the patient, used for identification and communication purposes.',
    PUBLIC PATIENTS.GENDER as GENDER with synonyms =(
        'sex',
        'male_female',
        'patient_gender',
        'demographic_gender',
        'gender_identity'
    ) comment = 'The patient''s gender (Male or Female), important for analyzing gender-specific drug effects and adverse events.',
    PUBLIC PATIENTS.LAST_NAME as LAST_NAME with synonyms =(
        'surname',
        'family_name',
        'second_name',
        'patronymic',
        'full_name',
        'name'
    ) comment = 'The last name of the patient, used for identification and communication purposes.',
    PUBLIC PATIENTS.PATIENT_ID as PATIENT_ID with synonyms =(
        'patient_key',
        'unique_patient_identifier',
        'subject_id',
        'participant_id',
        'patient_code'
    ) comment = 'The primary key and unique identifier for each patient enrolled in the clinical trial (format: P0001, P0002, etc.).',
    PUBLIC TREATMENTS.END_DATE as END_DATE with synonyms =(
        'planned_end_date',
        'treatment_end_date',
        'end_of_treatment',
        'planned_completion_date',
        'treatment_completion_date',
        'termination_date'
    ) comment = 'The planned end date for the patient''s treatment period.',
    PUBLIC TREATMENTS.PATIENT_ID as PATIENT_ID with synonyms =(
        'patient_key',
        'patient_identifier',
        'subject_id',
        'participant_id',
        'individual_id'
    ) comment = 'The foreign key linking to PATIENTS table, identifying which patient received this treatment assignment.',
    PUBLIC TREATMENTS.START_DATE as START_DATE with synonyms =(
        'treatment_initiation_date',
        'start_of_treatment',
        'treatment_begin_date',
        'initiation_date',
        'start_date_of_treatment',
        'treatment_commencement_date'
    ) comment = 'The date when the patient began taking the assigned treatment.',
    PUBLIC TREATMENTS.TREATMENT_GROUP as TREATMENT_GROUP with synonyms =(
        'treatment_arm',
        'treatment_category',
        'intervention_group',
        'study_group',
        'treatment_assignment',
        'experimental_group',
        'control_group',
        'treatment_type'
    ) comment = 'The treatment arm assigned to the patient: ''DrugX - Low Dose'', ''DrugX - High Dose'', or ''Placebo''.'
) comment = 'This semantic model is a comprehensive data platform for a clinical trial, designed to provide a unified view of patient information. It combines data on **patient demographics**, **treatment assignments**, **clinical measurements** from visits, **laboratory test results**, and **adverse events** to enable detailed safety and efficacy analysis. The model allows researchers and managers to easily track side effects, evaluate drug effectiveness, and generate reports for regulatory purposes.' with extension (
    CA = '{"tables":[{"name":"ADVERSE_EVENTS","dimensions":[{"name":"EVENT_NAME","cortex_search_service":{"database":"AI_DEVELOPMENT","schema":"SI_CLINICAL_TRIAL","service":"_CA_EVENT_NAME"}},{"name":"PATIENT_ID","sample_values":["P0003","P0274","P0178"]},{"name":"RESOLVED","sample_values":["FALSE","TRUE"]},{"name":"SEVERITY","sample_values":["Moderate","Mild","Severe"]}],"time_dimensions":[{"name":"EVENT_DATE","sample_values":["2023-11-22","2023-08-19","2023-01-15"]}]},{"name":"CLINICAL_VISITS","dimensions":[{"name":"NOTES","sample_values":["Minor concerns noted","Patient doing well"]},{"name":"PATIENT_ID","sample_values":["P0003","P0002","P0001"]},{"name":"VISIT_TYPE","sample_values":["Scheduled Follow-up"]}],"facts":[{"name":"DIASTOLIC_BP","sample_values":["79","108","81"]},{"name":"HEART_RATE","sample_values":["72","63.466372484","83.922626636"]},{"name":"SYSTOLIC_BP","sample_values":["90","132","117"]},{"name":"WEIGHT_KG","sample_values":["99.5","51.8","77.6"]}],"time_dimensions":[{"name":"VISIT_DATE","sample_values":["2023-01-31","2023-01-01","2023-06-30"]}]},{"name":"LAB_RESULTS","dimensions":[{"name":"NORMAL_RANGE","sample_values":["0.6-1.4","4.0-10.0","9.0-41.0"]},{"name":"PATIENT_ID","sample_values":["P0003","P0002","P0001"]},{"name":"TEST_NAME","sample_values":["White Blood Cell Count","Liver Function Test","Kidney Function Test"],"cortex_search_service":{"database":"AI_DEVELOPMENT","schema":"SI_CLINICAL_TRIAL","service":"_CA_TEST_NAME"}},{"name":"UNIT","sample_values":["mg/dL","10^3/ŒºL","U/L"]}],"facts":[{"name":"TEST_VALUE","sample_values":["180.78","34.63","191.26"]}],"time_dimensions":[{"name":"TEST_DATE","sample_values":["2023-01-31","2023-01-01","2023-06-30"]}]},{"name":"PATIENTS","dimensions":[{"name":"COUNTRY","sample_values":["Iran","Bosnia and Herzegovina","Puerto Rico"],"cortex_search_service":{"database":"AI_DEVELOPMENT","schema":"SI_CLINICAL_TRIAL","service":"_CA_COUNTRY"}},{"name":"FIRST_NAME","sample_values":["Alyssa","David","Mark"]},{"name":"GENDER","sample_values":["Male","Female"]},{"name":"LAST_NAME","sample_values":["Mcclain","Morris","Mccann"]},{"name":"PATIENT_ID","sample_values":["P0196","P0003","P0178"]}],"facts":[{"name":"AGE","sample_values":["54","42","52"]}],"time_dimensions":[{"name":"DATE_OF_BIRTH","sample_values":["1997-09-08","1973-09-08","1985-09-08"]}]},{"name":"TREATMENTS","dimensions":[{"name":"PATIENT_ID","sample_values":["P0196","P0003","P0178"]},{"name":"TREATMENT_GROUP","sample_values":["DrugX - Low Dose","Placebo","DrugX - High Dose"]}],"facts":[{"name":"DOSAGE_MG","sample_values":["300","0","100"]}],"time_dimensions":[{"name":"END_DATE","sample_values":["2024-01-01"]},{"name":"START_DATE","sample_values":["2023-01-04","2023-01-15","2023-01-27"]}]}],"relationships":[{"name":"patients_x_treatments"},{"name":"patients_x_lab_results"},{"name":"patients_x_clinical_visits"},{"name":"patients_x_adverse_events"}],"verified_queries":[{"name":"How many patients are in each treatment group?","question":"How many patients are in each treatment group?","sql":"SELECT\\n  treatment_group,\\n  COUNT(patient_id) AS patient_count\\nFROM\\n  treatments\\nGROUP BY\\n  treatment_group\\nORDER BY\\n  patient_count DESC NULLS LAST","use_as_onboarding_question":false,"verified_by":"Michael Gorkow","verified_at":1757295538},{"name":"What is the gender distribution of patients in the trial?","question":"What is the gender distribution of patients in the trial?","sql":"SELECT\\n  gender,\\n  COUNT(patient_id) AS patient_count\\nFROM\\n  patients\\nGROUP BY\\n  gender\\nORDER BY\\n  patient_count DESC NULLS LAST","use_as_onboarding_question":false,"verified_by":"Michael Gorkow","verified_at":1757295586},{"name":"What types of lab tests are being performed in this trial?","question":"What types of lab tests are being performed in this trial?","sql":"SELECT\\n  DISTINCT test_name\\nFROM\\n  lab_results\\nORDER BY\\n  test_name","use_as_onboarding_question":false,"verified_by":"Michael Gorkow","verified_at":1757295606},{"name":"What is the rate of adverse events in each treatment group?","question":"What is the rate of adverse events in each treatment group?","sql":"WITH adverse_event_counts AS (\\n  SELECT\\n    ae.patient_id,\\n    COUNT(*) AS event_count\\n  FROM\\n    adverse_events AS ae\\n  GROUP BY\\n    ae.patient_id\\n),\\ntreatment_adverse_events AS (\\n  SELECT\\n    t.treatment_group,\\n    t.patient_id,\\n    COALESCE(aec.event_count, 0) AS event_count\\n  FROM\\n    treatments AS t\\n    LEFT JOIN adverse_event_counts AS aec ON t.patient_id = aec.patient_id\\n)\\nSELECT\\n  treatment_group,\\n  COUNT(patient_id) AS total_patients,\\n  SUM(event_count) AS total_adverse_events,\\n  SUM(event_count) / NULLIF(COUNT(patient_id), 0) AS adverse_event_rate\\nFROM\\n  treatment_adverse_events\\nGROUP BY\\n  treatment_group\\nORDER BY\\n  adverse_event_rate DESC NULLS LAST","use_as_onboarding_question":false,"verified_by":"Michael Gorkow","verified_at":1757295636},{"name":"What are the top 5 most common adverse events across all patients?","question":"What are the top 5 most common adverse events across all patients?","sql":"SELECT\\n  event_name,\\n  COUNT(*) AS event_count\\nFROM\\n  adverse_events\\nGROUP BY\\n  event_name\\nORDER BY\\n  event_count DESC NULLS LAST\\nLIMIT\\n  5","use_as_onboarding_question":false,"verified_by":"Michael Gorkow","verified_at":1757295668}],"custom_instructions":"This semantic model designed for comprehensive analysis of data from a drug''s clinical trial. It integrates patient demographics, treatment assignments, clinical visit measurements, laboratory results, and adverse event reports to provide a complete view of a patient''s journey and outcomes during the study. This model allows for detailed safety and efficacy analysis, such as evaluating the frequency of side effects and tracking changes in patient health metrics over time, and can be used to answer critical questions about the trial''s results.\\n\\nKey Components\\n\\nThe model is structured around a central Patients table which acts as the core demographic source. It is linked to four key data tables:\\n\\n* Adverse Events: Records all side effects and safety incidents reported by patients. This is crucial for assessing the drug''s safety profile. The table includes details like the event name, its severity, and whether it has been resolved. \\n* Clinical Visits: Captures routine observations and vital signs from patient appointments, such as heart rate, blood pressure, and weight. It also includes clinical notes from healthcare providers.\\n* Lab Results: Contains a complete record of all laboratory tests performed on patients. This is essential for monitoring biomarkers and identifying any abnormal values that could be related to the study drug.\\n* Treatments: Details the specific treatment each patient was assigned, including the treatment group (e.g., DrugX - High Dose, Placebo) and the daily dosage.\\n\\nBusiness Use Cases\\n\\nThis model is a powerful tool for researchers, data scientists, and clinical trial managers. It is designed to support a wide range of analytical needs, including:\\n\\n* Safety Monitoring: Track and report on adverse events, identifying common side effects and comparing their frequency and severity across different treatment groups.\\n* Efficacy Analysis: Correlate treatment assignments with changes in clinical measurements (e.g., blood pressure) and laboratory results to determine the drug''s effectiveness.\\n* Patient Demographics: Analyze patient populations by gender, age, and country to understand how different groups respond to the treatment.\\n* Regulatory Reporting: Generate accurate and verifiable reports for regulatory bodies by having a single, unified source of clinical trial data.\\n* Querying: Provides the ability to ask natural language questions such as \\"What is the rate of adverse events in each treatment group?\\" or \\"What are the top 5 most common adverse events?\\"."}'
);

# 4. Create Search Service for Clinical Trial Reports

In [None]:
doc1 = open('documents/clinical_trial_description.md', 'r').read()
doc2 = open('documents/clinical_trial_results.md', 'r').read()
doc3 = open('documents/DrugX.md', 'r').read()

docs_df = pd.DataFrame(
    [['DrugX Trial Description', doc1],['DrugX Trial Results', doc2],['DrugX PRESCRIBING INFORMATION', doc3]], 
    columns=['DOCUMENT_TITLE','DOCUMENT_CONTENT']
)

session.write_pandas(
    docs_df, 
    table_name='TRIAL_DOCUMENTS',
    auto_create_table=True,
    use_logical_type=True,
    overwrite=True
)

In [None]:
CREATE OR REPLACE CORTEX SEARCH SERVICE CLINICAL_TRIAL_REPORTS
  ON DOCUMENT_CONTENT
  ATTRIBUTES DOCUMENT_TITLE
  WAREHOUSE = AI_WH
  TARGET_LAG = '1 hour'
  EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0'
AS (
  SELECT
      *
  FROM TRIAL_DOCUMENTS
);

# 5. Create Custom Tools

In [None]:
CREATE OR REPLACE FUNCTION find_clinical_trials(search_expr TEXT, max_studies INT)
RETURNS TEXT
LANGUAGE PYTHON
RUNTIME_VERSION = '3.11'
PACKAGES = ('pytrials','pandas')
ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository
EXTERNAL_ACCESS_INTEGRATIONS = (ai_external_access_integration)
HANDLER = 'find_clinical_trials'
AS
$$
from pytrials.client import ClinicalTrials
import pandas as pd
ct = ClinicalTrials()

def find_clinical_trials(search_expr: str, max_studies: int) -> str:
    trials = ct.get_full_studies(
        search_expr=search_expr,
        max_studies=max_studies,
        fmt="csv",
    )

    result_df = pd.DataFrame.from_records(trials[1:], columns=trials[0])
    result_df = result_df[[
    'NCT Number','Study Title','Study Status', 'Brief Summary', 
    'Interventions','Primary Outcome Measures', 'Secondary Outcome Measures', 
    'Sponsor', 'Sex', 'Age', 'Phases', 'Enrollment', 'Funder Type', 
    'Study Design', 'Start Date', 'Completion Date', 'Locations'
    ]]

    return result_df.to_json(orient='index')
$$;

select find_clinical_trials('covid', 5)

# 6. Create the Agent

In [None]:
CREATE OR REPLACE AGENT SNOWFLAKE_INTELLIGENCE.AGENTS.CLINICAL_TRIAL_AGENT
profile='{"display_name":"Clinical Trial Agent","avatar":"AiIcon","color":"var(--chartDim_3-x11sbcwy)"}'
comment='This agent provides insights into clinical trial data.'
FROM SPECIFICATION 
$$
{
  "models": {
    "orchestration": "claude-4-sonnet"
  },
  "instructions": {
    "orchestration": "When sending emails, make sure to provide well formatted content using html.\nWhen being asked about anomalies, first check the ANOMALY table if there are already anomalies for the relevant time period.\nIf there are no anomalies found, run the anomaly-detect tool. Otherwise use the data from the ANOMALY table unless the users explicitly asks to run the anomaly-detection tool."
  },
  "tools": [
    {
      "tool_spec": {
        "type": "cortex_analyst_text_to_sql",
        "name": "Factory-Data",
        "description": "This semantic data model provides a comprehensive view of manufacturing operations by linking machine performance, sensor data, and production line information. It helps teams monitor equipment efficiency and quickly identify and analyze anomalies.\n \nThe model is built on four core tables:\n \n * ANOMALIES: Contains records of unusual sensor readings, including which machine and sensor were affected, the type of anomaly, and the specific timestamp and value.\n * DIM_LINES: Provides details about each production line, such as its name and the manufacturing plant it's in.\n * DIM_MACHINES: Lists all the machines, detailing their manufacturer, model, and the production line they belong to.\n * DIM_SENSORS: Describes each sensor, including the machine it's installed on, what physical quantity it measures (e.g., pressure), and the unit of measurement.\n * FACT_OEE: Tracks Overall Equipment Effectiveness (OEE) metrics like Availability, Performance, and Quality for each machine over time, along with the number of units produced and scrapped.\n * FACT_SENSOR_VALUES_10_MINUTES: Stores sensor readings aggregated into 10-minute intervals.\n \nThis model helps you answer critical questions about factory performance and potential issues, such as:\n \n * Which machines and production lines have the lowest OEE scores?\n * What is the average OEE for a specific manufacturing plant or production line?\n * How does a machine's performance or quality change over time?\n * Which sensors are showing the most frequent or severe anomalies?\n * Can we link a recent drop in OEE to a specific sensor anomaly on a machine?"
      }
    },
    {
      "tool_spec": {
        "type": "cortex_search",
        "name": "Maintenance-Reports",
        "description": "This tool provides access to maintenance reports from technicians. The reports provide details about machine incidents and actions taken to mitigate them."
      }
    },
    {
      "tool_spec": {
        "type": "generic",
        "name": "send-email",
        "description": "Use this tool to send emails.",
        "input_schema": {
          "type": "object",
          "properties": {
            "recipient": {
              "description": "The email address of the recipient.",
              "type": "string"
            },
            "subject": {
              "description": "The subject of the email.",
              "type": "string"
            },
            "text": {
              "description": "The text of the email. Supports html code for formatted emails.",
              "type": "string"
            }
          },
          "required": [
            "recipient",
            "subject",
            "text"
          ]
        }
      }
    },
    {
      "tool_spec": {
        "type": "generic",
        "name": "detect-anomalies",
        "description": "Use this tool to perform anomaly detection for all sensors of a machine.",
        "input_schema": {
          "type": "object",
          "properties": {
            "end_date": {
              "description": "End date for historical data. Data on and before that date are included in anomaly detection.\nProvided as YYYY-MM-DD.",
              "type": "string"
            },
            "machine_id": {
              "description": "The ID of the machine for which to detect anomalies.",
              "type": "string"
            },
            "start_date": {
              "description": "Start date for historical data. Data on and after that date are included in anomaly detection.\nProvided as YYYY-MM-DD.",
              "type": "string"
            }
          },
          "required": [
            "end_date",
            "machine_id",
            "start_date"
          ]
        }
      }
    }
  ],
  "tool_resources": {
    "Factory-Data": {
      "semantic_view": "AI_DEVELOPMENT.SI_BOTTLING_COMPANY.FACTORY_DATA_MODEL"
    },
    "Maintenance-Reports": {
      "max_results": 4,
      "name": "AI_DEVELOPMENT.SI_BOTTLING_COMPANY.MAINTENANCE_REPORTS",
      "title_column": "INCIDENT_ID"
    },
    "detect-anomalies": {
      "execution_environment": {
        "type": "warehouse",
        "warehouse": "AI_WH"
      },
      "identifier": "AI_DEVELOPMENT.SI_BOTTLING_COMPANY.DETECT_ANOMALIES",
      "name": "DETECT_ANOMALIES(VARCHAR, VARCHAR, VARCHAR)",
      "type": "procedure"
    },
    "send-email": {
      "execution_environment": {
        "type": "warehouse",
        "warehouse": "AI_WH"
      },
      "identifier": "AI_DEVELOPMENT.PUBLIC.SEND_MAIL",
      "name": "SEND_MAIL(VARCHAR, VARCHAR, VARCHAR)",
      "type": "procedure"
    }
  }
}
$$