In [1]:
parameter_ranges = {
    'Hemoglobin': {
        'normal': {'male': (13.5, 17.5), 'female': (12.0, 15.5)},
        'low': ['Anemia', 'Blood loss', 'Chronic disease', 'Nutritional deficiency', 'Bone marrow disorder', 'Kidney disease'],
        'high': ['Dehydration', 'Polycythemia vera', 'Lung disease', 'High altitude adaptation']
    },
    'RBC': {
        'normal': {'male': (4.5, 5.9), 'female': (4.0, 5.2)},
        'low': ['Anemia', 'Bone marrow failure', 'Nutritional deficiency', 'Chronic inflammation', 'Hemolysis'],
        'high': ['Dehydration', 'Polycythemia vera', 'Hypoxia', 'Kidney tumor']
    },
    'HCT': {
        'normal': {'male': (40, 50), 'female': (36, 46)},
        'low': ['Anemia', 'Bleeding', 'Nutritional deficiency', 'Bone marrow disorder'],
        'high': ['Dehydration', 'Polycythemia vera', 'Chronic lung disease']
    },
    'MCV': {
        'normal': (80, 100),
        'low': ['Iron deficiency anemia', 'Thalassemia', 'Chronic disease'],
        'high': ['Vitamin B12 deficiency', 'Folate deficiency', 'Liver disease', 'Hypothyroidism']
    },
    'MCH': {
        'normal': (27, 33),
        'low': ['Iron deficiency anemia', 'Thalassemia'],
        'high': ['Macrocytic anemia', 'Reticulocytosis']
    },
    'MCHC': {
        'normal': (32, 36),
        'low': ['Iron deficiency anemia', 'Thalassemia'],
        'high': ['Hereditary spherocytosis', 'Hemoglobin C disease']
    },
    'RDW-CV': {
        'normal': (11.5, 14.5),
        'low': [],
        'high': ['Iron deficiency anemia', 'Vitamin B12 deficiency', 'Hemoglobinopathy', 'Myelodysplasia']
    },
    'RDW-SD': {
        'normal': (39, 46),
        'low': [],
        'high': ['Iron deficiency anemia', 'Vitamin B12 deficiency']  # Similar to RDW-CV
    },
    'WBC': {
        'normal': (4.0, 11.0),
        'low': ['Viral infection', 'Bone marrow disorder', 'Autoimmune disease', 'Severe infection'],
        'high': ['Bacterial infection', 'Leukemia', 'Inflammation', 'Stress response']
    },
    'NEU%': {
        'normal': (40, 70),
        'low': ['Viral infection', 'Autoimmune disorder', 'Chemotherapy effect'],
        'high': ['Bacterial infection', 'Acute inflammation', 'Steroid use']
    },
    'LYM%': {
        'normal': (20, 40),
        'low': ['HIV/AIDS', 'Immunosuppression', 'Radiation exposure'],
        'high': ['Viral infection', 'Chronic infection', 'Lymphoma']
    },
    'MON%': {
        'normal': (2, 10),
        'low': [],
        'high': ['Chronic infection', 'Autoimmune disease', 'Myeloproliferative disorder']
    },
    'EOS%': {
        'normal': (0, 6),
        'low': [],
        'high': ['Allergic disorder', 'Parasitic infection', 'Autoimmune disease']
    },
    'BAS%': {
        'normal': (0, 2),
        'low': [],
        'high': ['Allergic reaction', 'Chronic inflammation', 'Myeloproliferative disorder']
    },
    'LYM#': {
        'normal': (1.0, 4.0),
        'low': ['HIV/AIDS', 'Immunosuppression'],
        'high': ['Viral infection', 'Lymphoma']
    },
    'GRA#': {
        'normal': (1.8, 7.0),
        'low': ['Chemotherapy effect', 'Bone marrow failure'],
        'high': ['Bacterial infection', 'Inflammation']
    },
    'PLT': {
        'normal': (150, 450),
        'low': ['Viral infection', 'Autoimmune disorder', 'Bone marrow disorder'],
        'high': ['Inflammation', 'Iron deficiency', 'Myeloproliferative disorder']
    },
    'ESR': {
        'normal': {'male': (0, 15), 'female': (0, 20)},
        'low': [],
        'high': ['Inflammation', 'Infection', 'Autoimmune disease', 'Malignancy']
    }
}

In [2]:
def generate_labels(row):
    conditions = []
    sex = row['Sex'].lower() if pd.notna(row['Sex']) else None
    
    for param, ranges in parameter_ranges.items():
        if param not in row or pd.isna(row[param]):
            continue
            
        value = row[param]
        
        if isinstance(ranges['normal'], dict):
            if sex in ranges['normal']:
                normal_min, normal_max = ranges['normal'][sex]
            else:
                continue 
        else:
            normal_min, normal_max = ranges['normal']
        
        if value < normal_min and 'low' in ranges:
            conditions.extend(ranges['low'])
        elif value > normal_max and 'high' in ranges:
            conditions.extend(ranges['high'])
    
    return list(set(conditions)) if conditions else ['Normal']

In [3]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import classification_report, hamming_loss
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.impute import SimpleImputer
import joblib

In [4]:
df = pd.read_csv('synthetic_blood_reports.csv')

df['Conditions'] = df.apply(generate_labels, axis=1)

In [5]:
df.head()

Unnamed: 0,Age,Sex,Hemoglobin,RBC,WBC,PLT,MCV,MCH,MCHC,RDW-CV,NEU%,LYM%,MON%,EOS%,BAS%,LYM#,GRA#,ESR,Conditions
0,69,male,14.9,4.64,6.9,214.0,93.0,28.9,34.7,12.3,46.0,49.1,8.9,0.1,0.0,3.39,3.17,10.0,"[Lymphoma, Viral infection, Chronic infection]"
1,32,male,15.0,5.15,9.5,291.0,91.0,32.0,34.5,13.4,70.6,43.8,5.3,2.6,0.0,4.15,6.68,8.0,"[Acute inflammation, Viral infection, Lymphoma..."
2,89,female,13.3,4.18,9.8,263.0,97.0,30.5,34.8,13.0,57.2,46.1,7.9,2.2,0.0,4.5,5.59,14.0,"[Lymphoma, Viral infection, Chronic infection]"
3,78,male,15.0,5.17,9.5,159.0,88.0,31.1,33.4,13.4,56.1,33.7,6.8,1.1,2.3,3.21,5.35,14.0,"[Allergic reaction, Myeloproliferative disorde..."
4,38,male,13.9,4.45,7.7,296.0,89.0,31.5,35.5,13.9,60.3,37.6,8.7,2.1,0.0,2.91,4.66,17.0,"[Infection, Anemia, Chronic inflammation, Nutr..."


In [6]:
# Convert to multi-label format
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['Conditions'])
X = df.drop('Conditions', axis=1)

In [7]:
# Preprocess features
# Handle missing values
imputer = SimpleImputer(strategy='median')
num_cols = X.select_dtypes(include=np.number).columns
X[num_cols] = imputer.fit_transform(X[num_cols])

# One-hot encode categorical features
X = pd.get_dummies(X, columns=['Sex'])

In [8]:
X.head()

Unnamed: 0,Age,Hemoglobin,RBC,WBC,PLT,MCV,MCH,MCHC,RDW-CV,NEU%,LYM%,MON%,EOS%,BAS%,LYM#,GRA#,ESR,Sex_female,Sex_male
0,69.0,14.9,4.64,6.9,214.0,93.0,28.9,34.7,12.3,46.0,49.1,8.9,0.1,0.0,3.39,3.17,10.0,False,True
1,32.0,15.0,5.15,9.5,291.0,91.0,32.0,34.5,13.4,70.6,43.8,5.3,2.6,0.0,4.15,6.68,8.0,False,True
2,89.0,13.3,4.18,9.8,263.0,97.0,30.5,34.8,13.0,57.2,46.1,7.9,2.2,0.0,4.5,5.59,14.0,True,False
3,78.0,15.0,5.17,9.5,159.0,88.0,31.1,33.4,13.4,56.1,33.7,6.8,1.1,2.3,3.21,5.35,14.0,False,True
4,38.0,13.9,4.45,7.7,296.0,89.0,31.5,35.5,13.9,60.3,37.6,8.7,2.1,0.0,2.91,4.66,17.0,False,True


In [9]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

In [10]:
model = MultiOutputClassifier(
    RandomForestClassifier(
        n_estimators=200,
        max_depth=12,
        min_samples_split=5,
        class_weight='balanced',
        random_state=42
    ),
    n_jobs=-1
)

model.fit(X_train, y_train)

In [11]:
y_pred = model.predict(X_test)

print("\n Test Accuracy:")
print(model.score(X_test, y_test))

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=mlb.classes_, zero_division=0))
print("\nHamming Loss:", hamming_loss(y_test, y_pred))



 Test Accuracy:
0.953

Classification Report:
                             precision    recall  f1-score   support

         Acute inflammation       1.00      1.00      1.00       133
          Allergic reaction       1.00      1.00      1.00       950
                     Anemia       1.00      1.00      1.00       926
         Autoimmune disease       1.00      1.00      1.00       788
        Autoimmune disorder       1.00      1.00      1.00        56
        Bacterial infection       1.00      1.00      1.00       219
                 Blood loss       1.00      1.00      1.00       332
       Bone marrow disorder       1.00      1.00      1.00       378
        Bone marrow failure       1.00      1.00      1.00       839
            Chronic disease       1.00      1.00      1.00       363
          Chronic infection       1.00      1.00      1.00       614
       Chronic inflammation       1.00      1.00      1.00      1381
                Dehydration       1.00      0.82      0

In [12]:
joblib.dump(model, 'blood_report_model.pkl')
joblib.dump(mlb, 'label_binarizer.pkl')
joblib.dump(imputer, 'imputer.pkl')
joblib.dump(X_train.columns.tolist(), 'training_columns.pkl')

print("Saved")

Saved


In [13]:
def predict_patient_conditions(patient_data):
    model = joblib.load('blood_report_model.pkl')
    mlb = joblib.load('label_binarizer.pkl')
    imputer = joblib.load('imputer.pkl')
    
    patient_df = pd.DataFrame([patient_data])
    
    #preprocess
    num_cols = X.select_dtypes(include=np.number).columns
    patient_df[num_cols] = imputer.transform(patient_df[num_cols])
    patient_df = pd.get_dummies(patient_df, columns=['Sex'])
    
    train_cols = joblib.load('training_columns.pkl')
    missing_cols = set(train_cols) - set(patient_df.columns)
    for col in missing_cols:
        patient_df[col] = 0
    patient_df = patient_df[train_cols]
    
    probs = model.predict_proba(patient_df)
    predictions = {}
    
    for i, condition in enumerate(mlb.classes_):
        prob = probs[i][0][1]  # Probability of condition being present
        if prob > 0:  # You can adjust this threshold
            predictions[condition] = round(prob, 2)
    
    return predictions or {'Normal': 1.0}

# Example
sample_patient = {
    'Age': 58,
    'Sex': 'female',
    'Hemoglobin': 10.8,  # Low (anemia range)
    'RBC': 3.5,         # Low
    'HCT': 32,          # Low
    'MCV': 72,          # Low (microcytic)
    'MCH': 22,          # Low
    'MCHC': 29,         # Low
    'RDW-CV': 18.5,     # High
    'RDW-SD': 52,       # High
    'WBC': 3.2,         # Low
    'NEU%': 38,         # Normal
    'LYM%': 55,         # High (lymphocytosis)
    'MON%': 5,          # Normal
    'EOS%': 2,          # Normal
    'BAS%': 0,          # Normal
    'LYM#': 1.76,       # Calculated
    'GRA#': 1.22,       # Calculated
    'PLT': 112,         # Low
    'ESR': 42           # High
}

#Normal
sample_patient = {
    'Age': 45,                          # years
    'Sex': 'male',                      # M/F
    'Hemoglobin': 14.3,                 # g/dL
    'RBC': 5.1,                         # million cells/μL
    'HCT': 44,                          # %
    'MCV': 88,                          # fL
    'MCH': 29,                          # pg
    'MCHC': 33,                         # g/dL
    'RDW-CV': 13,                       # %
    'RDW-SD': 42,                       # fL
    'WBC': 6.5,                         # ×10³/μL
    'NEU%': 58,                         # %
    'LYM%': 35,                         # %
    'MON%': 5,                          # %
    'EOS%': 2,                          # %
    'BAS%': 0,                          # %
    'LYM#': 2.3,                        # ×10³/μL
    'GRA#': 3.8,                        # ×10³/μL
    'PLT': 240,                         # ×10³/μL
    'ESR': 10                           # mm/hr
}

print("Predicted Conditions:", predict_patient_conditions(sample_patient))

Predicted Conditions: {'Allergic reaction': 0.04, 'Anemia': 0.02, 'Autoimmune disease': 0.02, 'Autoimmune disorder': 0.01, 'Bacterial infection': 0.0, 'Blood loss': 0.01, 'Bone marrow disorder': 0.03, 'Bone marrow failure': 0.03, 'Chronic disease': 0.02, 'Chronic infection': 0.02, 'Chronic inflammation': 0.16, 'Dehydration': 0.0, 'Folate deficiency': 0.0, 'Hemoglobinopathy': 0.01, 'Hemolysis': 0.03, 'High altitude adaptation': 0.0, 'Hypothyroidism': 0.0, 'Infection': 0.01, 'Inflammation': 0.01, 'Iron deficiency': 0.02, 'Iron deficiency anemia': 0.02, 'Kidney disease': 0.01, 'Liver disease': 0.0, 'Lung disease': 0.0, 'Lymphoma': 0.01, 'Macrocytic anemia': 0.01, 'Malignancy': 0.01, 'Myelodysplasia': 0.01, 'Myeloproliferative disorder': 0.04, 'Normal': 0.83, 'Nutritional deficiency': 0.02, 'Polycythemia vera': 0.0, 'Reticulocytosis': 0.01, 'Thalassemia': 0.02, 'Viral infection': 0.01, 'Vitamin B12 deficiency': 0.02}
