Lets load our model and generate new predictions

In [None]:
# Load the combined healthcare transformer model
import tensorflow as tf
from keras import models
import numpy as np
from custom_layers import TokenAndPositionEmbedding, TransformerBlock


try:
    # Load the combined model (includes text vectorization + transformer)
    model = models.load_model('../models/healthcare_transformer.keras')
    
    # Extract vocabulary from the vectorization layer
    for layer in model.layers:
        if hasattr(layer, 'get_vocabulary'):
            vocab = layer.get_vocabulary()
            break
    
except Exception as e:
    print(f"Error loading model: {e}")


  saveable.load_own_variables(weights_store.get(inner_path))


In [7]:

# Input diagnosis sequence
diagnosis_sequence = "6826 485"

# Convert to tensor
text_tensor = tf.constant([diagnosis_sequence])
print(f"Tensor shape: {text_tensor.shape}")

# Get model predictions
predictions = model.predict(text_tensor, verbose=0)
if isinstance(predictions, list):
    y_pred = predictions[0]
else:
    y_pred = predictions

print(f"Model output shape: {y_pred.shape}")

# Get probabilities for next token
current_length = len(diagnosis_sequence.split())
print(f"Current sequence length: {current_length}")
print(f"Getting predictions from position: {current_length}")

next_token_probs = y_pred[0][current_length]
print(f"Probability vector shape: {next_token_probs.shape}")

# Get top 5 predictions
k = 5
top_5_indices = np.argsort(next_token_probs)[-k:][::-1]
top_5_probs = next_token_probs[top_5_indices]

print(f"\nTop {k} next diagnosis predictions:")
for i, (idx, prob) in enumerate(zip(top_5_indices, top_5_probs)):
    diagnosis_code = vocab[idx] if idx < len(vocab) else f"Token_{idx}"
    print(f"{i+1}. {diagnosis_code}: {prob:.4f}")

# Get the best prediction (argmax)
best_idx = top_5_indices[0]
best_code = vocab[best_idx] if best_idx < len(vocab) else f"Token_{best_idx}"
best_prob = top_5_probs[0]

print(f"\nBest prediction: {best_code} (probability: {best_prob:.4f})")
print(f"Extended sequence: {diagnosis_sequence} {best_code}")

Tensor shape: (1,)
Model output shape: (1, 3, 2740)
Current sequence length: 2
Getting predictions from position: 2
Probability vector shape: (2740,)

Top 5 next diagnosis predictions:
1. 486: 3.2188
2. 4280: 3.1891
3. 0389: 2.9358
4. 51881: 2.9218
5. v5789: 2.8318

Best prediction: 486 (probability: 3.2188)
Extended sequence: 6826 485 486


In [11]:
# Simple prediction script
def simple_predict(diagnosis_sequence, top_k=5):
    """
    Simple function to get next diagnosis predictions.
    
    Args:
        diagnosis_sequence: String of diagnosis codes (e.g., "6826 485")
        top_k: Number of top predictions to return
    
    Returns:
        List of (diagnosis_code, probability) tuples
    """
    # Convert to tensor
    text_tensor = tf.constant([diagnosis_sequence])
    
    # Get predictions
    predictions = model.predict(text_tensor, verbose=0)
    if isinstance(predictions, list):
        y_pred = predictions[0]
    else:
        y_pred = predictions
    
    # Get probabilities for next token
    current_length = len(diagnosis_sequence.split())
    if current_length >= y_pred.shape[1]:
        print(f"Warning: Sequence too long ({current_length} >= {y_pred.shape[1]})")
        return []
    
    next_token_probs = y_pred[0][current_length]
    
    # Get top-k predictions
    top_k_indices = np.argsort(next_token_probs)[-top_k:][::-1]
    top_k_probs = next_token_probs[top_k_indices]
    
    # Convert to diagnosis codes
    results = []
    for idx, prob in zip(top_k_indices, top_k_probs):
        if idx < len(vocab):
            code = vocab[idx]
            if code and code not in ['', '[UNK]', '[PAD]']:
                results.append((code, float(prob)))
        
    return results

# Test examples
print("Simple Healthcare Diagnosis Predictor")
print("=" * 40)

test_cases = [
    "6826",
    "6826 485", 
    "1970"
]

for test_seq in test_cases:
    print(f"\nInput: '{test_seq}'")
    predictions = simple_predict(test_seq, top_k=3)
    print("Top 3 predictions:")
    for i, (code, prob) in enumerate(predictions, 1):
        print(f"  {i}. {code}: {prob:.4f}")

Simple Healthcare Diagnosis Predictor

Input: '6826'
Top 3 predictions:
  1. 486: 3.2188
  2. 4280: 3.1891
  3. 0389: 2.9358

Input: '6826 485'
Top 3 predictions:
  1. 486: 3.2188
  2. 4280: 3.1891
  3. 0389: 2.9358

Input: '1970'
Top 3 predictions:
  1. 486: 3.2188
  2. 4280: 3.1891
  3. 0389: 2.9358


In [12]:
def iterative_diagnosis_prediction_v2(initial_sequence, num_iterations=5, show_progress=True):
    """
    Improved iterative prediction using cleaner numpy logic.
    
    Args:
        initial_sequence: Starting diagnosis codes (e.g., "6826 485 4589")
        num_iterations: Number of prediction iterations to perform
        show_progress: Whether to display the prediction process
    
    Returns:
        Final extended sequence and prediction history
    """
    
    current_sequence = initial_sequence.strip()
    prediction_history = []
    
    print(f"Starting iterative prediction from: '{current_sequence}'")
    print("=" * 60)
    
    for iteration in range(num_iterations):
        try:
            # Convert current sequence to tensor
            text_tensor = tf.constant([current_sequence])
            
            # Get model predictions (including attention scores if available)
            predictions = model.predict(text_tensor, verbose=0)
            
            # Handle multiple outputs (predictions and attention scores)
            if isinstance(predictions, list):
                y_pred = predictions[0]
                attention_scores = predictions[1] if len(predictions) > 1 else None
            else:
                y_pred = predictions
                attention_scores = None
            
            # Get the probabilities for the next token (at the position after input)
            next_token_probs = y_pred[0][len(current_sequence.split())-1]
            
            # Get top-5 predictions using numpy argsort (cleaner approach)
            k = 5
            top_5_indices = np.argsort(next_token_probs)[-k:][::-1]
            top_5_probs = next_token_probs[top_5_indices]
            
            # Convert indices to diagnosis codes
            top_5_codes = []
            for idx, prob in zip(top_5_indices, top_5_probs):
                if idx < len(vocab):
                    code = vocab[idx]
                    if code and code not in ['', '[UNK]', '[PAD]']:
                        top_5_codes.append((code, float(prob)))
                    else:
                        top_5_codes.append((f"Unknown_{idx}", float(prob)))
                else:
                    top_5_codes.append((f"Token_{idx}", float(prob)))
            
            # Get the selected diagnosis code (argmax)
            selected_code = top_5_codes[0][0] if top_5_codes else vocab[top_5_indices[0]]
            
            if show_progress:
                print(f"\nIteration {iteration + 1}:")
                print(f"Current sequence: {current_sequence}")
                print(f"Top 5 next predictions:")
                for i, (code, prob) in enumerate(top_5_codes):
                    print(f"Diagnosis Code: {code}, Probability: {prob:.4f}")
                print(f"Selected (argmax): {selected_code}")
                       
            # Store prediction info
            step_info = {
                'iteration': iteration + 1,
                'input_sequence': current_sequence,
                'top_5_predictions': top_5_codes,
                'selected_code': selected_code,
                'confidence': float(top_5_probs[0]),
                'attention_scores': attention_scores
            }
            prediction_history.append(step_info)
            
            # Append selected code to sequence
            current_sequence += f" {selected_code}"
            
        except Exception as e:
            print(f"Error in iteration {iteration + 1}: {e}")
            break
    
    print("\n" + "=" * 60)
    print(f"Final sequence: {current_sequence}")
    print(f"Extended from {len(initial_sequence.split())} to {len(current_sequence.split())} codes")
    
    return current_sequence, prediction_history


In [36]:
# Test it
final_seq, history = iterative_diagnosis_prediction_v2('6826', num_iterations=5)

Starting iterative prediction from: '6826'

Iteration 1:
Current sequence: 6826
Top 5 next predictions:
Diagnosis Code: 486, Probability: 3.2188
Diagnosis Code: 4280, Probability: 3.1891
Diagnosis Code: 0389, Probability: 2.9358
Diagnosis Code: 51881, Probability: 2.9218
Diagnosis Code: v5789, Probability: 2.8318
Selected (argmax): 486

Iteration 2:
Current sequence: 6826 486
Top 5 next predictions:
Diagnosis Code: 486, Probability: 3.2188
Diagnosis Code: 4280, Probability: 3.1891
Diagnosis Code: 0389, Probability: 2.9358
Diagnosis Code: 51881, Probability: 2.9218
Diagnosis Code: v5789, Probability: 2.8318
Selected (argmax): 486

Iteration 3:
Current sequence: 6826 486 486
Top 5 next predictions:
Diagnosis Code: 486, Probability: 3.2188
Diagnosis Code: 4280, Probability: 3.1891
Diagnosis Code: 0389, Probability: 2.9358
Diagnosis Code: 51881, Probability: 2.9218
Diagnosis Code: v5789, Probability: 2.8318
Selected (argmax): 486
Error in iteration 4: index 3 is out of bounds for axis 0 wi