In [9]:
import pandas as pd
import numpy as np
import pickle
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.model_selection import KFold, train_test_split
import time

In [17]:
# MIMIC-III Baseline Model Training: TF-IDF + Logistic Regression with Cross Validation


# 1. Load the preprocessed data
print("Loading preprocessed MIMIC-III data...")
data_path = '../data/'
processed_data_file = os.path.join(data_path, 'mimic3_data.pkl')
data = pd.read_pickle(processed_data_file)

print(f"Loaded {len(data)} records")

# 2. Create train and test sets (80% train, 20% test)
print("Creating train and test splits...")
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
print(f"Created splits: {len(train_data)} train, {len(test_data)} test")

# 3. Process ICD codes - build a set of all unique codes
# Modified code to handle mixed data types in ICD codes
print("Processing ICD codes...")
all_codes = set()
for codes in data['ICD9_CODE']:
    # Convert any non-string codes to strings and filter out NaN values
    valid_codes = [str(code) for code in codes if pd.notna(code)]
    all_codes.update(valid_codes)

# Now all codes should be strings, so sorting will work
code_to_idx = {code: i for i, code in enumerate(sorted(all_codes))}
idx_to_code = {i: code for code, i in code_to_idx.items()}

num_codes = len(code_to_idx)
print(f"Total unique ICD-9 codes: {num_codes}")


# 4. Create features and labels
def prepare_data(dataset):
    texts = dataset['TEXT'].tolist()
    
    # Convert ICD codes to multi-hot encoded vectors
    labels = np.zeros((len(dataset), num_codes), dtype=np.int8)
    for i, codes in enumerate(dataset['ICD9_CODE']):
        for code in codes:
            if code in code_to_idx:  # Just in case there's an unknown code
                labels[i, code_to_idx[code]] = 1
    
    return texts, labels

print("Preparing data...")
train_texts, train_labels = prepare_data(train_data)
test_texts, test_labels = prepare_data(test_data)

# 5. Create TF-IDF features
print("Creating TF-IDF features...")
start_time = time.time()
tfidf = TfidfVectorizer(
    max_features=10000,  # Limit features to reduce dimensionality
    min_df=5,            # Ignore terms that appear in less than 5 documents
    max_df=0.5,          # Ignore terms that appear in more than 50% of documents
    ngram_range=(1, 2),  # Use unigrams and bigrams
    stop_words='english' # Remove English stop words
)

train_features = tfidf.fit_transform(train_texts)
test_features = tfidf.transform(test_texts)

print(f"TF-IDF features shape: {train_features.shape}")
print(f"TF-IDF processing time: {time.time() - start_time:.2f} seconds")




Loading preprocessed MIMIC-III data...
Loaded 52726 records
Creating train and test splits...
Created splits: 42180 train, 10546 test
Processing ICD codes...
Total unique ICD-9 codes: 6918
Preparing data...
Creating TF-IDF features...
TF-IDF features shape: (42180, 10000)
TF-IDF processing time: 65.11 seconds


In [None]:
# 6. Cross-validation and model training
print("Setting up cross-validation...")
n_folds = 5
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)

cv_results = {
    'fold': [],
    'micro_f1': [],
    'macro_f1': [],
    'micro_precision': [],
    'macro_precision': [],
    'micro_recall': [],
    'macro_recall': []
}

def evaluate(model, features, labels):
    # Get predictions
    pred_probs = model.predict_proba(features)
    
    # Apply threshold of 0.5 for binary classification
    preds = np.zeros_like(labels)
    for i in range(num_codes):
        # Check if this classifier exists (it may not if there was only one class for this code)
        if i < len(pred_probs) and len(pred_probs[i]) > 0:
            preds[:, i] = (pred_probs[i][:, 1] >= 0.5).astype(int)
    
    # Calculate micro and macro F1 scores
    micro_f1 = f1_score(labels, preds, average='micro', zero_division=0)
    macro_f1 = f1_score(labels, preds, average='macro', zero_division=0)
    
    # Calculate micro and macro precision and recall
    micro_precision = precision_score(labels, preds, average='micro', zero_division=0)
    macro_precision = precision_score(labels, preds, average='macro', zero_division=0)
    micro_recall = recall_score(labels, preds, average='micro', zero_division=0)
    macro_recall = recall_score(labels, preds, average='macro', zero_division=0)
    
    return {
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'micro_precision': micro_precision,
        'macro_precision': macro_precision,
        'micro_recall': micro_recall,
        'macro_recall': macro_recall
    }

print("Performing cross-validation...")
fold = 1
best_model = None
best_micro_f1 = 0

for train_idx, val_idx in kf.split(train_features):
    print(f"\nTraining fold {fold}/{n_folds}...")
    
    # Get fold data
    X_train_fold, X_val_fold = train_features[train_idx], train_features[val_idx]
    y_train_fold, y_val_fold = train_labels[train_idx], train_labels[val_idx]
    
    # Create a list to store estimators
    estimators = []
    trained_indices = []
    
    # Train a separate model for each ICD code that has both positive and negative examples
    for i in range(num_codes):
        # Check if there are both positive and negative examples for this code
        if np.sum(y_train_fold[:, i] == 0) > 0 and np.sum(y_train_fold[:, i] == 1) > 0:
            print(f"  Training classifier for code {i+1}/{num_codes}", end='\r')
            estimator = LogisticRegression(
                C=1.0,
                solver='saga',
                penalty='l1',
                max_iter=500,
                n_jobs=2, 
                verbose=0,
                random_state=42
            )
            try:
                estimator.fit(X_train_fold, y_train_fold[:, i])
                estimators.append(estimator)
                trained_indices.append(i)
            except Exception as e:
                print(f"  Error training classifier for code {i}: {e}")
        else:
            # Skip codes without both positive and negative examples
            pass
    
    print(f"  Trained {len(estimators)}/{num_codes} classifiers")
    
    # Create a model object to hold the estimators
    model = MultiOutputClassifier(LogisticRegression())
    model.estimators_ = estimators
    model.classes_ = [np.array([0, 1]) for _ in estimators]
    model.trained_indices = trained_indices  # Store which indices were trained
    
    # Evaluate on validation set
    metrics = evaluate(model, X_val_fold, y_val_fold)
    
    print(f"Fold {fold} validation results:")
    print(f"  Micro F1: {metrics['micro_f1']:.4f}")
    print(f"  Macro F1: {metrics['macro_f1']:.4f}")
    
    # Save metrics
    cv_results['fold'].append(fold)
    for key, value in metrics.items():
        cv_results[key].append(value)
    
    # Keep track of best model
    if metrics['micro_f1'] > best_micro_f1:
        best_micro_f1 = metrics['micro_f1']
        best_model = model
        print(f"  New best model found with micro F1: {best_micro_f1:.4f}")
    
    fold += 1

Setting up cross-validation...
Performing cross-validation...

Training fold 1/5...
  Training classifier for code 23/6918



  Training classifier for code 78/6918



  Training classifier for code 79/6918



  Training classifier for code 88/6918



  Training classifier for code 102/6918



  Training classifier for code 106/6918



  Training classifier for code 109/6918



  Training classifier for code 161/6918



  Training classifier for code 182/6918



  Training classifier for code 208/6918



  Training classifier for code 256/6918



  Training classifier for code 291/6918



  Training classifier for code 361/6918



  Training classifier for code 370/6918



  Training classifier for code 371/6918



  Training classifier for code 372/6918



  Training classifier for code 374/6918



  Training classifier for code 384/6918



  Training classifier for code 401/6918



  Training classifier for code 406/6918



  Training classifier for code 456/6918



  Training classifier for code 495/6918



  Training classifier for code 497/6918



  Training classifier for code 507/6918



  Training classifier for code 530/6918



  Training classifier for code 556/6918



  Training classifier for code 589/6918



  Training classifier for code 607/6918



  Training classifier for code 662/6918



  Training classifier for code 718/6918



  Training classifier for code 764/6918



  Training classifier for code 767/6918



  Training classifier for code 771/6918



  Training classifier for code 784/6918



  Training classifier for code 812/6918



  Training classifier for code 825/6918



  Training classifier for code 826/6918



  Training classifier for code 828/6918



  Training classifier for code 831/6918



  Training classifier for code 837/6918



  Training classifier for code 839/6918



  Training classifier for code 840/6918



  Training classifier for code 843/6918



  Training classifier for code 854/6918



  Training classifier for code 855/6918



  Training classifier for code 856/6918



  Training classifier for code 877/6918



  Training classifier for code 913/6918



  Training classifier for code 926/6918



  Training classifier for code 965/6918



  Training classifier for code 971/6918



  Training classifier for code 972/6918



  Training classifier for code 974/6918



  Training classifier for code 975/6918



  Training classifier for code 983/6918



  Training classifier for code 985/6918



  Training classifier for code 991/6918



  Training classifier for code 1000/6918



  Training classifier for code 1011/6918



  Training classifier for code 1012/6918



  Training classifier for code 1017/6918



  Training classifier for code 1034/6918



  Training classifier for code 1077/6918



  Training classifier for code 1096/6918



  Training classifier for code 1134/6918



  Training classifier for code 1157/6918



  Training classifier for code 1184/6918



  Training classifier for code 1201/6918



  Training classifier for code 1214/6918



  Training classifier for code 1221/6918



  Training classifier for code 1231/6918



  Training classifier for code 1262/6918



  Training classifier for code 1267/6918



  Training classifier for code 1307/6918



  Training classifier for code 1350/6918



  Training classifier for code 1378/6918



  Training classifier for code 1390/6918



  Training classifier for code 1437/6918



  Training classifier for code 1462/6918



  Training classifier for code 1477/6918



  Training classifier for code 1492/6918



  Training classifier for code 1525/6918



  Training classifier for code 1531/6918



  Training classifier for code 1558/6918



  Training classifier for code 1563/6918



  Training classifier for code 1577/6918



  Training classifier for code 1581/6918



  Training classifier for code 1582/6918



  Training classifier for code 1583/6918



  Training classifier for code 1585/6918



  Training classifier for code 1586/6918



  Training classifier for code 1589/6918



  Training classifier for code 1597/6918



  Training classifier for code 1648/6918



  Training classifier for code 1691/6918



  Training classifier for code 1698/6918



  Training classifier for code 1738/6918



  Training classifier for code 1752/6918



  Training classifier for code 1761/6918



  Training classifier for code 1783/6918



  Training classifier for code 1938/6918



  Training classifier for code 1939/6918



  Training classifier for code 1941/6918



  Training classifier for code 1945/6918



  Training classifier for code 1949/6918



  Training classifier for code 1954/6918



  Training classifier for code 1956/6918



  Training classifier for code 1985/6918



  Training classifier for code 1988/6918



  Training classifier for code 2005/6918



  Training classifier for code 2008/6918



  Training classifier for code 2010/6918



  Training classifier for code 2013/6918



  Training classifier for code 2018/6918



  Training classifier for code 2022/6918



  Training classifier for code 2029/6918



  Training classifier for code 2030/6918



  Training classifier for code 2044/6918



  Training classifier for code 2054/6918



  Training classifier for code 2069/6918



  Training classifier for code 2072/6918



  Training classifier for code 2075/6918



  Training classifier for code 2079/6918



  Training classifier for code 2114/6918



  Training classifier for code 2117/6918



  Training classifier for code 2132/6918



  Training classifier for code 2142/6918



  Training classifier for code 2145/6918



  Training classifier for code 2150/6918



  Training classifier for code 2171/6918



  Training classifier for code 2191/6918



  Training classifier for code 2201/6918



  Training classifier for code 2208/6918



  Training classifier for code 2209/6918



  Training classifier for code 2229/6918



  Training classifier for code 2233/6918



  Training classifier for code 2236/6918



  Training classifier for code 2237/6918



  Training classifier for code 2240/6918



  Training classifier for code 2241/6918



  Training classifier for code 2257/6918



  Training classifier for code 2259/6918



  Training classifier for code 2262/6918



  Training classifier for code 2282/6918



  Training classifier for code 2292/6918



  Training classifier for code 2324/6918



  Training classifier for code 2330/6918



  Training classifier for code 2346/6918



  Training classifier for code 2349/6918



  Training classifier for code 2369/6918



  Training classifier for code 2421/6918



  Training classifier for code 2431/6918



  Training classifier for code 2436/6918



  Training classifier for code 2449/6918



  Training classifier for code 2456/6918



  Training classifier for code 2457/6918



  Training classifier for code 2471/6918



  Training classifier for code 2474/6918



  Training classifier for code 2478/6918



  Training classifier for code 2490/6918



  Training classifier for code 2491/6918



  Training classifier for code 2499/6918



  Training classifier for code 2517/6918



  Training classifier for code 2527/6918



  Training classifier for code 2551/6918



  Training classifier for code 2624/6918



  Training classifier for code 2649/6918



  Training classifier for code 2686/6918



  Training classifier for code 2699/6918



  Training classifier for code 2709/6918



  Training classifier for code 2725/6918



  Training classifier for code 2727/6918



  Training classifier for code 2750/6918



  Training classifier for code 2755/6918



  Training classifier for code 2768/6918



  Training classifier for code 2771/6918



  Training classifier for code 2776/6918



  Training classifier for code 2789/6918



  Training classifier for code 2795/6918



  Training classifier for code 2805/6918



  Training classifier for code 2808/6918



  Training classifier for code 2812/6918



  Training classifier for code 2813/6918



  Training classifier for code 2819/6918



  Training classifier for code 2820/6918



  Training classifier for code 2828/6918



  Training classifier for code 2835/6918



  Training classifier for code 2839/6918



  Training classifier for code 2840/6918



  Training classifier for code 2847/6918



  Training classifier for code 2908/6918



  Training classifier for code 2914/6918



  Training classifier for code 2915/6918



  Training classifier for code 2946/6918



  Training classifier for code 2947/6918



  Training classifier for code 2951/6918



  Training classifier for code 2952/6918



  Training classifier for code 2957/6918



  Training classifier for code 2958/6918



  Training classifier for code 2963/6918



  Training classifier for code 2965/6918



  Training classifier for code 2970/6918



  Training classifier for code 2987/6918



  Training classifier for code 3003/6918



  Training classifier for code 3022/6918



  Training classifier for code 3035/6918



  Training classifier for code 3036/6918



  Training classifier for code 3130/6918



  Training classifier for code 3378/6918



  Training classifier for code 3402/6918



  Training classifier for code 3411/6918



  Training classifier for code 3435/6918



  Training classifier for code 3476/6918



  Training classifier for code 3478/6918



  Training classifier for code 3481/6918



  Training classifier for code 3484/6918



  Training classifier for code 3485/6918



  Training classifier for code 3491/6918



  Training classifier for code 3492/6918



  Training classifier for code 3517/6918



  Training classifier for code 3521/6918



  Training classifier for code 3522/6918



  Training classifier for code 3548/6918



  Training classifier for code 3574/6918



  Training classifier for code 3622/6918



  Training classifier for code 3638/6918



  Training classifier for code 3656/6918



  Training classifier for code 3666/6918



  Training classifier for code 3672/6918



  Training classifier for code 3682/6918



  Training classifier for code 3692/6918



  Training classifier for code 3755/6918



  Training classifier for code 3759/6918



  Training classifier for code 3764/6918



  Training classifier for code 3775/6918



  Training classifier for code 3828/6918



  Training classifier for code 3875/6918



  Training classifier for code 3887/6918



  Training classifier for code 3888/6918



  Training classifier for code 3889/6918



  Training classifier for code 3917/6918



  Training classifier for code 3918/6918



  Training classifier for code 3926/6918



  Training classifier for code 3941/6918



  Training classifier for code 3943/6918



  Training classifier for code 3950/6918



  Training classifier for code 3963/6918



  Training classifier for code 3968/6918



  Training classifier for code 3994/6918



  Training classifier for code 3995/6918



  Training classifier for code 4008/6918



  Training classifier for code 4009/6918



  Training classifier for code 4019/6918



  Training classifier for code 4029/6918



  Training classifier for code 4055/6918



  Training classifier for code 4091/6918



  Training classifier for code 4101/6918



  Training classifier for code 4104/6918



  Training classifier for code 4121/6918



  Training classifier for code 4146/6918



  Training classifier for code 4148/6918



  Training classifier for code 4149/6918



  Training classifier for code 4152/6918



  Training classifier for code 4160/6918



  Training classifier for code 4167/6918



  Training classifier for code 4168/6918



  Training classifier for code 4169/6918



  Training classifier for code 4170/6918



  Training classifier for code 4183/6918



  Training classifier for code 4184/6918



  Training classifier for code 4192/6918



  Training classifier for code 4197/6918



  Training classifier for code 4198/6918



  Training classifier for code 4200/6918



  Training classifier for code 4201/6918



  Training classifier for code 4204/6918



  Training classifier for code 4205/6918



  Training classifier for code 4207/6918



  Training classifier for code 4209/6918



  Training classifier for code 4211/6918



  Training classifier for code 4215/6918



  Training classifier for code 4224/6918



  Training classifier for code 4226/6918



  Training classifier for code 4231/6918



  Training classifier for code 4232/6918



  Training classifier for code 4234/6918



  Training classifier for code 4236/6918



  Training classifier for code 4237/6918



  Training classifier for code 4239/6918



  Training classifier for code 4240/6918



  Training classifier for code 4241/6918



  Training classifier for code 4244/6918



  Training classifier for code 4249/6918



  Training classifier for code 4250/6918



  Training classifier for code 4251/6918



  Training classifier for code 4252/6918



  Training classifier for code 4254/6918



  Training classifier for code 4258/6918



  Training classifier for code 4260/6918



  Training classifier for code 4261/6918



  Training classifier for code 4265/6918



  Training classifier for code 4268/6918



  Training classifier for code 4269/6918



  Training classifier for code 4276/6918



  Training classifier for code 4277/6918



  Training classifier for code 4280/6918



  Training classifier for code 4282/6918



  Training classifier for code 4288/6918



  Training classifier for code 4291/6918



  Training classifier for code 4293/6918



  Training classifier for code 4294/6918



  Training classifier for code 4298/6918



  Training classifier for code 4299/6918



  Training classifier for code 4304/6918



  Training classifier for code 4305/6918



  Training classifier for code 4307/6918



  Training classifier for code 4310/6918



  Training classifier for code 4313/6918



  Training classifier for code 4316/6918



  Training classifier for code 4321/6918



  Training classifier for code 4323/6918



  Training classifier for code 4326/6918



  Training classifier for code 4332/6918



  Training classifier for code 4342/6918



  Training classifier for code 4356/6918



  Training classifier for code 4360/6918



  Training classifier for code 4365/6918



  Training classifier for code 4370/6918



  Training classifier for code 4379/6918



  Training classifier for code 4389/6918



  Training classifier for code 4406/6918



  Training classifier for code 4410/6918



  Training classifier for code 4412/6918



  Training classifier for code 4422/6918



  Training classifier for code 4426/6918



  Training classifier for code 4429/6918



  Training classifier for code 4433/6918



  Training classifier for code 4444/6918



  Training classifier for code 4455/6918



  Training classifier for code 4457/6918



  Training classifier for code 4469/6918



  Training classifier for code 4489/6918



  Training classifier for code 4500/6918



  Training classifier for code 4504/6918



  Training classifier for code 4510/6918



  Training classifier for code 4511/6918



  Training classifier for code 4528/6918



  Training classifier for code 4583/6918



  Training classifier for code 4620/6918



  Training classifier for code 4621/6918



  Training classifier for code 4622/6918



  Training classifier for code 4624/6918



  Training classifier for code 4634/6918



  Training classifier for code 4635/6918



  Training classifier for code 4636/6918



  Training classifier for code 4640/6918



  Training classifier for code 4740/6918



  Training classifier for code 4741/6918



  Training classifier for code 4744/6918



  Training classifier for code 4751/6918



  Training classifier for code 4756/6918



  Training classifier for code 4760/6918



  Training classifier for code 4773/6918



  Training classifier for code 4784/6918



  Training classifier for code 4785/6918



  Training classifier for code 4791/6918



  Training classifier for code 4796/6918



  Training classifier for code 4800/6918



  Training classifier for code 4802/6918



  Training classifier for code 4810/6918



  Training classifier for code 4811/6918



  Training classifier for code 4818/6918



  Training classifier for code 4819/6918



  Training classifier for code 4821/6918



  Training classifier for code 4822/6918



  Training classifier for code 4825/6918



  Training classifier for code 4830/6918



  Training classifier for code 4844/6918



  Training classifier for code 4855/6918



  Training classifier for code 4860/6918



  Training classifier for code 4863/6918



  Training classifier for code 4919/6918



  Training classifier for code 4927/6918



  Training classifier for code 4931/6918



  Training classifier for code 4938/6918



  Training classifier for code 4941/6918



  Training classifier for code 4942/6918



  Training classifier for code 4948/6918



  Training classifier for code 4950/6918



  Training classifier for code 4954/6918



  Training classifier for code 4957/6918



  Training classifier for code 4959/6918



  Training classifier for code 4961/6918



  Training classifier for code 4965/6918



  Training classifier for code 4967/6918



  Training classifier for code 4970/6918



  Training classifier for code 4974/6918



  Training classifier for code 5007/6918



  Training classifier for code 5071/6918



  Training classifier for code 5098/6918



  Training classifier for code 5103/6918



  Training classifier for code 5127/6918



  Training classifier for code 5140/6918



  Training classifier for code 5225/6918



  Training classifier for code 5252/6918



  Training classifier for code 5290/6918



  Training classifier for code 5300/6918



  Training classifier for code 5360/6918



  Training classifier for code 5364/6918



  Training classifier for code 5430/6918

In [None]:
# 7. Final model training and evaluation
print("\nTraining final model on all training data...")
start_time = time.time()

# Create a list to store estimators
final_estimators = []
trained_indices = []

# Train a separate model for each ICD code that has both positive and negative examples
for i in range(num_codes):
    # Check if there are both positive and negative examples for this code
    if np.sum(train_labels[:, i] == 0) > 0 and np.sum(train_labels[:, i] == 1) > 0:
        print(f"Training classifier for code {i+1}/{num_codes}", end='\r')
        estimator = LogisticRegression(
            C=1.0,
            solver='saga',
            penalty='l1',
            max_iter=100,
            n_jobs=1,  # Set to 1 for individual estimator
            verbose=0,
            random_state=42
        )
        try:
            estimator.fit(train_features, train_labels[:, i])
            final_estimators.append(estimator)
            trained_indices.append(i)
        except Exception as e:
            print(f"Error training classifier for code {i}: {e}")
    else:
        # Skip codes without both positive and negative examples
        pass

print(f"Trained {len(final_estimators)}/{num_codes} classifiers")

# Create a model object to hold the estimators
final_model = MultiOutputClassifier(LogisticRegression())
final_model.estimators_ = final_estimators
final_model.classes_ = [np.array([0, 1]) for _ in final_estimators]
final_model.trained_indices = trained_indices  # Store which indices were trained

print(f"Final model training time: {time.time() - start_time:.2f} seconds")

In [None]:


# Evaluate on training set
train_metrics = evaluate(final_model, train_features, train_labels)
print("\nTraining set results:")
for key, value in train_metrics.items():
    print(f"  {key}: {value:.4f}")

# Evaluate on test set
test_metrics = evaluate(final_model, test_features, test_labels)
print("\nTest set results:")
for key, value in test_metrics.items():
    print(f"  {key}: {value:.4f}")

# 8. Analyze ICD code distribution and performance
print("\nAnalyzing ICD code distribution...")

# Count occurrences of each ICD code
code_counts = np.sum(train_labels, axis=0)
# Sort ICD codes by frequency
sorted_idx = np.argsort(-code_counts)
top_codes = [idx_to_code[i] for i in sorted_idx[:20]]
top_counts = code_counts[sorted_idx[:20]]

# Plot top 20 most frequent ICD codes
plt.figure(figsize=(12, 6))
plt.bar(range(20), top_counts)
plt.xticks(range(20), top_codes, rotation=90)
plt.title('Top 20 Most Frequent ICD-9 Codes in Training Data')
plt.xlabel('ICD-9 Code')
plt.ylabel('Frequency')
plt.tight_layout()
plt.savefig(os.path.join(data_path, 'icd_distribution.png'))
plt.show()

# 9. Plot cross-validation performance
plt.figure(figsize=(10, 6))
plt.plot(cv_results['fold'], cv_results['micro_f1'], 'o-', label='Micro F1')
plt.plot(cv_results['fold'], cv_results['macro_f1'], 'o-', label='Macro F1')
plt.xlabel('Fold')
plt.ylabel('F1 Score')
plt.title('Cross-Validation Performance')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(data_path, 'cv_performance.png'))
plt.show()

# 10. Save the model and related artifacts
print("Saving model and artifacts...")
output_dir = os.path.join(data_path, 'baseline_model')
os.makedirs(output_dir, exist_ok=True)

# Save the TF-IDF vectorizer
with open(os.path.join(output_dir, 'tfidf_vectorizer.pkl'), 'wb') as f:
    pickle.dump(tfidf, f)

# Save the trained final model
with open(os.path.join(output_dir, 'lr_model.pkl'), 'wb') as f:
    pickle.dump(final_model, f)

# Save the code mappings
with open(os.path.join(output_dir, 'code_mappings.pkl'), 'wb') as f:
    pickle.dump({
        'code_to_idx': code_to_idx,
        'idx_to_code': idx_to_code
    }, f)

# Save the CV results
cv_df.to_csv(os.path.join(output_dir, 'cv_results.csv'), index=False)

# Save the evaluation metrics
metrics = {
    'train': train_metrics,
    'test': test_metrics,
    'cv': cv_df.to_dict()
}
with open(os.path.join(output_dir, 'metrics.pkl'), 'wb') as f:
    pickle.dump(metrics, f)

print("\nBaseline model training and evaluation complete!")
print(f"Model and artifacts saved to {output_dir}")

# 11. Display example predictions
print("\nExample predictions for 5 random test samples:")
import random
sample_indices = random.sample(range(len(test_texts)), 5)

# In the display example predictions section, modify the code:
for idx in sample_indices:
    # Get true labels
    true_codes = [idx_to_code[i] for i in np.where(test_labels[idx] == 1)[0]]
    
    # Get predicted probabilities for trained classifiers only
    pred_probs = np.zeros(num_codes)
    for i, estimator_idx in enumerate(final_model.trained_indices):
        pred_probs[estimator_idx] = final_model.estimators_[i].predict_proba(test_features[idx])[0][1]
    
    # Get top 5 predicted codes with highest probability
    top_pred_indices = np.argsort(-pred_probs)[:5]
    pred_codes = [(idx_to_code[i], pred_probs[i]) for i in top_pred_indices if pred_probs[i] > 0]
    
    print(f"\nSample {idx}:")
    print(f"Text snippet: {test_texts[idx][:100]}...")
    print(f"True codes ({len(true_codes)}):")
    for code in true_codes[:5]:
        print(f"  - {code}")
    if len(true_codes) > 5:
        print(f"  - ... and {len(true_codes) - 5} more")
    
    print(f"Top 5 predicted codes (with probability):")
    for code, prob in pred_codes:
        print(f"  - {code}: {prob:.4f}")
    print("-" * 50)