In [59]:
import os
import sys
sys.path.append('../src')

%load_ext autoreload
%autoreload 2
from ai_cdss.services.pipeline import PipelineBase

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
PATIENT_LIST = [
    775,  787,  788, 1123, 1169, 1170, 1171, 1172, 1173, 1983, 2110, 2195,
    2955, 2956, 2957, 2958, 2959, 2960, 2961, 2962, 2963, 3081, 3229, 3318, 3432
]

latent_to_clinical_mapping_nest = {
    # Functional Independence
    "BARTHEL": ["DAILY_LIVING_ACTIVITY"],  # Barthel Index measures independence in ADLs.

    # Motor Function (Spasticity & Strength)
    "ASH_PROXIMAL": ["BODY_PART_ARM", "BODY_PART_SHOULDER", "COORDINATION"],  # Ashworth scale for proximal limb spasticity.
    "MA_DISTAL": ["BODY_PART_FINGER", "BODY_PART_WRIST", "GRASPING", "PINCHING"],  # Motor Assessment for distal motor function.

    # Fatigue & Pain
    "FATIGUE": ["DIFFICULTY_COGNITIVE", "DIFFICULTY_MOTOR", "PROCESSING_SPEED", "ATTENTION"],  # Fatigue relates to cognitive/motor difficulty.
    "VAS": ["DIFFICULTY_COGNITIVE", "DIFFICULTY_MOTOR"],  # Visual Analog Scale (VAS) for perceived effort.

    # Fugl-Meyer Subscales (Motor Control & Coordination)
    "FM_A": ["BODY_PART_ARM", "BODY_PART_SHOULDER", "RANGE_OF_MOTION_H", "RANGE_OF_MOTION_V"],  # Upper Limb Motor
    "FM_B": ["BODY_PART_WRIST", "PRONATION_SUPINATION", "RANGE_OF_MOTION_H"],  # Wrist Motor
    "FM_C": ["BODY_PART_FINGER", "GRASPING", "PINCHING"],  # Hand Motor
    "FM_D": ["COORDINATION", "RANGE_OF_MOTION_H", "RANGE_OF_MOTION_V"],  # Coordination & Speed
    "FM_TOTAL": ["BODY_PART_ARM", "BODY_PART_WRIST", "BODY_PART_FINGER", "COORDINATION"],  # Full Upper Limb Score

    # Activity & Movement Quality
    "ACT_AU": ["BODY_PART_TRUNK"],  # Activity Autonomy linked to balance.
    "ACT_QOM": ["COORDINATION"],  # Quality of Movement related to balance & coordination.
}

pipeline = PipelineBase(
    PATIENT_LIST, 
    clinical_score_path="../../data/clinical_scores.csv", 
    protocol_csv_path="../../data/protocol_attributes.csv",
    mapping_dict=latent_to_clinical_mapping_nest
)

In [61]:
pipeline.extract_data()

Database engine created successfully
Data successfully saved to ../../data/app_data.csv
Database engine closed
Database engine created successfully
Data successfully saved to ../../data/plus_data.csv
Database engine closed
Database engine created successfully
Data successfully saved to ../../data/app_timeseries.csv
Database engine closed
Database engine created successfully
Data successfully saved to ../../data/plus_timeseries.csv
Database engine closed


In [62]:
pipeline.process_data()

In [78]:
pipeline.compute_scores()

In [79]:
pipeline.prescriptions

Unnamed: 0,PATIENT_ID,PROTOCOL_ID,PPF,ADHERENCE_EWMA,PARAMETER_VALUE_EWMA,PERFORMANCE_VALUE_EWMA,Score,CONTRIBUTION
0,775,10,0.145801,1.0,0.0,0.0,1.145801,"[0.0, 0.0, 0.0, 0.12336987390024973, 0.0224308..."
1,775,200,0.246830,1.0,0.0,0.0,1.246830,"[0.0, 0.0, 0.0, 0.09620541511261894, 0.0174918..."
2,775,201,0.629923,1.0,0.0,0.0,1.629923,"[0.366747892208503, 0.0, 0.0, 0.10085567035733..."
3,775,202,0.312915,1.0,0.0,0.0,1.312915,"[0.0, 0.0, 0.0, 0.09415993145159483, 0.0085599..."
4,775,203,0.318869,1.0,0.0,0.0,1.318869,"[0.0, 0.0, 0.0, 0.14447600272668604, 0.0131341..."
...,...,...,...,...,...,...,...,...
770,3432,229,0.035583,1.0,0.0,0.0,1.035583,"[0.0, 0.0, 0.0, 0.03558285063293994, 0.0, 0.0,..."
771,3432,230,0.091694,1.0,0.0,0.0,1.091694,"[0.0, 0.0, 0.0, 0.042320237338746854, 0.0, 0.0..."
772,3432,231,0.047739,1.0,0.0,0.0,1.047739,"[0.0, 0.0, 0.0, 0.047739403709085075, 0.0, 0.0..."
773,3432,232,0.035583,1.0,0.0,0.0,1.035583,"[0.0, 0.0, 0.0, 0.03558285063293994, 0.0, 0.0,..."


In [65]:
scores

Unnamed: 0,PATIENT_ID,PROTOCOL_ID,PPF,ADHERENCE_EWMA,PARAMETER_VALUE_EWMA,PERFORMANCE_VALUE_EWMA,Score
54,775,10,0.145801,1.0,0.0,0.0,1.145801
49,775,200,0.246830,1.0,0.0,0.0,1.246830
43,775,201,0.629923,1.0,0.0,0.0,1.629923
42,775,202,0.312915,1.0,0.0,0.0,1.312915
44,775,203,0.318869,1.0,0.0,0.0,1.318869
...,...,...,...,...,...,...,...
774,3432,229,0.035583,1.0,0.0,0.0,1.035583
772,3432,230,0.091694,1.0,0.0,0.0,1.091694
766,3432,231,0.047739,1.0,0.0,0.0,1.047739
771,3432,232,0.035583,1.0,0.0,0.0,1.035583


In [66]:
scores[["PATIENT_ID", "PROTOCOL_ID", "Score"]]

Unnamed: 0,PATIENT_ID,PROTOCOL_ID,Score
54,775,10,1.145801
49,775,200,1.246830
43,775,201,1.629923
42,775,202,1.312915
44,775,203,1.318869
...,...,...,...
774,3432,229,1.035583
772,3432,230,1.091694
766,3432,231,1.047739
771,3432,232,1.035583


In [67]:
n = 5    
top_n_protocols = scores.groupby('PATIENT_ID').apply(lambda x: x.nlargest(n, 'Score')).reset_index(drop=True)
top_n_protocols

  top_n_protocols = scores.groupby('PATIENT_ID').apply(lambda x: x.nlargest(n, 'Score')).reset_index(drop=True)


Unnamed: 0,PATIENT_ID,PROTOCOL_ID,PPF,ADHERENCE_EWMA,PARAMETER_VALUE_EWMA,PERFORMANCE_VALUE_EWMA,Score
0,775,222,0.629923,0.999770,1.000000,1.000000,2.629693
1,775,206,0.562494,0.853333,1.000000,0.832768,2.415827
2,775,224,0.629923,1.000000,0.673047,0.710473,2.302970
3,775,208,0.675454,0.606097,0.994970,0.974205,2.276521
4,775,214,0.509417,0.394473,0.950000,0.937500,1.853890
...,...,...,...,...,...,...,...
120,3432,214,0.417353,0.940556,0.953125,1.000000,2.311033
121,3432,219,0.551721,1.000000,0.359098,0.515625,1.910819
122,3432,225,0.822870,1.000000,0.000000,0.000000,1.822870
123,3432,227,0.754423,1.000000,0.000000,0.000000,1.754423


In [68]:
# Find top n protocols by score per patient_id
scores.groupby("PATIENT_ID").apply(lambda x: x.nlargest(1, "Score"))

  scores.groupby("PATIENT_ID").apply(lambda x: x.nlargest(1, "Score"))


Unnamed: 0_level_0,Unnamed: 1_level_0,PATIENT_ID,PROTOCOL_ID,PPF,ADHERENCE_EWMA,PARAMETER_VALUE_EWMA,PERFORMANCE_VALUE_EWMA,Score
PATIENT_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
775,45,775,222,0.629923,0.99977,1.0,1.0,2.629693
787,94,787,223,0.578258,1.0,0.461435,0.693477,2.039692
788,64,788,208,0.648354,1.0,0.921875,0.547242,2.570229
1123,126,1123,208,0.662383,0.904965,0.999576,0.980844,2.566924
1169,169,1169,222,0.629731,1.0,1.0,0.999997,2.629731
1170,186,1170,214,0.431142,1.0,0.5125,0.958375,1.943642
1171,231,1171,222,0.676896,1.0,0.611667,0.66591,2.288563
1172,250,1172,208,0.624634,1.0,0.999994,0.895814,2.624628
1173,279,1173,214,0.42909,1.0,0.992676,0.899252,2.421766
1983,331,1983,216,0.846066,1.0,0.0,0.0,1.846066


In [69]:
pipeline.data_processor.get_protocol(protocol_path="../../data/protocol_attributes.csv")

Unnamed: 0_level_0,PROTOCOL_ID,DIFFICULTY_COGNITIVE,DIFFICULTY_MOTOR,BODY_PART_FINGER,BODY_PART_WRIST,BODY_PART_ARM,BODY_PART_SHOULDER,BODY_PART_TRUNK,REACHING,GRASPING,...,ATTENTION,VISUAL_LANGUAGE,VISUALSPATIAL_PROCESSING_AWARENESS_NEGLECT,COORDINATION,MEMORY_WM,MEMORY_SEMANTIC,MATH,DAILY_LIVING_ACTIVITY,SYMBOLIC_UNDERSTANDING,SEMANTIC_PROCESSING
PROTOCOL_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Blobs,214,2,2,1,0,0,0,0,0,0,...,1,0,1,1,1,0,1,1,0,0
Twister,223,2,2,1,0,0,0,0,1,1,...,1,0,1,1,1,0,0,1,0,0
Shelves (AR),208,1,1,0,0,1,1,0,1,1,...,0,1,0,0,1,1,0,1,0,0
Place it,204,2,2,0,0,1,1,0,1,1,...,0,1,1,1,1,1,0,1,0,1
Alphabet (AR),205,1,1,0,0,1,1,0,1,0,...,0,0,1,1,0,0,0,1,0,0
Hockey (AR),219,0,2,0,0,1,1,0,1,1,...,0,1,1,1,1,1,0,1,0,0
Ducks (AR),209,0,2,0,0,1,1,1,1,1,...,1,1,1,1,1,1,0,1,0,0
Balloons (AR),206,0,2,0,0,1,1,0,1,0,...,1,1,1,1,1,1,0,1,0,0
Tubes (AR),226,2,2,0,1,1,1,0,1,1,...,1,1,1,1,1,1,0,1,0,0
Fishing day (AR),221,1,0,0,0,1,1,0,1,1,...,1,1,1,1,1,1,0,1,0,0


In [70]:
pipeline.sessions

Unnamed: 0,PATIENT_ID,HOSPITAL_ID,PARETIC_SIDE,UPPER_EXTREMITY_TO_TRAIN,HAND_RAISING_CAPACITY,COGNITIVE_FUNCTION_LEVEL,HAS_HEMINEGLIGENCE,GENDER,SKIN_COLOR,AGE,...,AR_MODE,WEEKDAY,REAL_SESSION_DURATION,PRESCRIBED_SESSION_DURATION,SESSION_DURATION,ADHERENCE,TOTAL_SUCCESS,TOTAL_ERRORS,SCORE,ADHERENCE_EWMA
0,775,40,LEFT,LEFT,LOW,MEDIUM,0,FEMALE,FDC3AD,88.0,...,NONE,FRIDAY,492.0,300.0,300,1.000000,99,8,231,1.000000
1,775,40,LEFT,LEFT,LOW,MEDIUM,0,FEMALE,FDC3AD,88.0,...,NONE,FRIDAY,338.0,300.0,300,1.000000,64,17,88,1.000000
2,775,40,LEFT,LEFT,LOW,MEDIUM,0,FEMALE,FDC3AD,88.0,...,TABLE,FRIDAY,280.0,240.0,240,1.000000,0,0,0,1.000000
3,775,40,LEFT,LEFT,LOW,MEDIUM,0,FEMALE,FDC3AD,88.0,...,TABLE,FRIDAY,391.0,300.0,300,1.000000,1,2,1,1.000000
4,775,40,LEFT,LEFT,LOW,MEDIUM,0,FEMALE,FDC3AD,88.0,...,NONE,MONDAY,472.0,300.0,300,1.000000,86,10,222,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2797,3432,41,RIGHT,RIGHT,HIGH,MEDIUM_HIGH,0,FEMALE,D09C80,55.0,...,NONE,MONDAY,309.0,300.0,300,1.000000,21,0,21,1.000000
2798,3432,41,RIGHT,RIGHT,HIGH,MEDIUM_HIGH,0,FEMALE,D09C80,55.0,...,NONE,MONDAY,336.0,300.0,300,1.000000,15,21,42,1.000000
2799,3432,41,RIGHT,RIGHT,HIGH,MEDIUM_HIGH,0,FEMALE,D09C80,55.0,...,NONE,MONDAY,360.0,300.0,300,1.000000,0,0,220,0.998752
2800,3432,41,RIGHT,RIGHT,HIGH,MEDIUM_HIGH,0,FEMALE,D09C80,55.0,...,NONE,MONDAY,400.0,300.0,300,1.000000,81,187,81,1.000000


In [71]:
pipeline.compute_scores()

In [72]:
pipeline.protocol_profiles.columns

Index(['BARTHEL', 'ASH_PROXIMAL', 'MA_DISTAL', 'FATIGUE', 'VAS', 'FM_A',
       'FM_B', 'FM_C', 'FM_D', 'FM_TOTAL', 'ACT_AU', 'ACT_QOM'],
      dtype='object')

In [73]:
pipeline.scoring_computer.compute_protocol_similarity(pipeline.protocol_profiles)

PROTOCOL_ID,214,223,208,204,205,219,209,206,226,221,...,216,231,10,211,220,228,232,230,233,229
PROTOCOL_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
214,1.0,0.944444,0.622222,0.813889,0.605556,0.661111,0.530556,0.558333,0.744444,0.684722,...,0.573611,0.538889,0.525,0.455556,0.380556,0.441667,0.559722,0.548611,0.518056,0.559722
223,0.944444,1.0,0.566667,0.758333,0.55,0.605556,0.475,0.502778,0.688889,0.629167,...,0.518056,0.483333,0.469444,0.4,0.325,0.386111,0.504167,0.493056,0.4625,0.504167
208,0.622222,0.566667,1.0,0.719444,0.761111,0.816667,0.658333,0.686111,0.677778,0.8125,...,0.695833,0.661111,0.647222,0.577778,0.702778,0.708333,0.640278,0.684722,0.715278,0.640278
204,0.813889,0.758333,0.719444,1.0,0.791667,0.847222,0.716667,0.744444,0.902778,0.809722,...,0.415278,0.380556,0.366667,0.297222,0.455556,0.427778,0.401389,0.473611,0.504167,0.401389
205,0.605556,0.55,0.761111,0.791667,1.0,0.944444,0.786111,0.925,0.805556,0.851389,...,0.568056,0.533333,0.519444,0.45,0.663889,0.636111,0.5125,0.556944,0.5875,0.5125
219,0.661111,0.605556,0.816667,0.847222,0.944444,1.0,0.841667,0.869444,0.861111,0.906944,...,0.5125,0.477778,0.463889,0.394444,0.608333,0.580556,0.456944,0.501389,0.531944,0.456944
209,0.530556,0.475,0.658333,0.716667,0.786111,0.841667,1.0,0.861111,0.730556,0.748611,...,0.354167,0.319444,0.305556,0.236111,0.483333,0.422222,0.298611,0.370833,0.401389,0.298611
206,0.558333,0.502778,0.686111,0.744444,0.925,0.869444,0.861111,1.0,0.758333,0.776389,...,0.493056,0.458333,0.444444,0.375,0.622222,0.561111,0.4375,0.509722,0.540278,0.4375
226,0.744444,0.688889,0.677778,0.902778,0.805556,0.861111,0.730556,0.758333,1.0,0.768056,...,0.373611,0.338889,0.325,0.255556,0.469444,0.441667,0.359722,0.431944,0.4625,0.359722
221,0.684722,0.629167,0.8125,0.809722,0.851389,0.906944,0.748611,0.776389,0.768056,1.0,...,0.605556,0.501389,0.515278,0.4875,0.543056,0.576389,0.480556,0.525,0.555556,0.480556


In [74]:
from ai_cdss.services.prescription import PrescriptionRecommender
pipeline.prescription_recommender = PrescriptionRecommender()

In [75]:
pipeline.ppf_matrix[0]

KeyError: 0

In [None]:
pipeline.prescription_recommender.recommend_protocols(pipeline.ppf_matrix[0])