<a href="https://colab.research.google.com/github/mr-fares10/project490/blob/main/lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

class JetPositionPredictor:
    def __init__(self, sequence_length=10):
        """
        Initialize the jet position predictor

        Args:
            sequence_length: Number of past frames to use for prediction
        """
        self.sequence_length = sequence_length
        self.model = None
        self.scaler = MinMaxScaler()

    def extract_frames(self, video_path):
        """
        Extract frames from video file

        Args:
            video_path: Path to video file

        Returns:
            List of frames
        """
        frames = []
        cap = cv2.VideoCapture(video_path)

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

        cap.release()
        return frames

    def detect_jet_centers(self, frames):
        """
        Detect the center of the jet in each frame
        This is a placeholder for your existing jet center detection code

        Args:
            frames: List of video frames

        Returns:
            Array of jet center coordinates (x, y, z) for each frame
        """
        # Replace this with your actual jet detection code
        centers = []

        for frame in frames:
            # Example placeholder for detection
            # This should be replaced with your actual detection algorithm
            center = self.detect_jet_center_in_frame(frame)
            centers.append(center)

        return np.array(centers)

    def detect_jet_center_in_frame(self, frame):
        """
        Placeholder for your existing jet center detection in a single frame

        Args:
            frame: Single video frame

        Returns:
            (x, y, z) coordinates of jet center
        """
        # Replace with your actual jet center detection code
        # This is just a placeholder
        # Assuming your existing code extracts (x, y, z) coordinates
        return [0, 0, 0]  # Replace with actual detection

    def create_sequences(self, positions):
        """
        Create sequences for LSTM training

        Args:
            positions: Array of (x, y, z) positions from each frame

        Returns:
            X: Input sequences
            y: Target positions (next frame)
        """
        X, y = [], []

        # Scale the data
        scaled_positions = self.scaler.fit_transform(positions)

        for i in range(len(scaled_positions) - self.sequence_length):
            X.append(scaled_positions[i:i + self.sequence_length])
            y.append(scaled_positions[i + self.sequence_length])

        return np.array(X), np.array(y)

    def build_model(self, input_shape):
        """
        Build the LSTM model

        Args:
            input_shape: Shape of input sequences

        Returns:
            Compiled LSTM model
        """
        model = Sequential()

        # LSTM layers
        model.add(LSTM(64, activation='relu', input_shape=input_shape, return_sequences=True))
        model.add(Dropout(0.2))

        model.add(LSTM(64, activation='relu'))
        model.add(Dropout(0.2))

        # Output layer (x, y, z)
        model.add(Dense(3))

        # Compile the model
        model.compile(optimizer='adam', loss='mse', metrics=['mae'])

        self.model = model
        return model

    def train(self, X, y, epochs=50, batch_size=32, validation_split=0.2):
        """
        Train the LSTM model

        Args:
            X: Input sequences
            y: Target positions
            epochs: Number of training epochs
            batch_size: Batch size
            validation_split: Validation data fraction

        Returns:
            Training history
        """
        if self.model is None:
            self.build_model(X.shape[1:])

        history = self.model.fit(
            X, y,
            epochs=epochs,
            batch_size=batch_size,
            validation_split=validation_split,
            verbose=1
        )

        return history

    def predict_next_position(self, sequence):
        """
        Predict the next position based on a sequence of positions

        Args:
            sequence: Sequence of recent positions

        Returns:
            Predicted next (x, y, z) position
        """
        if self.model is None:
            raise ValueError("Model has not been trained yet")

        # Scale the input sequence
        scaled_sequence = self.scaler.transform(sequence)

        # Reshape for prediction
        scaled_sequence = scaled_sequence.reshape(1, self.sequence_length, 3)

        # Predict and inverse transform
        scaled_prediction = self.model.predict(scaled_sequence)
        prediction = self.scaler.inverse_transform(scaled_prediction)

        return prediction[0]

    def evaluate_predictions(self, test_video_path):
        """
        Evaluate predictions on a test video

        Args:
            test_video_path: Path to test video

        Returns:
            Mean squared error between predictions and actual positions
        """
        # Extract frames and detect centers
        frames = self.extract_frames(test_video_path)
        actual_positions = self.detect_jet_centers(frames)

        # Make predictions
        predicted_positions = []

        for i in range(self.sequence_length, len(actual_positions)):
            sequence = actual_positions[i - self.sequence_length:i]
            predicted_pos = self.predict_next_position(sequence)
            predicted_positions.append(predicted_pos)

        # Calculate error
        actual = actual_positions[self.sequence_length:]
        predicted = np.array(predicted_positions)

        mse = np.mean(np.square(actual - predicted))
        print(f"Mean Squared Error: {mse}")

        # Visualize results
        self.visualize_predictions(actual, predicted)

        return mse

    def visualize_predictions(self, actual, predicted):
        """
        Visualize the actual vs predicted positions

        Args:
            actual: Actual positions
            predicted: Predicted positions
        """
        # Plot actual vs predicted for each dimension
        dims = ['X', 'Y', 'Z']

        plt.figure(figsize=(15, 10))

        for i in range(3):
            plt.subplot(3, 1, i+1)
            plt.plot(actual[:, i], label='Actual')
            plt.plot(predicted[:, i], label='Predicted')
            plt.title(f'{dims[i]} Position')
            plt.xlabel('Frame')
            plt.ylabel('Position')
            plt.legend()

        plt.tight_layout()
        plt.savefig('position_predictions.png')
        plt.show()

    def save_model(self, filepath):
        """
        Save the trained model

        Args:
            filepath: Path to save the model
        """
        if self.model is None:
            raise ValueError("No model to save")

        self.model.save(filepath)

    def load_model(self, filepath):
        """
        Load a trained model

        Args:
            filepath: Path to the trained model
        """
        self.model = tf.keras.models.load_model(filepath)

# Example usage
if __name__ == "__main__":
    # Initialize predictor
    predictor = JetPositionPredictor(sequence_length=10)

    # Training phase
    training_video = "path/to/training_video.mp4"
    frames = predictor.extract_frames(training_video)
    positions = predictor.detect_jet_centers(frames)

    # Create sequences
    X, y = predictor.create_sequences(positions)

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Build and train model
    predictor.build_model(X_train.shape[1:])
    history = predictor.train(X_train, y_train, epochs=50)

    # Evaluate model
    test_loss = predictor.model.evaluate(X_test, y_test)
    print(f"Test Loss: {test_loss}")

    # Save the model
    predictor.save_model("jet_position_model.h5")

    # Test on new video
    test_video = "path/to/test_video.mp4"
    mse = predictor.evaluate_predictions(test_video)

In [None]:
import cv2
import numpy as np
import tensorflow as tf
from collections import deque
import time

class RealtimeJetPredictor:
    def __init__(self, model_path, sequence_length=10, scaler=None):
        """
        Initialize real-time jet position predictor

        Args:
            model_path: Path to trained LSTM model
            sequence_length: Length of sequence for prediction
            scaler: Fitted scaler for normalization
        """
        self.model = tf.keras.models.load_model(model_path)
        self.sequence_length = sequence_length
        self.scaler = scaler

        # Buffer to store recent positions
        self.position_buffer = deque(maxlen=sequence_length)

        # Store predictions for visualization
        self.predictions = []
        self.actual_positions = []

    def detect_jet_center(self, frame):
        """
        Placeholder for jet center detection in a single frame
        Replace with your actual implementation

        Args:
            frame: Video frame

        Returns:
            (x, y, z) coordinates of jet center
        """
        # Replace with your actual jet detection code
        return [0, 0, 0]  # Example placeholder

    def predict_next_position(self):
        """
        Predict the next position based on recent positions

        Returns:
            Predicted (x, y, z) position for next frame
        """
        if len(self.position_buffer) < self.sequence_length:
            # Not enough data for prediction
            return None

        # Prepare sequence for prediction
        recent_positions = np.array(list(self.position_buffer))

        # Scale the sequence
        if self.scaler:
            scaled_positions = self.scaler.transform(recent_positions)
        else:
            # If no scaler provided, normalize between 0-1
            # This is less accurate than using the trained scaler
            min_vals = np.min(recent_positions, axis=0)
            max_vals = np.max(recent_positions, axis=0)
            range_vals = max_vals - min_vals
            # Avoid division by zero
            range_vals[range_vals == 0] = 1
            scaled_positions = (recent_positions - min_vals) / range_vals

        # Reshape for prediction
        scaled_sequence = scaled_positions.reshape(1, self.sequence_length, 3)

        # Predict
        scaled_prediction = self.model.predict(scaled_sequence, verbose=0)

        # Inverse transform
        if self.scaler:
            prediction = self.scaler.inverse_transform(scaled_prediction)
        else:
            # Inverse of manual normalization
            prediction = (scaled_prediction * range_vals) + min_vals

        return prediction[0]

    def process_video_stream(self, video_source=0, display=True, predict_every=1):
        """
        Process video stream and make predictions

        Args:
            video_source: Camera index or video file path
            display: Whether to display the video with predictions
            predict_every: Make prediction every N frames
        """
        cap = cv2.VideoCapture(video_source)
        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Detect jet center in current frame
            current_position = self.detect_jet_center(frame)
            self.actual_positions.append(current_position)

            # Add to buffer
            self.position_buffer.append(current_position)

            # Make prediction every N frames
            if frame_count % predict_every == 0:
                predicted_position = self.predict_next_position()

                if predicted_position is not None:
                    self.predictions.append(predicted_position)

                    if display:
                        # Draw actual position
                        x, y, z = current_position
                        cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 0), -1)

                        # Draw predicted position
                        pred_x, pred_y, pred_z = predicted_position
                        cv2.circle(frame, (int(pred_x), int(pred_y)), 5, (0, 0, 255), -1)

                        # Add text
                        cv2.putText(frame, f"Actual: ({x:.1f}, {y:.1f}, {z:.1f})",
                                   (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                        cv2.putText(frame, f"Predicted: ({pred_x:.1f}, {pred_y:.1f}, {pred_z:.1f})",
                                   (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

            if display:
                cv2.imshow('Jet Position Prediction', frame)

                # Exit on 'q' press
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

            frame_count += 1

        cap.release()
        if display:
            cv2.destroyAllWindows()

        return self.actual_positions, self.predictions

    def calculate_prediction_error(self):
        """
        Calculate prediction error metrics

        Returns:
            Dictionary of error metrics
        """
        if not self.predictions or len(self.predictions) != len(self.actual_positions[self.sequence_length:]):
            return None

        # Convert to numpy arrays
        actual = np.array(self.actual_positions[self.sequence_length:])
        predicted = np.array(self.predictions)

        # Calculate error metrics
        mse = np.mean(np.square(actual - predicted))
        mae = np.mean(np.abs(actual - predicted))

        return {
            "MSE": mse,
            "MAE": mae,
            "RMSE": np.sqrt(mse)
        }

    def visualize_results(self):
        """
        Visualize prediction results
        """
        if not self.predictions:
            print("No predictions available for visualization")
            return

        actual = np.array(self.actual_positions[self.sequence_length:])
        predicted = np.array(self.predictions)

        # Plot results
        dims = ['X', 'Y', 'Z']

        import matplotlib.pyplot as plt
        plt.figure(figsize=(15, 10))

        for i in range(3):
            plt.subplot(3, 1, i+1)
            plt.plot(actual[:, i], label='Actual')
            plt.plot(predicted[:, i], label='Predicted')
            plt.title(f'{dims[i]} Position')
            plt.xlabel('Frame')
            plt.ylabel('Position')
            plt.legend()

        plt.tight_layout()
        plt.savefig('realtime_predictions.png')
        plt.show()

        # Print error metrics
        errors = self.calculate_prediction_error()
        if errors:
            for metric, value in errors.items():
                print(f"{metric}: {value}")

# Example usage
if __name__ == "__main__":
    # Initialize predictor
    predictor = RealtimeJetPredictor(
        model_path="jet_position_model.h5",
        sequence_length=10
    )

    # Process video (use 0 for webcam or file path for video file)
    predictor.process_video_stream("path/to/test_video.mp4", display=True)

    # Visualize results
    predictor.visualize_results()

In [None]:
# Import your existing jet detection module
# For example:
# from jet_detector import detect_jet_center

import numpy as np
import cv2
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
from collections import deque
import matplotlib.pyplot as plt
import os

def integrate_with_existing_detector():
    """
    This function shows how to integrate your existing jet detection code
    with the LSTM prediction model
    """
    # Step 1: Prepare training data using your existing detector
    training_data = []

    # Example directory of training frames
    training_frames_dir = "path/to/training/frames"
    frame_files = sorted([f for f in os.listdir(training_frames_dir) if f.endswith('.jpg')])

    for frame_file in frame_files:
        frame_path = os.path.join(training_frames_dir, frame_file)
        frame = cv2.imread(frame_path)

        # Call your existing detection function
        # Replace this with your actual function call
        # xyz = detect_jet_center(frame)
        xyz = [0, 0, 0]  # Replace with your actual detection

        training_data.append(xyz)

    # Convert to numpy array
    training_data = np.array(training_data)

    # Step 2: Create sequences for LSTM
    sequence_length = 10
    X, y = [], []

    # Normalize data
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(training_data)

    for i in range(len(normalized_data) - sequence_length):
        X.append(normalized_data[i:i + sequence_length])
        y.append(normalized_data[i + sequence_length])

    X = np.array(X)
    y = np.array(y)

    # Step 3: Build and train LSTM model
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(64, activation='relu', input_shape=(sequence_length, 3), return_sequences=True),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.LSTM(64, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(3)
    ])

    model.compile(optimizer='adam', loss='mse', metrics=['mae'])

    # Train the model
    model.fit(X, y, epochs=50, batch_size=32, validation_split=0.2)

    # Save the model and scaler
    model.save("jet_position_model.h5")

    # Step 4: Real-time prediction with feedback loop
    def predict_with_feedback(video_path, model, scaler, sequence_length=10):
        """
        Predict jet positions with feedback loop - correcting predictions
        with actual detections

        Args:
            video_path: Path to test video
            model: Trained LSTM model
            scaler: Fitted scaler
            sequence_length: Length of input sequence
        """
        cap = cv2.VideoCapture(video_path)

        # Initialize buffers
        position_buffer = deque(maxlen=sequence_length)
        actual_positions = []
        predicted_positions = []
        corrected_predictions = []

        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Detect actual position
            # Replace with your actual detection function
            # actual_xyz = detect_jet_center(frame)
            actual_xyz = [0, 0, 0]  # Replace with your actual detection

            actual_positions.append(actual_xyz)

            # If we have enough frames, make a prediction
            if frame_count >= sequence_length:
                # Get recent positions
                recent_positions = np.array(list(position_buffer))

                # Scale the positions
                scaled_positions = scaler.transform(recent_positions)

                # Reshape for prediction
                scaled_sequence = scaled_positions.reshape(1, sequence_length, 3)

                # Predict next position
                scaled_prediction = model.predict(scaled_sequence, verbose=0)
                prediction = scaler.inverse_transform(scaled_prediction)[0]

                predicted_positions.append(prediction)

                # Visualization (optional)
                x, y, z = actual_xyz
                pred_x, pred_y, pred_z = prediction

                cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 0), -1)  # Actual (green)
                cv2.circle(frame, (int(pred_x), int(pred_y)), 5, (0, 0, 255), -1)  # Predicted (red)

                # Calculate error
                error = np.array(actual_xyz) - prediction

                # Store for visualization
                if len(corrected_predictions) > 0:
                    # Previous corrected prediction
                    prev_corrected = corrected_predictions[-1]

                    # Calculate correction based on error trend
                    alpha = 0.7  # Learning rate for correction
                    correction = prediction + (alpha * error)

                    corrected_predictions.append(correction)

                    # Draw corrected prediction
                    corr_x, corr_y, corr_z = correction
                    cv2.circle(frame, (int(corr_x), int(corr_y)), 5, (255, 0, 0), -1)  # Corrected (blue)
                else:
                    # First correction is just actual position
                    corrected_predictions.append(actual_xyz)

            # Add current position to buffer
            position_buffer.append(actual_xyz)

            # Display
            cv2.imshow('Prediction with Feedback', frame)
            if cv2.waitKey(30) & 0xFF == ord('q'):
                break

            frame_count += 1

        cap.release()
        cv2.destroyAllWindows()

        # Return results for analysis
        return {
            'actual': np.array(actual_positions[sequence_length:]),
            'predicted': np.array(predicted_positions),
            'corrected': np.array(corrected_predictions)
        }

    # Run prediction with feedback
    results = predict_with_feedback(
        "path/to/test_video.mp4",
        model,
        scaler,
        sequence_length
    )

    # Analyze results
    analyze_prediction_results(results)

def analyze_prediction_results(results):
    """
    Analyze prediction results and visualize

    Args:
        results: Dictionary with actual, predicted, and corrected positions
    """
    actual = results['actual']
    predicted = results['predicted']
    corrected = results['corrected']

    # Calculate error metrics
    pred_mse = np.mean(np.square(actual - predicted))
    corr_mse = np.mean(np.square(actual - corrected))

    pred_mae = np.mean(np.abs(actual - predicted))
    corr_mae = np.mean(np.abs(actual - corrected))

    print(f"Original Prediction MSE: {pred_mse:.4f}")
    print(f"Corrected Prediction MSE: {corr_mse:.4f}")
    print(f"Improvement: {100 * (1 - corr_mse/pred_mse):.2f}%")

    print(f"Original Prediction MAE: {pred_mae:.4f}")
    print(f"Corrected Prediction MAE: {corr_mae:.4f}")

    # Visualize results
    dims = ['X', 'Y', 'Z']
    colors = ['g', 'r', 'b']
    labels = ['Actual', 'Predicted', 'Corrected']

    plt.figure(figsize=(15, 12))

    for i in range(3):
        plt.subplot(3, 1, i+1)

        plt.plot(actual[:, i], color=colors[0], label=labels[0])
        plt.plot(predicted[:, i], color=colors[1], label=labels[1])
        plt.plot(corrected[:, i], color=colors[2], label=labels[2])

        plt.title(f'{dims[i]} Position')
        plt.xlabel('Frame')
        plt.ylabel('Position')
        plt.legend()

    plt.tight_layout()
    plt.savefig('prediction_analysis.png')
    plt.show()

    # Plot errors over time
    plt.figure(figsize=(15, 6))

    # Calculate Euclidean distance error
    pred_error = np.sqrt(np.sum(np.square(actual - predicted), axis=1))
    corr_error = np.sqrt(np.sum(np.square(actual - corrected), axis=1))

    plt.plot(pred_error, color='r', label='Original Prediction Error')
    plt.plot(corr_error, color='b', label='Corrected Prediction Error')

    plt.title('Prediction Error Over Time')
    plt.xlabel('Frame')
    plt.ylabel('Euclidean Distance Error')
    plt.legend()

    plt.tight_layout()
    plt.savefig('prediction_error.png')
    plt.show()

if __name__ == "__main__":
    integrate_with_existing_detector()

Upload Your Video Data

In [None]:
from google.colab import files
uploaded = files.upload()  # This will prompt you to select files

 Modify the Jet Detection Code

In [None]:
%%writefile jet_detector.py
# Paste your existing jet center detection code here
# Ensure it has a function that takes a frame and returns (x,y,z) coordinates

def detect_jet_center(frame):
    # Your existing detection code here
    # ...
    return [x, y, z]  # Return center coordinates

Update the Predictor Class to Use Your Detector

In [None]:
# Import necessary libraries
import cv2
import numpy as np
import tensorflow as tf
from jet_position_predictor import JetPositionPredictor

# Modify the detector method to use your code
from jet_detector import detect_jet_center

# Create a subclass that uses your detection function
class MyJetPredictor(JetPositionPredictor):
    def detect_jet_center_in_frame(self, frame):
        # Use your detection function
        return detect_jet_center(frame)

Prepare Training Data

In [None]:
# Initialize your predictor
predictor = MyJetPredictor(sequence_length=10)

# Path to your uploaded video
training_video_path = "your_training_video.mp4"  # Update with actual filename

# Extract frames and detect jet centers
print("Extracting frames...")
frames = predictor.extract_frames(training_video_path)
print(f"Extracted {len(frames)} frames")

print("Detecting jet centers...")
positions = predictor.detect_jet_centers(frames)
print(f"Detected positions shape: {positions.shape}")

# Save positions to avoid recomputing them
np.save("jet_positions.npy", positions)

# Alternatively, if detection takes too long, you can load previously saved positions
# positions = np.load("jet_positions.npy")

Create Training Sequences

In [None]:
# Create sequences for LSTM training
print("Creating sequences...")
X, y = predictor.create_sequences(positions)
print(f"X shape: {X.shape}, y shape: {y.shape}")

# Split data into training and validation sets
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training samples: {X_train.shape[0]}, Validation samples: {X_val.shape[0]}")

Build and Train the Model

In [None]:
# Build the model
print("Building model...")
predictor.build_model(X_train.shape[1:])
predictor.model.summary()

# Train the model
print("Training model...")
history = predictor.train(
    X_train, y_train,
    epochs=50,
    batch_size=32,
    validation_split=0.2
)

# Plot training history
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.subplot(1, 2, 2)
plt.plot(history.history['mae'])
plt.plot(history.history['val_mae'])
plt.title('Model MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()

 Evaluate on Validation Set

In [None]:
# Evaluate model on validation set
val_loss, val_mae = predictor.model.evaluate(X_val, y_val)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation MAE: {val_mae:.4f}")

# Make predictions on validation set
val_pred = predictor.model.predict(X_val)

# Plot predictions vs actual for first few validation samples
plt.figure(figsize=(15, 10))
sample_idx = 0  # Change this to view different samples

for i in range(3):  # Plot X, Y, Z separately
    plt.subplot(3, 1, i+1)
    plt.plot(y_val[sample_idx, i], 'g-', label=f'Actual')
    plt.plot(val_pred[sample_idx, i], 'r--', label=f'Predicted')
    plt.title(f'Dimension {i} ({"XYZ"[i]})')
    plt.legend()

plt.tight_layout()
plt.show()

Save the Trained Model

In [None]:
# Save the trained model
predictor.save_model("jet_position_model.h5")
print("Model saved successfully")

# Save the scaler too (needed for future predictions)
import pickle
with open('scaler.pkl', 'wb') as f:
    pickle.dump(predictor.scaler, f)
print("Scaler saved successfully")

Test on a New Video

In [None]:
# Path to test video
test_video_path = "your_test_video.mp4"  # Update with actual filename

# Import RealtimeJetPredictor
from realtime_prediction import RealtimeJetPredictor

# Create a subclass with your detection code
class MyRealtimePredictor(RealtimeJetPredictor):
    def detect_jet_center(self, frame):
        # Use your detection function
        return detect_jet_center(frame)

# Load the saved scaler
with open('scaler.pkl', 'rb') as f:
    scaler = pickle.load(f)

# Initialize real-time predictor
rt_predictor = MyRealtimePredictor(
    model_path="jet_position_model.h5",
    sequence_length=10,
    scaler=scaler
)

# Process the test video
actual_positions, predictions = rt_predictor.process_video_stream(
    test_video_path,
    display=True,
    predict_every=1
)

# Visualize results
rt_predictor.visualize_results()

# Calculate prediction errors
error_metrics = rt_predictor.calculate_prediction_error()
for metric, value in error_metrics.items():
    print(f"{metric}: {value:.4f}")