<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/PLDM_JEPA_FP_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install av -q
!pip install colab-env -q
import colab_env

In [None]:
!nvidia-smi

Thu Jul 24 16:41:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   48C    P8             12W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Conceptual Modifications to Cell 1: All Setup, Definitions, and Model Instantiations

In [None]:
# Cell 1: Conceptual Modifications - Aviation Data Definitions

import torch
import numpy as np
import os
import glob
import av
import json
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoVideoProcessor, AutoModel
from tqdm.auto import tqdm
import logging
import datetime
import pytz
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s')

import google.generativeai as genai
from google.colab import userdata

GOOGLE_API_KEY = userdata.get('GEMINI')
if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)
    print("Google Generative AI configured successfully using Colab Secrets.")
else:
    print("WARNING: GOOGLE_API_KEY not found in Colab Secrets. Please ensure it's set for API calls.")
    print("API calls will likely fail. Proceeding with unconfigured API.")

class AgentConfig:
    LLM_MODEL_NAME: str = "gemini-2.5-flash"
    CLASS_LABELS = [
        "airplane landing",
        "airplane takeoff",
        "airport ground operations",
        "in-flight cruise",
        "emergency landing",
        "pre-flight check/maintenance",
        "en-route cruise",
        "climb phase",
        "descent phase",
        "holding pattern"
    ]

# Define num_classes globally
num_classes = len(AgentConfig.CLASS_LABELS)

# --- FIX: CLASSIFIER_SAVE_PATH moved to global scope ---
CLASSIFIER_SAVE_PATH = "classifier_head_trained_on_tartan_aviation_sample.pth"

AIRPORTS = {
    "CYUL": {"name": "Montreal-Trudeau International", "lat": 45.4706, "lon": -73.7408, "elevation_ft": 118},
    "LFPG": {"name": "Paris-Charles de Gaulle", "lat": 49.0097, "lon": 2.5479, "elevation_ft": 392},
}

AIRCRAFT_PERFORMANCE = {
    "Boeing777_300ER": {
        "cruise_speed_knots": 490,
        "fuel_burn_kg_per_hour": 7000,
        "max_range_nm": 7900,
        "climb_rate_fpm": 2500,
        "descent_rate_fpm": 2000,
        "typical_cruise_altitude_ft": 37000,
        "fuel_capacity_kg": 145000
    }
}

hf_repo = "facebook/vjepa2-vitg-fpc64-256"
EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"

TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408

CONCEPTUAL_PLDM_LATENT_DIM = 1024

latent_dim_pldm = CONCEPTUAL_PLDM_LATENT_DIM
action_dim = 8

def load_and_process_video(video_path, processor_instance, model_instance, device_instance, num_frames_to_sample=16):
    """
    Loads a video, samples frames, and extracts V-JEPA features.
    Returns extracted features (torch.Tensor, shape like [1, 2048, 1408]) and the frames.
    Does NOT flatten the V-JEPA output here, keeping it as model's raw output.
    """
    frames = []
    if not os.path.exists(video_path):
        logging.error(f"ERROR: Video file '{video_path}' not found.")
        return None, None
    try:
        container = av.open(video_path)
        total_frames_in_video = container.streams.video[0].frames
        sampling_interval = max(1, total_frames_in_video // num_frames_to_sample)
        logging.info(f"Total frames in video: {total_frames_in_video}")
        logging.info(f"Sampling interval: {sampling_interval} frames")

        for i, frame in enumerate(container.decode(video=0)):
            if len(frames) >= num_frames_to_sample:
                break
            if i % sampling_interval == 0:
                img = frame.to_rgb().to_ndarray()
                frames.append(img)

        if not frames:
            logging.error(f"ERROR: No frames could be loaded from '{video_path}'.")
            return None, None
        elif len(frames) < num_frames_to_sample:
            logging.warning(f"WARNING: Only {len(frames)} frames loaded. Requested: {num_frames_to_sample}.")

        inputs = processor_instance(videos=list(frames), return_tensors="pt")
        inputs = {k: v.to(device_instance) for k, v in inputs.items()}

        with torch.no_grad():
            features = model_instance(**inputs).last_hidden_state

        logging.info(f"Successfully extracted V-JEPA features with raw shape: {features.shape}")
        return features, frames

    except av.FFmpegError as e:
        logging.error(f"Error loading video with PyAV: {e}")
        logging.error("This might indicate an issue with the video file itself or PyAV installation.")
        return None, None
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        logging.error("Ensure 'av' library is installed (`pip install av`) and video file is not corrupt.")
        return None, None

class ClassifierHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        return self.fc2(self.dropout(self.relu(self.fc1(x))))

class LatentDynamicsPredictor(torch.nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(latent_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, latent_state, action):
        combined_input = torch.cat([latent_state, action], dim=-1)
        predicted_next_latent_state = self.layers(combined_input)
        return predicted_next_latent_state

class LatentProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.projector = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.projector(x))

print("\n--- Instantiating Models and Optimizers ---")
model = AutoModel.from_pretrained(hf_repo)
processor = AutoVideoProcessor.from_pretrained(hf_repo)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

latent_projector = LatentProjector(TOTAL_FLATTENED_VJEPA_DIM, CONCEPTUAL_PLDM_LATENT_DIM)
latent_projector.to(device)

predictor = LatentDynamicsPredictor(latent_dim_pldm, action_dim)
predictor.to(device)
optimizer_pldm = torch.optim.Adam(list(predictor.parameters()) + list(latent_projector.parameters()), lr=0.001)

classifier = ClassifierHead(input_dim=1408, num_classes=num_classes)
classifier.to(device)

print(f"Models instantiated and moved to {device}.")
print("\nCell 1 setup complete for conceptual flight planning.")

## Cell 2: Core Execution - Feature Extraction, Classifier Training & Inference, LLM Interaction

In [None]:
#Cell 2: Core Execution Feature Extraction, Classifier Training & Inference, LLM Interaction, and PLDM Training/Planning
# This cell assumes Cell 1 has been successfully executed in the current session.
# All objects (model, processor, classifier, predictor, device, optimizer_pldm)
# and all function definitions (load_and_process_video, ClassifierHead, LatentDynamicsPredictor)
# are expected to be available from Cell 1's execution.
import os
import logging
import torch
import json
from google.colab import drive
from tqdm.auto import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import datetime
import pytz

#Mounting Google Drive
print("\n--- Cell 2: Mounting Google Drive for dataset access ---")
drive.mount('/content/gdrive')
print("Google Drive mounted.")

print(f"Checking for extracted features directory: {EXTRACTED_FEATURES_DIR}")
if not os.path.exists(EXTRACTED_FEATURES_DIR):
    logging.error(f"ERROR: Extracted features directory '{EXTRACTED_FEATURES_DIR}' not found. Please create it and upload V-JEPA features.")
    exit()
else:
    print(f"Extracted features directory found at {EXTRACTED_FEATURES_DIR}")

# Part 1: Load and process airplane-landing.mp4 for initial observation
print(f"\n--- Cell 2: Part 1 - Loading actual video '/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/airplane-landing.mp4' for feature extraction ---")
flight_video_path = '/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/airplane-landing.mp4'
# Use the defined load_and_process_video helper function. It now returns features and frames.
video_features_for_inference_raw, frames_for_pldm_planning = load_and_process_video(flight_video_path, processor, model, device_instance=device)

# -- CRITICAL: Process raw V-JEPA features to match ClassifierHead's expected input_dim --
if video_features_for_inference_raw is not None:
    # V-JEPA output shape is typically [1, 2048, 1408] (Batch, Channels, Height * Width if 1D)
    # Your old code pooled it as squeeze(0).mean(dim=0).unsqueeze(0), which resulted in [1, 1408] for classifier.
    # So, extracted_embedding_dim should be 1408 for the classifier.
    pooled_features_for_classifier = video_features_for_inference_raw.squeeze(0).mean(dim=0).unsqueeze(0)
    extracted_embedding_dim_for_classifier = pooled_features_for_classifier.shape[-1]
    logging.info(f"Dynamically determined extracted_embedding_dim for ClassifierHead: {extracted_embedding_dim_for_classifier}")
else:
    pooled_features_for_classifier = None
    extracted_embedding_dim_for_classifier = -1
    logging.error("Failed to extract video features for classifier. Exiting Cell 2.")
    exit()

# Part 2: Classifier Training
print(f"\n--- Cell 2: Part 2 - Starting Classifier Training ---")
print(f"Attempting to load real V-JEPA features for classifier training or generate synthetic data.")
print(f"Using device for classifier training: {device}")

try:
    # Re-initialize classifier with the correct, dynamically determined input_dim
    classifier = ClassifierHead(input_dim=extracted_embedding_dim_for_classifier, num_classes=num_classes)
    classifier.to(device)

    train_features_list = []
    train_labels_list = []

    map_file_path = os.path.join(EXTRACTED_FEATURES_DIR, "feature_label_map.json")
    if not os.path.exists(map_file_path):
        logging.warning(f"Feature-label map file '{map_file_path}' not found. Generating synthetic data.")
        feature_label_map = {}
    else:
        with open(map_file_path, 'r') as f:
            feature_label_map = json.load(f)

    if not feature_label_map:
        logging.warning(f"Feature-label map at {map_file_path} is empty. Generating synthetic data.")
        num_training_samples = 2_000_000
        # Synthetic data generation uses the dynamically determined input_dim
        train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier)
        train_labels = torch.randint(0, num_classes, (num_training_samples,))
        train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=32, shuffle=True)
        val_loader = None
        print(f"Loaded {num_training_samples} SYNTHETIC features for training.")
    else:
        for item in tqdm(feature_label_map, desc="Loading real V-JEPA features"):
            feature_path = item['feature_path']
            label_idx = item['label_idx']
            try:
                if not os.path.isabs(feature_path):
                    feature_path = os.path.join(EXTRACTED_FEATURES_DIR, feature_path)

                if not os.path.exists(feature_path):
                    logging.warning(f"Feature file not found at {feature_path}. Skipping.")
                    continue

                feature = torch.load(feature_path, map_location=device)

                # Match your working code's pooling/squashing logic to get [1408] dim
                if feature.ndim == 3:
                    feature = feature.squeeze(0).mean(dim=0)
                elif feature.ndim == 2:
                    if feature.shape[0] == 1 and feature.shape[1] == 1408:
                        feature = feature.squeeze(0)
                    elif feature.shape[1] == 1408:
                        feature = feature.mean(dim=0)
                    else:
                        feature = feature.flatten()
                elif feature.ndim == 1:
                    pass
                else:
                    logging.warning(f"Skipping malformed feature from {feature_path}. Unexpected dimensions: {feature.ndim}")
                    continue

                # Final check after processing. Should be 1D with 1408 elements.
                if feature.shape[0] != extracted_embedding_dim_for_classifier:
                    logging.warning(f"Skipping feature at {feature_path}. Dimension mismatch: expected {extracted_embedding_dim_for_classifier}, got {feature.shape[0]}.")
                    continue

                train_features_list.append(feature)
                train_labels_list.append(label_idx)

            except Exception as e:
                logging.error(f"Error loading feature from {feature_path}: {e}. Skipping.")

        if train_features_list:
            train_features = torch.stack(train_features_list).to(device)
            train_labels = torch.tensor(train_labels_list).to(device)
            num_training_samples = len(train_features)
            print(f"Loaded {num_training_samples} REAL V-JEPA features for training.")

            if num_training_samples < 2:
                print("WARNING: Only 1 real V-JEPA feature loaded. Training may be unstable. Consider more data.")
                train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=1, shuffle=True)
                val_loader = None
            else:
                dataset_size = len(train_features)
                train_size = int(0.8 * dataset_size)
                val_size = dataset_size - train_size

                if val_size == 0 and train_size > 0:
                    train_size = dataset_size
                    train_dataset_real = TensorDataset(train_features, train_labels)
                    val_dataset_real = None
                else:
                    train_dataset_real, val_dataset_real = torch.utils.data.random_split(
                        TensorDataset(train_features, train_labels), [train_size, val_size]
                    )
                train_loader = DataLoader(train_dataset_real, batch_size=32, shuffle=True)
                val_loader = DataLoader(val_dataset_real, batch_size=32, shuffle=False) if val_dataset_real else None
                print(f"Training on {len(train_dataset_real)} samples, Validation on {len(val_dataset_real) if val_dataset_real else 0} samples.")
        else:
            logging.error("No real V-JEPA features could be loaded from map file. Generating synthetic data as fallback.")
            num_training_samples = 2_000_000
            train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier)
            train_labels = torch.randint(0, num_classes, (num_training_samples,))
            train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=32, shuffle=True)
            val_loader = None
            print(f"Loaded {num_training_samples} SYNTHETIC features for training as fallback.")

    criterion = torch.nn.CrossEntropyLoss()
    optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=0.001)

    num_epochs = 20
    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer_classifier.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_classifier.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader.dataset)

        val_loss = 0.0
        if val_loader and len(val_loader.dataset) > 0:
            classifier.eval()
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = classifier(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader.dataset)
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        else:
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}")
            print("No validation data available or validation dataset is empty.")


    print("\n--- Classifier Training Complete ---")
    torch.save(classifier.state_dict(), CLASSIFIER_SAVE_PATH)
    print(f"Classifier saved to: {CLASSIFIER_SAVE_PATH}")
except Exception as e:
    logging.error(f"Error during classifier training: {e}")


In [None]:
#Part 3: Classification Inference and Gemini LLM Interaction
print("\n--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---")
if pooled_features_for_classifier is None:
    logging.error("ERROR: Cannot perform classifier inference as 'pooled_features_for_classifier' is None.")
else:
    try:
        pooled_features_for_inference_on_device = pooled_features_for_classifier.to(device)

        classifier.load_state_dict(torch.load(CLASSIFIER_SAVE_PATH, map_location=device))
        logging.info(f"Classifier weights loaded from: {CLASSIFIER_SAVE_PATH}")

        classifier.eval()
        with torch.no_grad():
            logits = classifier(pooled_features_for_inference_on_device)
            probabilities = torch.softmax(logits, dim=1)

        predicted_class_idx = torch.argmax(probabilities, dim=1).item()
        predicted_confidence = probabilities[0, predicted_class_idx].item()
        # --- FIX: Use AgentConfig.CLASS_LABELS ---
        predicted_label = AgentConfig.CLASS_LABELS[predicted_class_idx]

        llm_input_description = ""
        if predicted_label == "airplane landing":
            llm_input_description = "The visual system detected an airplane landing."
        elif predicted_label == "airplane takeoff":
            llm_input_description = "The visual system detected an airplane takeoff."
        elif predicted_label == "airport ground operations":
            llm_input_description = "The visual system detected airport ground operations."
        elif predicted_label == "in-flight cruise":
            llm_input_description = "The visual system detected an airplane in-flight cruise."
        elif predicted_label == "emergency landing":
            llm_input_description = "The visual system detected a possible emergency landing."
        elif predicted_label == "pre-flight check/maintenance":
            llm_input_description = "The visual system detected pre-flight check or maintenance activity."
        else:
            llm_input_description = "The visual system detected an unrecognised flight activity."

        llm_input_description += f" (Confidence: {predicted_confidence:.2f})"

        print(f"\n--- AI Agent's Understanding from Classifier ---")
        print(f"**Primary Classification (Predicted by AI):** '{predicted_label}' {llm_input_description.split('Confidence:')[1].strip()}")
        print(f"**Description for LLM:** {llm_input_description}")
        print(f"Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.")

        print("\n--- Engaging Gemini LLM for Further Reasoning ---")
        try:
            llm_model = genai.GenerativeModel(AgentConfig.LLM_MODEL_NAME)

            prompt_for_gemini = f"""
                  You are an AI assistant for flight planning operations.
                  Current visual observation: {llm_input_description}
                  Current time (EST): {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}

                  Based on this visual observation, provide a concise operational assessment relevant to flight planning.
                  If the observation seems random or uncertain, state that. Do not add any conversational filler.
                  """

            gemini_response = llm_model.generate_content(prompt_for_gemini)

            print("\n--- Gemini LLM Response ---")
            if gemini_response.candidates:
                for candidate in gemini_response.candidates:
                    if candidate.content and candidate.content.parts:
                        for part in candidate.content.parts:
                            print(part.text)

                print("\n--- Gemini LLM Response - END ---")
                print('\n')
            else:

                print("Gemini LLM did not provide a text response or cannot provide one.")
                if gemini_response.prompt_feedback:
                    print(f"Prompt Feedback: {gemini_response.prompt_feedback}")
                if hasattr(gemini_response, 'error'):
                    print(f"LLM Error: {llm_e}")

        except Exception as llm_e:
            logging.error(f"Error interacting with Gemini LLM: {llm_e}")
            logging.error("Ensure your GOOGLE_API_KEY is correctly set in Colab Secrets.")

    except Exception as e:
        logging.error(f"Error during classification inference or overall Cell 2 execution: {e}")

print(f"The V-JEPA features (shape: {pooled_features_for_classifier.shape}) are the core input that a trained classifier would learn from.")
print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")

print("\nCell 2 execution complete.")


--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---

--- AI Agent's Understanding from Classifier ---
**Primary Classification (Predicted by AI):** 'airplane landing' 1.00)
**Description for LLM:** The visual system detected an airplane landing. (Confidence: 1.00)
Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.

--- Engaging Gemini LLM for Further Reasoning ---

--- Gemini LLM Response ---
Airplane landing observed. Routine arrival operation impacting runway and gate availability.

--- Gemini LLM Response - END ---


The V-JEPA features (shape: torch.Size([1, 1408])) are the core input that a trained classifier would learn from.
Current time in EST: 2025-07-24 11:49:36 EST

Cell 2 execution complete.


This prediction comes from a classifier that was trained on the provided V-JEPA features.
For *truly accurate and high-confidence predictions on real videos,
the classifier needs to be trained on a large, diverse dataset of *real V-JEPA features and their corresponding labels*.

## Conceptual Modifications to Cell 3: Conceptual PLDM Latent Dynamics Training

In [None]:
# Cell 3: Conceptual Modifications - PLDM Latent Dynamics Training for Real Flights

import os
import logging
import torch
from tqdm.auto import tqdm
import random

EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"
TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408
# Use the new conceptual latent dimension for dummy data generation
latent_dim = CONCEPTUAL_PLDM_LATENT_DIM
action_dim = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_flight_dynamics_training_data_real_conceptual(device, num_simulated_real_trajectories=100):
    """
    CONCEPTUAL FUNCTION: Loads reward-free offline flight data for Latent Dynamics training.
    For this conceptual demo, it simulates loading 'real-like' sequences.
    """
    print(f"Loading conceptual real flight dynamics data for training from Montreal-Paris context.")

    dynamics_training_data = []

    latent_vec_dim = latent_dim # This will now be CONCEPTUAL_PLDM_LATENT_DIM
    action_vec_dim = action_dim

    for _ in range(num_simulated_real_trajectories):
        current_latent = torch.rand(1, latent_vec_dim).to(device) * 0.1
        trajectory_length = random.randint(5, 20)

        for _ in range(trajectory_length):
            action = torch.randn(1, action_vec_dim).to(device) * 0.1

            # Project action's influence to match latent_vec_dim for dummy data generation
            action_influence_on_latent = action.mean() * torch.ones(1, latent_vec_dim).to(device) * 0.5

            next_latent = current_latent + action_influence_on_latent + (torch.randn(1, latent_vec_dim) * 0.005).to(device)

            dynamics_training_data.append((current_latent, action, next_latent))
            current_latent = next_latent.clone()

    print(f"Loaded {len(dynamics_training_data)} conceptual 'real-like' dynamics training samples.")
    return dynamics_training_data

def train_latent_dynamics_model(predictor_model, optimizer, training_data, epochs=10):
    predictor_model.train()
    print("\n-- Training Latent Dynamics Predictor for Conceptual Real Flights ---")
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (latent_s_t, action_t, latent_s_t_plus_1) in tqdm(
            enumerate(training_data),
            total=len(training_data),
            desc=f"Epoch {epoch+1}/{epochs}"
        ):
            latent_s_t, action_t, latent_s_t_plus_1 = latent_s_t.to(device), action_t.to(device), latent_s_t_plus_1.to(device)

            predicted_z_t_plus_1 = predictor_model(latent_s_t, action_t)
            loss = torch.nn.functional.mse_loss(predicted_z_t_plus_1, latent_s_t_plus_1)

            optimizer_pldm.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {total_loss / len(training_data)}")


print("\n--Cell 3: Starting Conceptual PLDM Latent Dynamics Training for Montreal-Paris context ---")

dynamics_training_data = load_flight_dynamics_training_data_real_conceptual(device)

if dynamics_training_data:
    train_latent_dynamics_model(predictor, optimizer_pldm, dynamics_training_data)
else:
    print("Skipping Latent Dynamics Training as no conceptual data was loaded.")

print("\nCell 3 execution complete.")

## Conceptual Modifications to Cell 4: Conceptual PLDM Planning

In [None]:
# Cell 4: Conceptual Modifications - PLDM Planning for Montreal to Paris

import logging
import torch
import numpy as np
import datetime
import pytz
from tqdm.auto import tqdm

TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408
# Use the new conceptual latent dimension for planning
latent_dim = CONCEPTUAL_PLDM_LATENT_DIM
action_dim = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def plan_montreal_to_paris_flight(start_airport_data, target_airport_data, aircraft_model_data,
                                  encoder_model, processor_instance, predictor_model, latent_projector_instance,
                                  planning_horizon=50):
    """
    CONCEPTUAL FUNCTION: Plans a flight path from Montreal to Paris using the learned latent dynamics.
    This is a conceptual MPPI implementation demonstrating the flow for a real flight scenario.
    """
    encoder_model.eval()
    predictor_model.eval()
    latent_projector_instance.eval() # Ensure projector is in eval mode

    # --- FIX: Project V-JEPA output to the smaller latent_dim for planning ---
    # First, get a dummy raw V-JEPA feature (as if from a video of Montreal airport environment)
    # In a real scenario, this would be an actual V-JEPA encoding of the initial state.
    dummy_initial_vjepa_feature_raw = torch.rand(1, 2048, 1408).to(device) # Matches V-JEPA's typical output shape
    dummy_target_vjepa_feature_raw = torch.rand(1, 2048, 1408).to(device) # Matches V-JEPA's typical output shape + slight difference

    # Flatten the raw V-JEPA output to match TOTAL_FLATTENED_VJEPA_DIM
    flattened_initial_vjepa_feature = dummy_initial_vjepa_feature_raw.flatten(start_dim=1)
    flattened_target_vjepa_feature = dummy_target_vjepa_feature_raw.flatten(start_dim=1)

    # Now project to the smaller CONCEPTUAL_PLDM_LATENT_DIM
    with torch.no_grad():
        current_latent_state = latent_projector_instance(flattened_initial_vjepa_feature)
        target_latent_state = latent_projector_instance(flattened_target_vjepa_feature)
    # --- END FIX ---

    print("\n--- Starting Conceptual Flight Plan: Montreal (CYUL) to Paris (LFPG) ---")
    print('\n')
    print(f"Conceptual Current Latent State Shape: {current_latent_state.shape}")
    print(f"Conceptual Target Latent State Shape: {target_latent_state.shape}")
    print('\n')

    best_action_sequence = []

    num_action_samples = 100

    for step in range(planning_horizon):
        candidate_actions = torch.rand(num_action_samples, action_dim).to(device)

        simulated_trajectories_cost = []

        for i in range(num_action_samples):
            simulated_next_latent = predictor_model(current_latent_state, candidate_actions[i].unsqueeze(0))

            # 1. Goal Proximity Cost (Primary)
            goal_proximity_cost = torch.norm(target_latent_state - simulated_next_latent)

            # 2. Conceptual Fuel Cost
            conceptual_fuel_cost = torch.rand(1).to(device) * 0.01

            # 3. Conceptual Environmental Cost
            conceptual_weather_cost = torch.rand(1).to(device) * 0.02

            # 4. Conceptual Regulatory/Efficiency Cost
            conceptual_efficiency_cost = torch.rand(1).to(device) * 0.005

            total_cost = goal_proximity_cost + conceptual_fuel_cost + conceptual_weather_cost + conceptual_efficiency_cost
            simulated_trajectories_cost.append(total_cost)

        best_candidate_idx = torch.argmin(torch.tensor(simulated_trajectories_cost).squeeze())
        optimal_action_for_step = candidate_actions[best_candidate_idx]

        best_action_sequence.append(optimal_action_for_step.squeeze().cpu().numpy())

        with torch.no_grad():
            current_latent_state = predictor_model(current_latent_state, optimal_action_for_step.unsqueeze(0))

    print(f"Conceptual Plan for {planning_horizon} steps (first 5 actions shown):")
    for i, action in enumerate(best_action_sequence[:5]):
        print(f"Step {i+1}: {action}")
    print("-- Conceptual Planning Complete ---")

    return best_action_sequence

print("\n--Cell 4: Starting Conceptual PLDM Planning for Montreal to Paris ---")

start_airport_data = AIRPORTS["CYUL"]
target_airport_data = AIRPORTS["LFPG"]
aircraft_model_data = AIRCRAFT_PERFORMANCE["Boeing777_300ER"]

# --- FIX: Pass the new latent_projector to the planning function ---
conceptual_flight_plan_actions = plan_montreal_to_paris_flight(
    start_airport_data,
    target_airport_data,
    aircraft_model_data,
    model,
    processor,
    predictor,
    latent_projector # Pass the projector instance
)

print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")
print("\nCell 4 execution complete.")


--Cell 4: Starting Conceptual PLDM Planning for Montreal to Paris ---

--- Starting Conceptual Flight Plan: Montreal (CYUL) to Paris (LFPG) ---


Conceptual Current Latent State Shape: torch.Size([1, 1024])
Conceptual Target Latent State Shape: torch.Size([1, 1024])


Conceptual Plan for 50 steps (first 5 actions shown):
Step 1: [0.7455744  0.13278925 0.09921384 0.10516822 0.25165814 0.12890929
 0.16943347 0.00627178]
Step 2: [0.03834724 0.12711555 0.22410339 0.880904   0.2624029  0.28845394
 0.09839606 0.20951557]
Step 3: [0.10010779 0.14281225 0.20694387 0.03428817 0.8536696  0.48601228
 0.03118342 0.12214506]
Step 4: [0.19816    0.6673281  0.33317983 0.01353353 0.4531961  0.046157
 0.01012588 0.52020943]
Step 5: [0.11279976 0.21449488 0.11356294 0.09161949 0.20542902 0.56288326
 0.02422994 0.07400107]
-- Conceptual Planning Complete ---
Current time in EST: 2025-07-24 11:43:32 EST

Cell 4 execution complete.
