<a href="https://colab.research.google.com/github/gcosma/cormtest/blob/main/LANCET_Personalisation_Paper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tool for analysing condition progression in patients with Multiple Long Term Conditions and Intellectual Disabilities

In [1]:
#@title Step 1: You need to make some installations. This is very simple. On the next line you will see a little triangle that's in a little black circle on the left of this box. Click on it. When the triangle turns to a green tick you will see a message that says "Done! Now click on Step 2".
# Install necessary packages
!pip -q install pandas numpy networkx matplotlib pyvis ipywidgets ipython plotly seaborn
# Display a message after installation
print("Done! Now click on Step 2.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hDone! Now click on Step 2.


In [2]:
#@title Step 2: The arrow on the left just before the start of this sentence is to hide and show the code. If you click it accidentally, just click it again and the code will be hidden. Below this sentence there is another arrow in a circle. To execute the code click on the arrow that's in the cirle.
import pandas as pd
import numpy as np
from itertools import combinations
import os
from google.colab import files
from IPython.display import HTML, display
import base64
import io
import ipywidgets as widgets

# Global variables
global_data = None
total_patients_in_group = 0
gender = ''
age_group = ''
results_df = None

# Condition categories dictionary
condition_categories = {
    "Anaemia": "Blood",
    "Cardiac Arrhythmias": "Circulatory",
    "Coronary Heart Disease": "Circulatory",
    "Heart Failure": "Circulatory",
    "Hypertension": "Circulatory",
    "Peripheral Vascular Disease": "Circulatory",
    "Stroke": "Circulatory",
    "Barretts Oesophagus": "Digestive",
    "Chronic Constipation": "Digestive",
    "Chronic Diarrhoea": "Digestive",
    "Cirrhosis": "Digestive",
    "Dysphagia": "Digestive",
    "Inflammatory Bowel Disease": "Digestive",
    "Reflux Disorders": "Digestive",
    "Hearing Loss": "Ear",
    "Addisons Disease": "Endocrine",
    "Diabetes": "Endocrine",
    "Polycystic Ovary Syndrome": "Endocrine",
    "Thyroid Disorders": "Endocrine",
    "Visual Impairment": "Eye",
    "Chronic Kidney Disease": "Genitourinary",
    "Menopausal And Perimenopausal": "Genitourinary",
    "Menopausal and Perimenopausal": "Genitourinary",
    "Dementia": "Mental",
    "Mental Illness": "Mental",
    "Tourette": "Mental",
    "Chronic Arthritis": "Musculoskeletal",
    "Chronic Pain Conditions": "Musculoskeletal",
    "Osteoporosis": "Musculoskeletal",
    "Cancer": "Neoplasms",
    "Cerebral Palsy": "Nervous",
    "Epilepsy": "Nervous",
    "Insomnia": "Nervous",
    "Multiple Sclerosis": "Nervous",
    "Neuropathic Pain": "Nervous",
    "Parkinsons": "Nervous",
    "Bronchiectasis": "Respiratory",
    "Chronic Airway Diseases": "Respiratory",
    "Chronic Pneumonia": "Respiratory",
    "Interstitial Lung Disease": "Respiratory",
    "Psoriasis": "Skin"
}

def load_file():
    """Load and process the CSV file"""
    print("Please upload your CSV file.")
    uploaded = files.upload()

    if not uploaded:
        print("No file uploaded. Exiting.")
        return None, None, None, None

    file_name = next(iter(uploaded))
    file_content = uploaded[file_name]

    try:
        data = pd.read_csv(io.BytesIO(file_content))
        total_patients = data['TotalPatientsInGroup'].iloc[0]

        print(f"Uploaded file: {file_name}")

        # Normalize the filename for comparison
        file_name_lower = file_name.lower()

        # Determine gender from filename
        if 'females' in file_name_lower:
            gender = 'Female'
        elif 'males' in file_name_lower:
            gender = 'Male'
        else:
            gender = 'Unknown Gender'

        # Determine age group from filename
        if 'below45' in file_name_lower:
            age_group = '<45'
        elif '45to64' in file_name_lower:
            age_group = '45-64'
        elif '65plus' in file_name_lower:
            age_group = '65+'
        else:
            age_group = 'Unknown Age Group'

        print(f"File loaded successfully. Total patients: {total_patients}")
        print(f"Gender: {gender}, Age Group: {age_group}")

        return data, total_patients, gender, age_group

    except Exception as e:
        print(f"Error loading file: {str(e)}")
        return None, None, None, None

def initialize_data(data, total_patients, gender_info, age_group_info):
    """Initialize global variables with loaded data"""
    global global_data, total_patients_in_group, gender, age_group
    global_data = data
    total_patients_in_group = total_patients
    gender = gender_info
    age_group = age_group_info

def analyze_condition_combinations(min_percentage, min_frequency):
    """Analyze combinations of conditions based on minimum percentage and frequency"""
    global results_df

    if global_data is None or global_data.empty:
        print("No data available. Please load a valid CSV file.")
        return

    data = global_data[(global_data['Percentage'] >= min_percentage) &
                       (global_data['PairFrequency'] >= min_frequency)].copy()

    # Clean condition names
    data.loc[:, 'ConditionA'] = data['ConditionA'].str.replace(r'\s*\([^)]*\)', '', regex=True)
    data.loc[:, 'ConditionB'] = data['ConditionB'].str.replace(r'\s*\([^)]*\)', '', regex=True)
    data.loc[:, 'ConditionA'] = data['ConditionA'].str.replace('_', ' ')
    data.loc[:, 'ConditionB'] = data['ConditionB'].str.replace('_', ' ')

    unique_conditions = pd.unique(data[['ConditionA', 'ConditionB']].values.ravel('K'))

    # Create frequency maps
    pair_frequency_map = {}
    condition_frequency_map = {}

    for _, row in data.iterrows():
        key1 = f"{row['ConditionA']}_{row['ConditionB']}"
        key2 = f"{row['ConditionB']}_{row['ConditionA']}"
        pair_frequency_map[key1] = row['PairFrequency']
        pair_frequency_map[key2] = row['PairFrequency']

        condition_frequency_map[row['ConditionA']] = condition_frequency_map.get(row['ConditionA'], 0) + row['PairFrequency']
        condition_frequency_map[row['ConditionB']] = condition_frequency_map.get(row['ConditionB'], 0) + row['PairFrequency']

    result_data = []

    # Analyze combinations
    for k in range(3, min(8, len(unique_conditions) + 1)):
        for comb in combinations(unique_conditions, k):
            pair_frequencies = [pair_frequency_map.get(f"{a}_{b}", 0) for a, b in combinations(comb, 2)]
            frequency = min(pair_frequencies)
            prevalence = (frequency / total_patients_in_group) * 100

            observed = frequency
            expected = total_patients_in_group
            for condition in comb:
                expected *= (condition_frequency_map[condition] / total_patients_in_group)
            odds_ratio = observed / expected if expected != 0 else float('inf')

            result_data.append({
                'Combination': ' + '.join(comb),
                'NumConditions': len(comb),
                'Minimum Pair Frequency': frequency,
                'Prevalence of the combination (%)': prevalence,
                'Total odds ratio': odds_ratio
            })

    results_df = pd.DataFrame(result_data)
    results_df = results_df.sort_values('Prevalence of the combination (%)', ascending=False)
    results_df = results_df[results_df['Prevalence of the combination (%)'] > 0]

    print(f'Analysis complete. {len(results_df)} combinations found.')
    return results_df

def save_results_to_csv(filename="condition_combinations.csv"):
    """Save analysis results to CSV file"""
    global results_df
    if results_df is not None and not results_df.empty:
        results_df.to_csv(filename, index=False)
        print(f"Results saved to {filename}")
        files.download(filename)
    else:
        print("No results available to save. Please run the analysis first.")

# Main execution block
if __name__ == "__main__":
    try:
        # Load the file
        data, total_patients, gender, age_group = load_file()

        if data is not None:
            # Initialize the data
            initialize_data(data, total_patients, gender, age_group)

            while True:
                # Calculate ranges
                min_freq_range = (global_data['PairFrequency'].min(), global_data['PairFrequency'].max())
                min_percentage_range = (global_data['Percentage'].min(), global_data['Percentage'].max())

                # Get user input for minimum pair frequency
                while True:
                    try:
                        min_frequency = int(input(f"Enter the minimum pair frequency [{min_freq_range[0]}-{min_freq_range[1]}]: "))
                        if min_freq_range[0] <= min_frequency <= min_freq_range[1]:
                            break
                        else:
                            print(f"Error: Value must be between {min_freq_range[0]} and {min_freq_range[1]}. Please try again.")
                    except ValueError:
                        print("Error: Please enter a valid integer.")

                # Get user input for minimum percentage
                while True:
                    try:
                        min_percentage = float(input(f"Enter the minimum prevalence percentage of a pair (%) [{min_percentage_range[0]:.2f}-{min_percentage_range[1]:.2f}]: "))
                        if min_percentage_range[0] <= min_percentage <= min_percentage_range[1]:
                            break
                        else:
                            print(f"Error: Value must be between {min_percentage_range[0]:.2f} and {min_percentage_range[1]:.2f}. Please try again.")
                    except ValueError:
                        print("Error: Please enter a valid number.")

                # Analyze condition combinations
                results = analyze_condition_combinations(min_percentage, min_frequency)

                # Display results
                total_combinations = len(results)
                print(f"\nTotal number of condition combinations: {total_combinations}")
                print("All condition combinations:")

                # Display all results without truncation
                pd.set_option('display.max_rows', None)
                pd.set_option('display.max_columns', None)
                pd.set_option('display.width', None)
                pd.set_option('display.max_colwidth', None)
                display(results)

                # Reset display options
                pd.reset_option('display.max_rows')
                pd.reset_option('display.max_columns')
                pd.reset_option('display.width')
                pd.reset_option('display.max_colwidth')

                # Ask if user wants to save results
                save_choice = input("Do you want to save the results to a CSV file? (yes/no): ").lower()
                if save_choice == 'yes':
                    save_results_to_csv()

                # Ask if user wants to run again
                run_again = input("Do you want to run the analysis again with different parameters? (yes/no): ").lower()
                if run_again != 'yes':
                    break

            print("Analysis completed. Thank you for using the script!")
        else:
            print("Failed to load data. Please run the script again and upload a valid CSV file.")

    except KeyboardInterrupt:
        print("\nScript execution interrupted by user. Exiting gracefully...")
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")
    finally:
        print("Script execution completed.")

print("You can run the cell again to start a new analysis with a different file.")

Please upload your CSV file.


Saving Females_fdr_significant_high_freq_odds_ratio_analysis_below45.csv to Females_fdr_significant_high_freq_odds_ratio_analysis_below45.csv
Uploaded file: Females_fdr_significant_high_freq_odds_ratio_analysis_below45.csv
File loaded successfully. Total patients: 6397
Gender: Female, Age Group: <45
Enter the minimum pair frequency [71-484]: 100
Enter the minimum prevalence percentage of a pair (%) [1.11-7.57]: 2
Analysis complete. 7 combinations found.

Total number of condition combinations: 7
All condition combinations:


Unnamed: 0,Combination,NumConditions,Minimum Pair Frequency,Prevalence of the combination (%),Total odds ratio
290,Chronic Airway Diseases + Insomnia + Reflux Disorders,3,239,3.736126,6.990363
249,Chronic Pain Conditions + Chronic Airway Diseases + Reflux Disorders,3,203,3.173363,7.49397
83,Cerebral Palsy + Dysphagia + Epilepsy,3,139,2.172894,19.137554
245,Chronic Pain Conditions + Chronic Airway Diseases + Insomnia,3,136,2.125997,6.818377
252,Chronic Pain Conditions + Mental Illness + Insomnia,3,136,2.125997,7.520877
262,Chronic Pain Conditions + Insomnia + Reflux Disorders,3,136,2.125997,4.293926
1165,Chronic Pain Conditions + Chronic Airway Diseases + Insomnia + Reflux Disorders,4,136,2.125997,30.184885


Do you want to save the results to a CSV file? (yes/no): no
Do you want to run the analysis again with different parameters? (yes/no): o
Analysis completed. Thank you for using the script!
Script execution completed.
You can run the cell again to start a new analysis with a different file.


In [3]:
# @title Personalised predictions with sensitivity analysis
import pandas as pd
import numpy as np
import math
import random
import io
import base64
from IPython.display import HTML, display
import ipywidgets as widgets
from pyvis.network import Network
from google.colab import files
from matplotlib import patches
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# Global variables
global_data = None
total_patients_in_group = 0
gender = ''
age_group = ''
results_df = None

condition_categories = {
    "Addisons Disease": "Endocrine",
    "Anaemia": "Blood",
    "Barretts Oesophagus": "Digestive",
    "Bronchiectasis": "Respiratory",
    "Cancer": "Neoplasms",
    "Cardiac Arrhythmias": "Cardiovascular",
    "Cerebral Palsy": "Nervous",
    "Chronic Airway Diseases": "Respiratory",
    "Chronic Arthritis": "Musculoskeletal",
    "Chronic Constipation": "Digestive",
    "Chronic Diarrhoea": "Digestive",
    "Chronic Kidney Disease": "Genitourinary",
    "Chronic Pain Conditions": "Musculoskeletal",
    "Chronic Pneumonia": "Respiratory",
    "Cirrhosis": "Digestive",
    "Coronary Heart Disease": "Cardiovascular",
    "Dementia": "Mental health",
    "Diabetes": "Endocrine",
    "Dysphagia": "Digestive",
    "Epilepsy": "Nervous",
    "Heart Failure": "Cardiovascular",
    "Hearing Loss": "Ear",
    "Hypertension": "Cardiovascular",
    "Inflammatory Bowel Disease": "Digestive",
    "Insomnia": "Nervous",
    "Interstitial Lung Disease": "Respiratory",
    "Mental Illness": "Mental",
    "Menopausal and Perimenopausal": "Genitourinary",
    "Multiple Sclerosis": "Nervous",
    "Neuropathic Pain": "Nervous",
    "Osteoporosis": "Musculoskeletal",
    "Parkinsons": "Nervous",
    "Peripheral Vascular Disease": "Circulatory",
    "Polycystic Ovary Syndrome": "Endocrine",
    "Psoriasis": "Skin",
    "Reflux Disorders": "Digestive",
    "Stroke": "Nervous",
    "Thyroid Disorders": "Endocrine",
    "Tourette": "Mental health",
    "Visual Impairment": "Eye"
}

# For system colors, let's organize them based on the order they appear in the conditions
SYSTEM_COLORS = {
    "Endocrine": "#BA55D3",     # Medium Orchid
    "Blood": "#DC143C",         # Crimson
    "Digestive": "#32CD32",     # Lime Green
    "Respiratory": "#48D1CC",   # Medium Turquoise
    "Neoplasms": "#800080",     # Purple
    "Cardiovascular": "#FF4500", # Orange Red
    "Nervous": "#FFD700",       # Gold
    "Musculoskeletal": "#4682B4", # Steel Blue
    "Genitourinary": "#DAA520", # Goldenrod
    "Mental health": "#8B4513", # Saddle Brown
    "Mental": "#A0522D",       # Sienna
    "Ear": "#4169E1",          # Royal Blue
    "Eye": "#20B2AA",          # Light Sea Green
    "Circulatory": "#FF6347",   # Tomato
    "Skin": "#F08080"          # Light Coral
}

def parse_iqr(iqr_string):
    """Parse IQR string of format 'median [Q1-Q3]' into (median, q1, q3)"""
    try:
        median_str, iqr = iqr_string.split(' [')
        q1, q3 = iqr.strip(']').split('-')
        return float(median_str), float(q1), float(q3)
    except:
        return 0.0, 0.0, 0.0

def load_data():
    global global_data, total_patients_in_group, gender, age_group

    print("Please upload your CSV file.")
    uploaded = files.upload()

    if not uploaded:
        print("No file uploaded. Exiting.")
        return False

    file_name = next(iter(uploaded))
    file_content = uploaded[file_name]

    try:
        global_data = pd.read_csv(io.BytesIO(file_content))
        total_patients_in_group = global_data['TotalPatientsInGroup'].iloc[0]

        file_name_lower = file_name.lower()
        if 'females' in file_name_lower:
            gender = 'Female'
        elif 'males' in file_name_lower:
            gender = 'Male'
        else:
            gender = 'Unknown Gender'

        if 'below45' in file_name_lower:
            age_group = '<45'
        elif '45to64' in file_name_lower:
            age_group = '45-64'
        elif '65plus' in file_name_lower:
            age_group = '65+'
        else:
            age_group = 'Unknown Age Group'

        print(f"File loaded successfully. Total patients: {total_patients_in_group}")
        print(f"Gender: {gender}, Age Group: {age_group}")
        return True

    except Exception as e:
        print(f"Error loading file: {str(e)}")
        return False

def perform_sensitivity_analysis(data):
    """Perform sensitivity analysis with corrected calculations to match Jupyter version"""
    or_thresholds = [2.0, 3.0, 4.0, 5.0]
    results = []
    total_patients = data['TotalPatientsInGroup'].iloc[0]

    for threshold in or_thresholds:
        filtered_data = data[data['OddsRatio'] >= threshold].copy()
        n_trajectories = len(filtered_data)

        # Coverage calculation
        total_pairs = filtered_data['PairFrequency'].sum()
        estimated_unique_patients = total_pairs / 2
        coverage = min((estimated_unique_patients / total_patients) * 100, 100.0)

        # System pairs calculation - matching Jupyter version
        system_pairs = set()
        for _, row in filtered_data.iterrows():
            sys_a = condition_categories.get(row['ConditionA'], 'Other')
            sys_b = condition_categories.get(row['ConditionB'], 'Other')
            if sys_a != sys_b:
                system_pairs.add(tuple(sorted([sys_a, sys_b])))

        # Add debugging print statements here
        print(f"\nFor threshold {threshold}:")
        print(f"Total system pairs found: {len(system_pairs)}")
        print("System pairs:")
        for pair in sorted(system_pairs):
            print(f"{pair[0]} - {pair[1]}")

        # Get top 5 patterns by odds ratio
        top_patterns = data.nlargest(5, 'OddsRatio')[
            ['ConditionA', 'ConditionB', 'OddsRatio', 'PairFrequency',
             'MedianDurationYearsWithIQR', 'DirectionalPercentage', 'Precedence']
        ].to_dict('records')

        # Calculate median durations
        duration_stats = filtered_data['MedianDurationYearsWithIQR'].apply(parse_iqr)
        medians = [x[0] for x in duration_stats if x[0] > 0]
        q1s = [x[1] for x in duration_stats if x[1] > 0]
        q3s = [x[2] for x in duration_stats if x[2] > 0]

        results.append({
            'OR_Threshold': threshold,
            'Num_Trajectories': n_trajectories,
            'Coverage_Percent': round(coverage, 2),
            'System_Pairs': len(system_pairs),
            'Median_Duration': round(np.median(medians) if medians else 0, 2),
            'Q1_Duration': round(np.median(q1s) if q1s else 0, 2),
            'Q3_Duration': round(np.median(q3s) if q3s else 0, 2),
            'Top_Patterns': top_patterns
        })

    return pd.DataFrame(results)

def create_summary_table(results_df):
    summary_html = """
    <h3>Sensitivity Analysis Results</h3>
    <table style='width:50%; border-collapse: collapse; margin-bottom: 20px;'>
        <tr style='background-color: #f2f2f2;'>
            <th style='padding: 8px; border: 1px solid #ddd;'>OR Threshold</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Trajectories</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Coverage (%)</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>System Pairs</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Median Years [IQR]</th>
        </tr>
    """

    for _, row in results_df.iterrows():
        summary_html += f"""
        <tr>
            <td style='padding: 8px; border: 1px solid #ddd;'>{row['OR_Threshold']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{row['Num_Trajectories']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{row['Coverage_Percent']:.2f}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{row['System_Pairs']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{row['Median_Duration']:.2f} [{row['Q1_Duration']:.2f}-{row['Q3_Duration']:.2f}]</td>
        </tr>
        """
    return summary_html

def create_patterns_table(results_df):
    # Only use the patterns from the first threshold since they're the same for all
    patterns = results_df.iloc[0]['Top_Patterns']

    patterns_html = """
    <h3>Top 5 Strongest Trajectories</h3>
    <table style='width:80%; border-collapse: collapse; margin-bottom: 20px;'>
        <tr style='background-color: #f2f2f2;'>
            <th style='padding: 8px; border: 1px solid #ddd;'>Condition A</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Condition B</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Odds Ratio</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Pair Frequency</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Years (Median [IQR])</th>
            <th style='padding: 8px; border: 1px solid #ddd;'>Direction (%)</th>
        </tr>
    """

    for pattern in patterns:
        patterns_html += f"""
        <tr>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['ConditionA']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['ConditionB']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['OddsRatio']:.2f}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['PairFrequency']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['MedianDurationYearsWithIQR']}</td>
            <td style='padding: 8px; border: 1px solid #ddd;'>{pattern['DirectionalPercentage']:.1f}% {pattern['Precedence']}</td>
        </tr>
        """
    patterns_html += "</table>"
    return patterns_html

def create_comprehensive_plot(results_df):
    global gender, age_group

    plt.rcParams.update({
        'font.size': 11,
        'axes.labelsize': 12,
        'axes.titlesize': 14,
        'axes.grid': True,
        'grid.alpha': 0.2,
        'figure.figsize': (12, 8)
    })

    fig = plt.figure()
    ax1 = plt.gca()
    ax2 = ax1.twinx()

    x_vals = results_df['OR_Threshold'].values
    width = 0.2
    bar_heights = results_df['Num_Trajectories']

    bars = ax1.bar(x_vals, bar_heights, alpha=0.3, color='navy', width=width)
    line = ax2.plot(x_vals, results_df['Coverage_Percent'], 'r-o', linewidth=2)

    sizes = (results_df['System_Pairs'] / results_df['System_Pairs'].max()) * 500
    scatter = ax2.scatter(x_vals, results_df['Coverage_Percent'], s=sizes, alpha=0.5, color='darkred')

    # Add IQR information inside bars
    for i, row in results_df.iterrows():
        ax1.text(row['OR_Threshold'], bar_heights[i] * 0.5,
                f"Median: {row['Median_Duration']:.1f}y\nIQR: [{row['Q1_Duration']:.1f}-{row['Q3_Duration']:.1f}]",
                ha='center', va='center', fontsize=10)

    ax1.set_xlabel('Minimum Odds Ratio Threshold')
    ax1.set_ylabel('Number of Disease Trajectories')
    ax2.set_ylabel('Population Coverage (%)')

    legend_elements = [
        patches.Patch(facecolor='navy', alpha=0.3,
                     label='Number of Disease Trajectories\n(Height of bars)'),
        Line2D([0], [0], color='r', marker='o',
               label='Population Coverage %\n(Red line)'),
        Line2D([0], [0], marker='o', color='darkred', alpha=0.5,
               label='Body System Pairs\n(Size of circles)',
               markersize=10, linestyle='None')
    ]
    ax1.legend(handles=legend_elements, loc='upper right')

    if gender and age_group:
        plt.title(f'Impact of Odds Ratio Threshold on Disease Trajectory Analysis in {gender}s {age_group}')
    else:
        plt.title('Impact of Odds Ratio Threshold on Disease Trajectory Analysis in General Population')

    plt.tight_layout()

    return fig


def display_sensitivity_results(results_df):
    if results_df.empty:
        raise ValueError("Empty results DataFrame")

    try:
        display(HTML(create_summary_table(results_df) + create_patterns_table(results_df)))
        fig = create_comprehensive_plot(results_df)
        display(fig)
        plt.close(fig)
    except Exception as e:
        print(f"Error displaying results: {str(e)}")

def create_trajectory_graph(patient_conditions, time_horizon=None, time_margin=None, min_or=2.0):
    """Create a network graph visualizing disease trajectories for given patient conditions.

    Args:
        patient_conditions (list): List of initial patient conditions
        time_horizon (float, optional): Maximum years to look ahead
        time_margin (float, optional): Margin for time window as decimal
        min_or (float, optional): Minimum odds ratio threshold. Defaults to 2.0

    Returns:
        Network: A pyvis Network object containing the trajectory graph
    """
    net = Network(notebook=True, bgcolor='white', font_color='black', height="1200px", width="100%", cdn_resources='in_line')

    # Filter data based on odds ratio
    filtered_data = global_data[global_data['OddsRatio'] >= min_or].copy()

    # Find all connected conditions
    connected_conditions = set()
    for condition_a in patient_conditions:
        time_filtered_data = filtered_data
        if time_horizon and time_margin:
            time_filtered_data = filtered_data[
                (filtered_data['ConditionA'] == condition_a) &
                (filtered_data['MedianDurationYearsWithIQR'].apply(lambda x: parse_iqr(x)[0]) <= time_horizon * (1 + time_margin))
            ]
        conditions_b = set(time_filtered_data[time_filtered_data['ConditionA'] == condition_a]['ConditionB'])
        connected_conditions.update(conditions_b)

    # Define active conditions (initial + connected)
    active_conditions = set(patient_conditions) | connected_conditions

    # Get active categories
    active_categories = {condition_categories[cond] for cond in active_conditions if cond in condition_categories}

    # Set up category positions
    category_positions = {}
    radius = 300
    for i, category in enumerate(sorted(active_categories)):
        angle = i * (2 * math.pi / len(active_categories))
        category_positions[category] = {
            'x': radius * math.cos(angle),
            'y': radius * math.sin(angle)
        }

    # Add legend
    legend_start_x = -600
    legend_start_y = -300
    net.add_node(
        "legend_title",
        label="Legend",
        x=legend_start_x,
        y=legend_start_y - 50,
        size=0,
        font={'size': 20, 'bold': True},
        physics=False,
        fixed=True
    )

    for i, category in enumerate(sorted(active_categories)):
        color = SYSTEM_COLORS[category]
        net.add_node(
            f"legend_{category}",
            label=f"{category}",
            x=legend_start_x + 50,
            y=legend_start_y + (i * 60),
            size=10,
            shape='dot',
            color={'background': color, 'border': color},
            font={'size': 16, 'align': 'left'},
            physics=False,
            fixed=True
        )

    # Add nodes
    for condition in active_conditions:
        category = condition_categories.get(condition, "Other")
        base_color = SYSTEM_COLORS[category]
        pos = category_positions.get(category, {'x': 0, 'y': 0})

        x = pos['x'] + random.uniform(-50, 50)
        y = pos['y'] + random.uniform(-50, 50)

        if condition in patient_conditions:
            # Highlight patient conditions
            net.add_node(condition,
                        label=f"★ {condition}",
                        size=30,
                        title=f"{condition}\nCategory: {category}",
                        color={
                            'background': f"{base_color}50",
                            'border': '#000000',
                            'highlight': {'border': '#000000', 'background': f"{base_color}50"}
                        },
                        borderWidth=2,
                        x=x,
                        y=y,
                        physics=False)
        else:
            # Regular nodes
            net.add_node(condition,
                        label=condition,
                        size=20,
                        title=f"{condition}\nCategory: {category}",
                        color={
                            'background': f"{base_color}50",
                            'border': base_color,
                            'highlight': {'border': base_color, 'background': f"{base_color}50"}
                        },
                        x=x,
                        y=y,
                        physics=False)

    # Add edges with corrected directionality
    for condition_a in patient_conditions:
        relevant_data = filtered_data[filtered_data['ConditionA'] == condition_a]
        if time_horizon and time_margin:
            relevant_data = relevant_data[
                relevant_data['MedianDurationYearsWithIQR'].apply(lambda x: parse_iqr(x)[0]) <= time_horizon * (1 + time_margin)
            ]

        for _, row in relevant_data.iterrows():
            condition_b = row['ConditionB']
            if condition_b not in patient_conditions:
                edge_width = max(1, min(8, math.log2(row['OddsRatio'] + 1)))
                prevalence = (row['PairFrequency'] / total_patients_in_group) * 100
                directional_percentage = row['DirectionalPercentage']

                # Determine edge direction based on DirectionalPercentage
                if directional_percentage >= 50:
                    source, target = condition_a, condition_b
                else:
                    source, target = condition_b, condition_a
                    directional_percentage = 100 - directional_percentage

                edge_label = (f"OR: {row['OddsRatio']:.1f}\n"
                            f"Years: {row['MedianDurationYearsWithIQR']}\n"
                            f"n={row['PairFrequency']} ({prevalence:.1f}%)\n"
                            f"Proceeds: {directional_percentage:.1f}%")

                net.add_edge(source,
                           target,
                           label=edge_label,
                           title=edge_label,
                           width=edge_width,
                           color={'color': 'rgba(128,128,128,0.7)', 'highlight': 'black'},
                           arrows={'to': {'enabled': True}},
                           smooth={'type': 'curvedCW', 'roundness': 0.2})

    # Set network options
    net.set_options('''{
        "physics": {"enabled": false},
        "edges": {
            "font": {"size": 14, "align": "horizontal", "multi": true},
            "arrows": {"to": {"enabled": true, "scaleFactor": 1}},
            "smooth": {"type": "curvedCW", "roundness": 0.2}
        }
    }''')

    return net

def display_pyvis_graph(net):
    html = net.generate_html()
    unique_id = f"graph_{np.random.randint(0, 1000000)}"
    filename = f"{unique_id}.html"

    with open(filename, 'w', encoding='utf-8') as f:
        f.write(html)

    with open(filename, 'r', encoding='utf-8') as f:
        html_content = f.read()

    b64 = base64.b64encode(html_content.encode()).decode()

    display(HTML(f"""
    <div style="display: flex; justify-content: space-between;">
        <div style="width: 25%;">
            <h3>How to Read the Graph:</h3>
            <ul>
                <li>Nodes represent conditions, colored by body system category.</li>
                <li>★ marks initial conditions.</li>
                <li>Edge labels show:
                    <ul>
                        <li>OR: Odds Ratio</li>
                        <li>Years: Median time [IQR]</li>
                        <li>n: Patient pairs</li>
                        <li>Proceeds: Percentage first condition precedes second</li>
                    </ul>
                </li>
                <li>Edge thickness represents odds ratio strength</li>
            </ul>
        </div>
        <div style="width: 75%;">
            <iframe src="data:text/html;base64,{b64}" width="100%" height="800px"></iframe>
        </div>
    </div>
    """))

    save_button = widgets.Button(description="Save Graph as HTML")
    save_output = widgets.Output()

    def on_save_button_clicked(b):
        with save_output:
            save_output.clear_output()
            files.download(filename)
            print(f"Graph saved as {filename}")

    save_button.on_click(on_save_button_clicked)
    display(save_button, save_output)

def run_patient_prediction():
    global global_data

    if global_data is None:
        if not load_data():
            print("Failed to load data. Please upload a valid CSV file.")
            return

    unique_conditions = sorted(set(global_data['ConditionA'].unique()) | set(global_data['ConditionB'].unique()))
    max_years = math.ceil(global_data['MedianDurationYearsWithIQR'].apply(lambda x: parse_iqr(x)[0]).max())

    sensitivity_button = widgets.Button(description="Run Sensitivity Analysis")
    sensitivity_output = widgets.Output()

    def on_sensitivity_button_clicked(b):
        with sensitivity_output:
            sensitivity_output.clear_output()
            results = perform_sensitivity_analysis(global_data)
            display_sensitivity_results(results)

    sensitivity_button.on_click(on_sensitivity_button_clicked)

    min_or_slider = widgets.FloatSlider(
        value=2.0, min=1.0, max=10.0, step=0.5,
        description='Min OR:',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    condition_select = widgets.SelectMultiple(
        options=unique_conditions,
        description='Select initial conditions:',
        layout=widgets.Layout(width='70%', height='200px')
    )

    time_horizon_slider = widgets.IntSlider(
        value=min(5, max_years), min=1, max=max_years,
        description='Time Horizon (years):',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    time_margin_slider = widgets.FloatSlider(
        value=0.1, min=0, max=0.5, step=0.05,
        description='Time Margin:',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    run_button = widgets.Button(description="Run Analysis")
    output = widgets.Output()

    def on_run_button_clicked(b):
        with output:
            output.clear_output()
            patient_conditions = list(condition_select.value)

            if not patient_conditions:
                print("Please select at least one initial condition.")
                return

            net = create_trajectory_graph(
                patient_conditions,
                time_horizon_slider.value,
                time_margin_slider.value,
                min_or_slider.value
            )
            display_pyvis_graph(net)

    run_button.on_click(on_run_button_clicked)

    data_summary = widgets.HTML(
        value=f"<p>Data Summary:<br>Maximum progression time: {max_years} years<br>"
              f"Total patients: {total_patients_in_group}<br>"
              f"Gender: {gender}<br>Age Group: {age_group}</p>"
    )

    display(data_summary)
    display(sensitivity_button, sensitivity_output)
    display(condition_select, time_horizon_slider, time_margin_slider, min_or_slider, run_button, output)

if __name__ == "__main__":
    print("Welcome to the Patient Trajectory Predictor!")
    print("Please upload your CSV file to begin.")
    run_patient_prediction()

Welcome to the Patient Trajectory Predictor!
Please upload your CSV file to begin.
Please upload your CSV file.


Saving Females_fdr_significant_high_freq_odds_ratio_analysis_below45.csv to Females_fdr_significant_high_freq_odds_ratio_analysis_below45 (1).csv
File loaded successfully. Total patients: 6397
Gender: Female, Age Group: <45


HTML(value='<p>Data Summary:<br>Maximum progression time: 15 years<br>Total patients: 6397<br>Gender: Female<b…

Button(description='Run Sensitivity Analysis', style=ButtonStyle())

Output()

SelectMultiple(description='Select initial conditions:', layout=Layout(height='200px', width='70%'), options=(…

IntSlider(value=5, continuous_update=False, description='Time Horizon (years):', layout=Layout(width='70%'), m…

FloatSlider(value=0.1, continuous_update=False, description='Time Margin:', layout=Layout(width='70%'), max=0.…

FloatSlider(value=2.0, continuous_update=False, description='Min OR:', layout=Layout(width='70%'), max=10.0, m…

Button(description='Run Analysis', style=ButtonStyle())

Output()

In [4]:
import pandas as pd
import numpy as np
import math
import random
import io
import base64
from IPython.display import HTML, display
import ipywidgets as widgets
from pyvis.network import Network
from google.colab import files
from matplotlib import patches
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# Global variables
global_data = None
total_patients_in_group = 0
gender = ''
age_group = ''
results_df = None

# System categories dictionary
condition_categories = {
    "Addisons Disease": "Endocrine",
    "Anaemia": "Blood",
    "Barretts Oesophagus": "Digestive",
    "Bronchiectasis": "Respiratory",
    "Cancer": "Neoplasms",
    "Cardiac Arrhythmias": "Cardiovascular",
    "Cerebral Palsy": "Nervous",
    "Chronic Airway Diseases": "Respiratory",
    "Chronic Arthritis": "Musculoskeletal",
    "Chronic Constipation": "Digestive",
    "Chronic Diarrhoea": "Digestive",
    "Chronic Kidney Disease": "Genitourinary",
    "Chronic Pain Conditions": "Musculoskeletal",
    "Chronic Pneumonia": "Respiratory",
    "Cirrhosis": "Digestive",
    "Coronary Heart Disease": "Cardiovascular",
    "Dementia": "Mental health",
    "Diabetes": "Endocrine",
    "Dysphagia": "Digestive",
    "Epilepsy": "Nervous",
    "Heart Failure": "Cardiovascular",
    "Hearing Loss": "Ear",
    "Hypertension": "Cardiovascular",
    "Inflammatory Bowel Disease": "Digestive",
    "Insomnia": "Nervous",
    "Interstitial Lung Disease": "Respiratory",
    "Mental Illness": "Mental",
    "Menopausal and Perimenopausal": "Genitourinary",
    "Multiple Sclerosis": "Nervous",
    "Neuropathic Pain": "Nervous",
    "Osteoporosis": "Musculoskeletal",
    "Parkinsons": "Nervous",
    "Peripheral Vascular Disease": "Circulatory",
    "Polycystic Ovary Syndrome": "Endocrine",
    "Psoriasis": "Skin",
    "Reflux Disorders": "Digestive",
    "Stroke": "Nervous",
    "Thyroid Disorders": "Endocrine",
    "Tourette": "Mental health",
    "Visual Impairment": "Eye"
}

# System colors dictionary
SYSTEM_COLORS = {
    "Endocrine": "#BA55D3",     # Medium Orchid
    "Blood": "#DC143C",         # Crimson
    "Digestive": "#32CD32",     # Lime Green
    "Respiratory": "#48D1CC",   # Medium Turquoise
    "Neoplasms": "#800080",     # Purple
    "Cardiovascular": "#FF4500", # Orange Red
    "Nervous": "#FFD700",       # Gold
    "Musculoskeletal": "#4682B4", # Steel Blue
    "Genitourinary": "#DAA520", # Goldenrod
    "Mental health": "#8B4513", # Saddle Brown
    "Mental": "#A0522D",       # Sienna
    "Ear": "#4169E1",          # Royal Blue
    "Eye": "#20B2AA",          # Light Sea Green
    "Circulatory": "#FF6347",   # Tomato
    "Skin": "#F08080"          # Light Coral
}

def parse_iqr(iqr_string):
    """Parse IQR string of format 'median [Q1-Q3]' into (median, q1, q3)"""
    try:
        median_str, iqr = iqr_string.split(' [')
        q1, q3 = iqr.strip(']').split('-')
        return float(median_str), float(q1), float(q3)
    except:
        return 0.0, 0.0, 0.0

def load_data():
    """Load and process the CSV file with trajectory data"""
    global global_data, total_patients_in_group, gender, age_group

    print("Please upload your CSV file.")
    uploaded = files.upload()

    if not uploaded:
        print("No file uploaded. Exiting.")
        return False

    file_name = next(iter(uploaded))
    file_content = uploaded[file_name]

    try:
        global_data = pd.read_csv(io.BytesIO(file_content))
        total_patients_in_group = global_data['TotalPatientsInGroup'].iloc[0]

        file_name_lower = file_name.lower()
        if 'females' in file_name_lower:
            gender = 'Female'
        elif 'males' in file_name_lower:
            gender = 'Male'
        else:
            gender = 'Unknown Gender'

        if 'below45' in file_name_lower:
            age_group = '<45'
        elif '45to64' in file_name_lower:
            age_group = '45-64'
        elif '65plus' in file_name_lower:
            age_group = '65+'
        else:
            age_group = 'Unknown Age Group'

        print(f"File loaded successfully. Total patients: {total_patients_in_group}")
        print(f"Gender: {gender}, Age Group: {age_group}")
        return True

    except Exception as e:
        print(f"Error loading file: {str(e)}")
        return False

def create_personalized_trajectory_analysis(patient_conditions, time_horizon=None, time_margin=None, min_or=2.0):
    """Create a personalized analysis of disease trajectories for a patient's conditions"""
    filtered_data = global_data[global_data['OddsRatio'] >= min_or].copy()

    def get_risk_level(odds_ratio):
        if odds_ratio >= 5:
            return "High", "#dc3545"
        elif odds_ratio >= 3:
            return "Moderate", "#ffc107"
        else:
            return "Low", "#28a745"

    html = """
    <style>
        .patient-analysis {
            font-family: Arial, sans-serif;
            max-width: 1200px;
            margin: 20px auto;
        }
        .condition-section {
            margin-bottom: 30px;
            border: 1px solid #ddd;
            border-radius: 8px;
            padding: 20px;
            background-color: #f8f9fa;
        }
        .condition-header {
            font-size: 1.2em;
            color: #2c5282;
            margin-bottom: 15px;
            padding-bottom: 10px;
            border-bottom: 2px solid #e2e8f0;
        }
        .trajectory-table {
            width: 100%;
            border-collapse: collapse;
            margin: 10px 0;
            background-color: white;
        }
        .trajectory-table th {
            background-color: #f5f5f5;
            padding: 12px;
            text-align: left;
            border: 1px solid #ddd;
        }
        .trajectory-table td {
            padding: 10px;
            border: 1px solid #ddd;
        }
        .risk-badge {
            padding: 4px 8px;
            border-radius: 4px;
            color: white;
            font-weight: bold;
        }
        .system-tag {
            display: inline-block;
            padding: 2px 6px;
            border-radius: 4px;
            background-color: #e2e8f0;
            font-size: 0.9em;
            margin-right: 5px;
        }
        .timeline-indicator {
            font-style: italic;
            color: #666;
        }
    </style>
    <div class="patient-analysis">
        <h2>Personalized Disease Trajectory Analysis</h2>
        <p>Based on current conditions: """ + ", ".join(patient_conditions) + "</p>"

    for condition_a in patient_conditions:
        time_filtered_data = filtered_data[filtered_data['ConditionA'] == condition_a]
        if time_horizon and time_margin:
            time_filtered_data = time_filtered_data[
                time_filtered_data['MedianDurationYearsWithIQR'].apply(
                    lambda x: parse_iqr(x)[0]) <= time_horizon * (1 + time_margin)
            ]

        if not time_filtered_data.empty:
            system_a = condition_categories.get(condition_a, 'Other')
            html += f"""
            <div class="condition-section">
                <div class="condition-header">
                    <span class="system-tag">{system_a}</span>
                    Progression Paths from {condition_a}
                </div>
                <table class="trajectory-table">
                    <thead>
                        <tr>
                            <th>Risk Level</th>
                            <th>Potential Progression</th>
                            <th>Expected Timeline</th>
                            <th>Statistical Support</th>
                            <th>Progression Details</th>
                        </tr>
                    </thead>
                    <tbody>
            """

            for _, row in time_filtered_data.sort_values('OddsRatio', ascending=False).iterrows():
                condition_b = row['ConditionB']
                if condition_b not in patient_conditions:
                    system_b = condition_categories.get(condition_b, 'Other')
                    median, q1, q3 = parse_iqr(row['MedianDurationYearsWithIQR'])
                    prevalence = (row['PairFrequency'] / total_patients_in_group) * 100
                    risk_level, color = get_risk_level(row['OddsRatio'])

                    if row['DirectionalPercentage'] >= 50:
                        direction = f"{condition_a} → {condition_b}"
                        confidence = row['DirectionalPercentage']
                    else:
                        direction = f"{condition_b} → {condition_a}"
                        confidence = 100 - row['DirectionalPercentage']

                    html += f"""
                        <tr>
                            <td><span class="risk-badge" style="background-color: {color}">{risk_level}</span></td>
                            <td>
                                <strong>{condition_b}</strong><br>
                                <span class="system-tag">{system_b}</span>
                            </td>
                            <td class="timeline-indicator">
                                Typically {median:.1f} years<br>
                                Range: {q1:.1f} to {q3:.1f} years
                            </td>
                            <td>
                                OR: {row['OddsRatio']:.1f}<br>
                                {row['PairFrequency']} cases ({prevalence:.1f}%)
                            </td>
                            <td>
                                {confidence:.1f}% confidence in progression order<br>
                                {direction}
                            </td>
                        </tr>
                    """

            html += """
                    </tbody>
                </table>
            </div>
            """

    html += """
        <div style="margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 8px;">
            <h4>Understanding This Analysis:</h4>
            <ul>
                <li><strong>Risk Level:</strong> Based on odds ratio strength (High: OR≥5, Moderate: OR≥3, Low: OR≥2)</li>
                <li><strong>Expected Timeline:</strong> Median years and range between which progression typically occurs</li>
                <li><strong>Statistical Support:</strong> Odds ratio and number of observed cases in the population</li>
                <li><strong>Progression Details:</strong> Confidence in the order of disease progression</li>
            </ul>
        </div>
    </div>
    """

    return HTML(html)

def run_patient_prediction():
    """Main function to run the patient trajectory prediction analysis"""
    global global_data

    if global_data is None:
        if not load_data():
            print("Failed to load data. Please upload a valid CSV file.")
            return

    unique_conditions = sorted(set(global_data['ConditionA'].unique()) | set(global_data['ConditionB'].unique()))
    max_years = math.ceil(global_data['MedianDurationYearsWithIQR'].apply(lambda x: parse_iqr(x)[0]).max())

    # Create widgets
    min_or_slider = widgets.FloatSlider(
        value=2.0,
        min=1.0,
        max=10.0,
        step=0.5,
        description='Minimum Odds Ratio:',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    condition_select = widgets.SelectMultiple(
        options=unique_conditions,
        description='Select conditions:',
        layout=widgets.Layout(width='70%', height='200px')
    )

    time_horizon_slider = widgets.IntSlider(
        value=min(5, max_years),
        min=1,
        max=max_years,
        description='Time Horizon (years):',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    time_margin_slider = widgets.FloatSlider(
        value=0.1,
        min=0,
        max=0.5,
        step=0.05,
        description='Time Margin:',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    run_button = widgets.Button(description="Analyze Patient Trajectory")
    output = widgets.Output()

    def on_run_button_clicked(b):
        with output:
            output.clear_output()
            patient_conditions = list(condition_select.value)

            if not patient_conditions:
                print("Please select at least one condition to analyze.")
                return

            # Display personalized trajectory analysis
            display(create_personalized_trajectory_analysis(
                patient_conditions,
                time_horizon_slider.value,
                time_margin_slider.value,
                min_or_slider.value
            ))

    run_button.on_click(on_run_button_clicked)

    # Display widgets
    data_summary = widgets.HTML(
        value=f"<p><strong>Data Summary:</strong><br>"
              f"Maximum progression time: {max_years} years<br>"
              f"Total patients: {total_patients_in_group:,}<br>"
              f"Gender: {gender}<br>"
              f"Age Group: {age_group}</p>"
    )

    display(widgets.HTML("<h2>Patient Trajectory Analysis</h2>"))
    display(data_summary)
    display(widgets.HTML("<p>1. Select the patient's current conditions:</p>"))
    display(condition_select)
    display(widgets.HTML("<p>2. Adjust analysis parameters (optional):</p>"))
    display(time_horizon_slider)
    display(time_margin_slider)
    display(min_or_slider)
    display(widgets.HTML("<p>3. Click to generate analysis:</p>"))
    display(run_button)
    display(output)

def create_network_visualization(patient_conditions, time_horizon=None, time_margin=None, min_or=2.0):
    """Create an interactive network visualization of the patient's disease trajectories"""
    net = Network(notebook=True, bgcolor='white', font_color='black', height="800px", width="100%")

    # Filter data
    filtered_data = global_data[global_data['OddsRatio'] >= min_or].copy()

    # Add nodes for current conditions
    for condition in patient_conditions:
        system = condition_categories.get(condition, 'Other')
        color = SYSTEM_COLORS.get(system, '#808080')
        net.add_node(condition,
                    label=f"★ {condition}",
                    title=f"Current Condition: {condition}\nSystem: {system}",
                    color=color,
                    size=30,
                    borderWidth=2)

    # Add potential progression nodes and edges
    for condition_a in patient_conditions:
        trajectories = filtered_data[filtered_data['ConditionA'] == condition_a]

        if time_horizon and time_margin:
            trajectories = trajectories[
                trajectories['MedianDurationYearsWithIQR'].apply(
                    lambda x: parse_iqr(x)[0]) <= time_horizon * (1 + time_margin)
            ]

        for _, row in trajectories.iterrows():
            condition_b = row['ConditionB']
            if condition_b not in patient_conditions:
                # Add progression node
                system = condition_categories.get(condition_b, 'Other')
                color = SYSTEM_COLORS.get(system, '#808080')
                net.add_node(condition_b,
                            label=condition_b,
                            title=f"Potential Condition: {condition_b}\nSystem: {system}",
                            color=color,
                            size=20)

                # Add edge with details
                median, q1, q3 = parse_iqr(row['MedianDurationYearsWithIQR'])
                edge_title = (f"OR: {row['OddsRatio']:.1f}\n"
                            f"Timeline: {median:.1f} years [{q1:.1f}-{q3:.1f}]\n"
                            f"Cases: {row['PairFrequency']} ({row['PairFrequency']/total_patients_in_group*100:.1f}%)\n"
                            f"Direction Confidence: {row['DirectionalPercentage']:.1f}%")

                # Edge thickness based on odds ratio
                width = max(1, min(8, math.log2(row['OddsRatio'] + 1)))

                net.add_edge(condition_a,
                            condition_b,
                            title=edge_title,
                            width=width,
                            arrows={'to': {'enabled': True}})

    # Configure network options
    net.set_options("""
    {
        "physics": {
            "enabled": true,
            "forceAtlas2Based": {
                "gravitationalConstant": -50,
                "springLength": 200
            },
            "solver": "forceAtlas2Based",
            "stabilization": {
                "iterations": 50
            }
        },
        "edges": {
            "smooth": {"type": "continuous"},
            "color": {"inherit": false, "color": "#666666"}
        },
        "interaction": {
            "hover": true,
            "tooltipDelay": 100
        }
    }
    """)

    return net

def main():
    """Main function to run the application"""
    print("Welcome to the Patient Disease Trajectory Analyzer!")
    print("\nThis tool helps analyze potential disease progressions based on:")
    print("- Current medical conditions")
    print("- Population-level statistics")
    print("- Time-based progression patterns")
    print("\nTo begin, please upload your patient trajectory data file.")

    run_patient_prediction()

if __name__ == "__main__":
    main()

Welcome to the Patient Disease Trajectory Analyzer!

This tool helps analyze potential disease progressions based on:
- Current medical conditions
- Population-level statistics
- Time-based progression patterns

To begin, please upload your patient trajectory data file.
Please upload your CSV file.


Saving Females_fdr_significant_high_freq_odds_ratio_analysis_below45.csv to Females_fdr_significant_high_freq_odds_ratio_analysis_below45 (2).csv
File loaded successfully. Total patients: 6397
Gender: Female, Age Group: <45


HTML(value='<h2>Patient Trajectory Analysis</h2>')

HTML(value='<p><strong>Data Summary:</strong><br>Maximum progression time: 15 years<br>Total patients: 6,397<b…

HTML(value="<p>1. Select the patient's current conditions:</p>")

SelectMultiple(description='Select conditions:', layout=Layout(height='200px', width='70%'), options=('Anaemia…

HTML(value='<p>2. Adjust analysis parameters (optional):</p>')

IntSlider(value=5, continuous_update=False, description='Time Horizon (years):', layout=Layout(width='70%'), m…

FloatSlider(value=0.1, continuous_update=False, description='Time Margin:', layout=Layout(width='70%'), max=0.…

FloatSlider(value=2.0, continuous_update=False, description='Minimum Odds Ratio:', layout=Layout(width='70%'),…

HTML(value='<p>3. Click to generate analysis:</p>')

Button(description='Analyze Patient Trajectory', style=ButtonStyle())

Output()