<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/JEPA_PLDCpynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Thu Jul 24 10:44:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P0             46W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
!pip install av -q

In [1]:
!cp -pr /content/airplane-landing.mp4 /content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/

In [None]:
!pip install colab-env -q
import colab_env

In [8]:
!ls -lta /content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/

total 22538
-rw-------+ 1 root root 23070474 Jul 24 10:47 airplane-landing.mp4
-rw-------  1 root root     6986 Jul 23 15:37 1_2023-02-22-15-21-49_feature.pt
-rw-------  1 root root      203 Jul 23 15:37 feature_label_map.json


## Cell 1: All Setup, Definitions, and Model Instantiations

In [3]:
# Cell 1: All Setup, Definitions, and Model Instantiations
# Run this cell completely after any kernel restart.

import torch
import numpy as np
import os
import glob
import av
import json
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoVideoProcessor, AutoModel
from tqdm.auto import tqdm

# Logging setup
import logging
import datetime
import pytz
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Import Google Generative AI components
from google.colab import userdata
import google.generativeai as genai

# --- API Key Setup ---
GOOGLE_API_KEY = userdata.get('GEMINI')
if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)
    print("Google Generative AI configured successfully using Colab Secrets.")
else:
    print("WARNING: GOOGLE_API_KEY not found in Colab Secrets. Please ensure 'GEMINI' secret is set.")
    print("API calls will likely fail. Proceeding with unconfigured API.")

# --- Agent Configuration ---
class AgentConfig:
    LLM_MODEL_NAME: str = "gemini-2.5-flash"

# Define class labels
CLASS_LABELS = [
    "airplane landing",
    "airplane takeoff",
    "airport ground operations",
    "in-flight cruise",
    "emergency landing",
    "pre-flight check/maintenance"
]
num_classes = len(CLASS_LABELS)
CLASSIFIER_SAVE_PATH = "classifier_head_trained_on_tartan_aviation_sample.pth"

# --- Configuration for V-JEPA and Dataset Paths ---
hf_repo = "facebook/vjepa2-vitg-fpc64-256"
EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"

# Global parameters for expected V-JEPA output dimensions for PLDM.
# For V-JEPA [1, 2048, 1408], if flattened for PLDM, this is 2048 * 1408.
TOTAL_FLATTENED_VJEPA_DIM_FOR_PLDM = 2048 * 1408
latent_dim_pldm = TOTAL_FLATTENED_VJEPA_DIM_FOR_PLDM # For PLDM Predictor input
action_dim = 8

# --- Helper Function for Video Loading and Feature Extraction (Returns raw V-JEPA output) ---
def load_and_process_video(video_path, processor_instance, model_instance, device_instance, num_frames_to_sample=16):
    """
    Loads a video, samples frames, and extracts V-JEPA features.
    Returns extracted features (torch.Tensor, shape like [1, 2048, 1408]) and the list of raw frames (list).
    Does NOT flatten the V-JEPA output here, keeping it as model's raw output.
    """
    frames = []
    if not os.path.exists(video_path):
        logging.error(f"ERROR: Video file '{video_path}' not found.")
        return None, None

    try:
        container = av.open(video_path)
        total_frames_in_video = container.streams.video[0].frames
        sampling_interval = max(1, total_frames_in_video // num_frames_to_sample)

        logging.info(f"Total frames in video: {total_frames_in_video}")
        logging.info(f"Sampling interval: {sampling_interval} frames")

        for i, frame in enumerate(container.decode(video=0)):
            if len(frames) >= num_frames_to_sample:
                break
            if i % sampling_interval == 0:
                img = frame.to_rgb().to_ndarray()
                frames.append(img)

        if not frames:
            logging.error(f"ERROR: No frames could be loaded from '{video_path}'.")
            return None, None
        elif len(frames) < num_frames_to_sample:
            logging.warning(f"WARNING: Only {len(frames)} frames loaded. Model might perform suboptimally.")

        inputs = processor_instance(videos=list(frames), return_tensors="pt")

        # Move inputs to device before model inference
        inputs = {k: v.to(device_instance) for k, v in inputs.items()}

        with torch.no_grad():
            features = model_instance(**inputs).last_hidden_state # Keep features in raw V-JEPA output shape

        logging.info(f"Successfully extracted V-JEPA features with raw shape: {features.shape}")
        return features, frames

    except av.FFmpegError as e:
        logging.error(f"Error loading video with PyAV: {e}")
        logging.error("This might indicate an issue with the video file itself or PyAV installation.")
        return None, None
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        logging.error("Ensure 'av' library is installed (`pip install av`) and that the video file is not corrupted.")
        return None, None

# --- Define Classifier Head (using the working logic for input_dim) ---
class ClassifierHead(nn.Module):
    def __init__(self, input_dim, num_classes): # input_dim will be the pooled feature dim (1408)
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        return self.fc2(self.dropout(self.relu(self.fc1(x))))

# Define LatentDynamicsPredictor (still expects total flattened dim for PLDM)
class LatentDynamicsPredictor(torch.nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(latent_dim + action_dim, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, latent_dim)
        )

    def forward(self, latent_state, action):
        combined_input = torch.cat([latent_state, action], dim=-1)
        predicted_next_latent_state = self.layers(combined_input)
        return predicted_next_latent_state

# --- Instantiate Models and Optimizers (Done only once in Cell 1) ---
print("\n--- Instantiating Models and Optimizers ---")
model = AutoModel.from_pretrained(hf_repo)
processor = AutoVideoProcessor.from_pretrained(hf_repo)

# Determine device for all computations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move V-JEPA model to device

# Initialize PLDM Predictor and its optimizer
predictor = LatentDynamicsPredictor(latent_dim_pldm, action_dim) # Use PLDM specific latent_dim
predictor.to(device) # Move predictor to device
optimizer_pldm = torch.optim.Adam(predictor.parameters(), lr=0.001)

# Initial Classifier instantiation (input_dim will be dynamically set in Cell 2)
# For now, initialize with a dummy value, it will be correctly re-initialized later.
# Or, if running the reference code, it expects 1408, so let's use that for initial Classifier instantiation.
classifier = ClassifierHead(input_dim=1408, num_classes=num_classes) # Assuming 1408 is the pooled feature dim
classifier.to(device) # Move classifier to device

print(f"Models instantiated and moved to {device}.")
print("\nCell 1 setup complete. Remember to run this cell first after any kernel restart before running subsequent cells.")

Google Generative AI configured successfully using Colab Secrets.

--- Instantiating Models and Optimizers ---
Models instantiated and moved to cuda.

Cell 1 setup complete. Remember to run this cell first after any kernel restart before running subsequent cells.


## Cell 2: Core Execution - Feature Extraction, Classifier Training & Inference, LLM Interaction, and PLDM Training/Planning

In [None]:
# Cell 2: Core Execution - Feature Extraction, Classifier Training & Inference, LLM Interaction, PLDM Training & Planning
# This cell assumes Cell 1 has been successfully executed in the current session.
# All objects (model, processor, classifier, predictor, device, optimizer_pldm)
# and all function definitions (load_and_process_video, ClassifierHead, LatentDynamicsPredictor)
# are expected to be available from Cell 1's execution.

import os
import logging
import torch
import json
from google.colab import drive
from tqdm.auto import tqdm # Ensure tqdm is imported for progress bars
import torch.optim as optim # For classifier optimizer
from torch.utils.data import DataLoader, TensorDataset # For classifier training data
import datetime # For LLM timestamp
import pytz # For LLM timestamp

# --- Mounting Google Drive ---
print("\n--- Cell 2: Mounting Google Drive for dataset access ---")
drive.mount('/content/gdrive')
print("Google Drive mounted.")

print(f"Checking for extracted features directory: {EXTRACTED_FEATURES_DIR}")
if not os.path.exists(EXTRACTED_FEATURES_DIR):
    logging.error(f"ERROR: Extracted features directory '{EXTRACTED_FEATURES_DIR}' not found. Please ensure Google Drive is mounted and path is correct.")
    exit() # Exit if critical directory is not found.
else:
    print(f"Extracted features directory found at {EXTRACTED_FEATURES_DIR}.")

# --- Part 1: Load and process 'airplane-landing.mp4' for initial observation and Feature Extraction ---
print(f"\n--- Cell 2: Part 1 - Loading actual video '/content/airplane-landing.mp4' for Feature Extraction ---")
#flight_video_path = '/content/airplane-landing.mp4'
flight_video_path = '/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/airplane-landing.mp4'

# Use the defined load_and_process_video helper function. It now returns RAW V-JEPA output.
video_features_for_inference_raw, frames_for_pldm_planning = load_and_process_video(flight_video_path, processor, model, device)

# --- CRITICAL: Process raw V-JEPA features to match ClassifierHead's expected input (pooling to 1408) ---
if video_features_for_inference_raw is not None:
    # V-JEPA output shape is typically [1, 2048, 1408] (Batch, Channels, Sequence_Length)
    # Your old code pooled it as .squeeze(0).mean(dim=0).unsqueeze(0), which results in [1, 1408]
    # So, extracted_embedding_dim should be 1408 for the classifier.
    pooled_features_for_classifier = video_features_for_inference_raw.squeeze(0).mean(dim=0).unsqueeze(0).to(device)
    extracted_embedding_dim_for_classifier = pooled_features_for_classifier.shape[1] # This will be 1408
    logging.info(f"Dynamically determined extracted_embedding_dim for Classifier: {extracted_embedding_dim_for_classifier}")
else:
    pooled_features_for_classifier = None
    extracted_embedding_dim_for_classifier = -1
    logging.error("Failed to extract video features for classifier. Exiting Cell 2.")
    exit() # Exit if critical features are not loaded

# --- Part 2: Classifier Training ---
print(f"\n--- Cell 2: Part 2 - Starting Classifier Training ---")
print(f"Attempting to load real V-JEPA features for classifier training from: {EXTRACTED_FEATURES_DIR}")

print(f"Using device for classifier training: {device}") # 'device' is global from Cell 1

try:
    # Re-initialize classifier with the correct, dynamically determined input dimension (1408)
    classifier = ClassifierHead(input_dim=extracted_embedding_dim_for_classifier, num_classes=num_classes).to(device)

    train_features_list = []
    train_labels_list = []

    map_file_path = os.path.join(EXTRACTED_FEATURES_DIR, "feature_label_map.json")

    if not os.path.exists(map_file_path):
        logging.warning(f"Feature-label map file '{map_file_path}' not found. Classifier will be trained on SYNTHETIC data as a fallback.")
        feature_label_map = {} # Treat as empty to trigger synthetic fallback
    else:
        with open(map_file_path, 'r') as f:
            feature_label_map = json.load(f)

    if not feature_label_map:
        logging.warning(f"Feature-label map at {map_file_path} is empty. Classifier will be trained on SYNTHETIC data as a fallback.")
        num_training_samples = 2_000_000
        # Synthetic data generation uses the dynamically determined input_dim (1408)
        train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier).to(device)
        train_labels = torch.randint(0, num_classes, (num_training_samples,)).to(device)
        train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
        val_loader = None
        print(f"Loaded {num_training_samples} SYNTHETIC features for training.")

    else:
        for item in tqdm(feature_label_map, desc="Loading real V-JEPA features"):
            feature_path = item['feature_path']
            label_idx = item['label_idx']
            try:
                if not os.path.isabs(feature_path):
                    feature_path = os.path.join(EXTRACTED_FEATURES_DIR, os.path.basename(feature_path))

                if not os.path.exists(feature_path):
                    logging.warning(f"Feature file not found at {feature_path}. Skipping.")
                    continue

                feature = torch.load(feature_path, map_location=device)

                # --- CRITICAL FIX: Match your working code's pooling/squeezing logic for saved features ---
                # Your old working code used squeeze(0).mean(dim=0).unsqueeze(0) to get [1, 1408]
                # We need to ensure loaded 'feature' from disk ends up as a 1D vector of 1408 elements.
                if feature.ndim == 3: # Common V-JEPA output [1, 2048, 1408] or [Batch, Channels, SeqLen]
                    feature = feature.squeeze(0).mean(dim=0) # -> [2048, 1408] -> mean(dim=0) -> [1408]
                elif feature.ndim == 2: # Could be [2048, 1408] or [1, 1408] if already processed
                    if feature.shape[0] == 1 and feature.shape[1] == 1408: # If [1, 1408]
                        feature = feature.squeeze(0) # -> [1408]
                    elif feature.shape[1] == 1408: # If [X, 1408], assume X is channels, pool them
                        feature = feature.mean(dim=0) # -> [1408]
                    else: # If unexpected 2D, try aggressive flatten (unlikely for pooled feature)
                        feature = feature.flatten() # This might be wrong if feature is already very small
                elif feature.ndim == 1: # Already [1408]
                    pass
                else:
                    logging.warning(f"Skipping malformed feature at {feature_path}. Unexpected ndim: {feature.ndim}. Got {feature.shape}. Skipping.")
                    continue

                # Final check after processing. Should be 1D with 1408 elements.
                if feature.shape[0] != extracted_embedding_dim_for_classifier:
                     logging.warning(f"Skipping feature at {feature_path}. Expected dimension {extracted_embedding_dim_for_classifier}, but got {feature.shape[0]} after processing. Skipping.")
                     continue

                train_features_list.append(feature)
                train_labels_list.append(label_idx)
            except Exception as e:
                logging.error(f"Error loading feature from {feature_path}: {e}")

        if train_features_list:
            train_features = torch.stack(train_features_list).to(device) # Stack into [N, 1408]
            train_labels = torch.tensor(train_labels_list).to(device)
            num_training_samples = len(train_features)
            print(f"Loaded {num_training_samples} REAL V-JEPA features for training.")

            if num_training_samples < 2:
                print("WARNING: Only 1 real V-JEPA feature loaded. Training will be performed on this single sample (no train/val split).")
                train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=1, shuffle=False)
                val_loader = None
            else:
                dataset_size = len(train_features)
                train_size = int(0.8 * dataset_size)
                val_size = dataset_size - train_size
                if val_size == 0 and train_size > 0:
                    train_size = dataset_size
                    train_dataset_real = TensorDataset(train_features, train_labels)
                    val_dataset_real = None
                else:
                    train_dataset_real, val_dataset_real = torch.utils.data.random_split(TensorDataset(train_features, train_labels), [train_size, val_size])

                train_loader = DataLoader(train_dataset_real, batch_size=32, shuffle=True)
                val_loader = DataLoader(val_dataset_real, batch_size=32, shuffle=False) if val_dataset_real else None

                print(f"Training on {len(train_dataset_real)} samples, Validating on {len(val_dataset_real)} samples." if val_loader else f"Training on {len(train_dataset_real)} samples (No separate validation data).")
        else:
            logging.error("No real V-JEPA features could be loaded from map. Falling back to synthetic data.")
            num_training_samples = 2_000_000
            train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier).to(device)
            train_labels = torch.randint(0, num_classes, (num_training_samples,)).to(device)
            train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
            val_loader = None
            print(f"Loaded {num_training_samples} SYNTHETIC features for training (due to real data load failure).")

    criterion = torch.nn.CrossEntropyLoss()
    optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=0.001)

    num_epochs = 20
    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer_classifier.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_classifier.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader.dataset)

        val_loss = 0.0
        if val_loader and len(val_loader.dataset) > 0:
            classifier.eval()
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = classifier(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader.dataset)
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        else:
             logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f} (No validation data)")

    print("--- Classifier Training Complete ---")

    torch.save(classifier.state_dict(), CLASSIFIER_SAVE_PATH)
    print(f"Classifier saved to: {CLASSIFIER_SAVE_PATH}")

except Exception as e:
    logging.error(f"Error during classifier training: {e}")

In [5]:
# --- Part 3: Classification Inference and Gemini LLM Interaction ---
print("\n--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---")

if pooled_features_for_classifier is None: # Use pooled_features_for_classifier here
    logging.error("ERROR: Cannot perform classifier inference as 'pooled_features_for_classifier' is None. Check video loading in Part 1.")
else:
    try:
        pooled_features_for_inference_on_device = pooled_features_for_classifier.to(device) # Ensure it's on device

        # Load the just-trained classifier's weights (redundant but safe after saving)
        classifier.load_state_dict(torch.load(CLASSIFIER_SAVE_PATH, map_location=device))
        logging.info(f"Classifier weights loaded from: {CLASSIFIER_SAVE_PATH}")

        classifier.eval()
        with torch.no_grad():
            logits = classifier(pooled_features_for_inference_on_device)
            probabilities = torch.softmax(logits, dim=1)
            predicted_class_idx = torch.argmax(probabilities, dim=1).item()
            predicted_confidence = probabilities[0, predicted_class_idx].item()
            predicted_label = CLASS_LABELS[predicted_class_idx]

        llm_input_description = ""
        if predicted_label == "airplane landing":
            llm_input_description = "The visual system detected an airplane landing. "
        elif predicted_label == "airplane takeoff":
            llm_input_description = "The visual system detected an airplane takeoff. "
        elif predicted_label == "airport ground operations":
            llm_input_description = "The visual system detected airport ground operations. "
        elif predicted_label == "in-flight cruise":
            llm_input_description = "The visual system detected an airplane in flight/cruise. "
        elif predicted_label == "emergency landing":
            llm_input_description = "The visual system detected a possible emergency landing scenario. "
        elif predicted_label == "pre-flight check/maintenance":
            llm_input_description = "The visual system detected pre-flight checks or maintenance activities. "
        else:
            llm_input_description = "The visual system detected an unrecognized or ambiguous aviation event. "

        llm_input_description += f"(Confidence: {predicted_confidence:.2f})"


        print(f"\n--- AI Agent's Understanding from Classifier ---")
        print(f"**Primary Classification (Predicted by AI):** '{predicted_label}' (Confidence: {predicted_confidence:.2f})")
        print(f"**Description for LLM:** {llm_input_description}")
        print(f"Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.")

        print("\n--- Engaging Gemini LLM for Further Reasoning ---")
        try:
            # You would likely want to re-instantiate LLM model here for self-containment
            # if this cell might run independently, or ensure genai is configured from Cell 1.
            llm_model = genai.GenerativeModel(AgentConfig.LLM_MODEL_NAME)

            prompt_for_gemini = f"""
            You are an AI assistant for flight planning operations.
            Current visual observation: {llm_input_description}
            Current time (EST): {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}

            Based on this visual observation, provide a concise operational assessment relevant for flight planning.
            If the observation seems random or uncertain, state that. Do not invent details not present in the observation.
            """

            gemini_response = llm_model.generate_content(prompt_for_gemini)

            print("\n--- Gemini LLM Response ---")
            if gemini_response.candidates:
                for candidate in gemini_response.candidates:
                    if candidate.content and candidate.content.parts:
                        for part in candidate.content.parts:
                            print(part.text)
            else:
                print("Gemini LLM did not provide a text response or candidates.")
                if gemini_response.prompt_feedback:
                    print(f"Prompt Feedback: {gemini_response.prompt_feedback}")
                if hasattr(gemini_response, 'error'):
                    print(f"LLM Error: {gemini_response.error}")


        except Exception as llm_e:
            logging.error(f"Error interacting with Gemini LLM: {llm_e}")
            logging.error("Ensure your GOOGLE_API_KEY is correctly set in Colab Secrets and the model name is valid.")


        print(f"\nThis prediction comes from a classifier that was trained on the provided V-JEPA features.")
        print(f"For *truly accurate and high-confidence predictions* on real videos,")
        print(f"the classifier needs to be trained on a large, diverse dataset of *real V-JEPA features and their corresponding labels*.")
        print(f"The V-JEPA features (shape: {pooled_features_for_inference_on_device.shape}) are the core input that a trained classifier would learn from.")
        print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")

    except Exception as e:
        logging.error(f"Error during classification inference or overall process: {e}")

print("\nCell 2 execution complete.")


--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---

--- AI Agent's Understanding from Classifier ---
**Primary Classification (Predicted by AI):** 'airplane landing' (Confidence: 1.00)
**Description for LLM:** The visual system detected an airplane landing. (Confidence: 1.00)
Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.

--- Engaging Gemini LLM for Further Reasoning ---

--- Gemini LLM Response ---
A visual observation of an airplane landing with 1.00 confidence indicates current runway occupancy and active air traffic at the observed location. This is relevant for assessing real-time traffic flow and potential sequencing requirements for inbound or outbound flights.

This prediction comes from a classifier that was trained on the provided V-JEPA features.
For *truly accurate and high-confidence predictions* on real videos,
the classifier needs to 

## Cell 3: Conceptual PLDM Latent Dynamics Training

In [None]:
# Cell 3: Conceptual PLDM Latent Dynamics Training
# This cell assumes Cell 1 and Cell 2 have been successfully executed in the current session.
# It uses:
# - 'predictor' (instantiated in Cell 1)
# - 'optimizer_pldm' (instantiated in Cell 1)
# - 'device' (determined in Cell 1)
# - 'latent_dim' (from Cell 1, for dimensions)
# - 'action_dim' (from Cell 1, for dimensions)
# - 'EXTRACTED_FEATURES_DIR' (from Cell 1, for data path)

import os
import logging
import torch
from tqdm.auto import tqdm # For progress bars

# --- Re-Declare Configuration Variables and Global Objects needed in this cell ---
# These must match the configuration from Cell 1 (for self-containment/robustness)
EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"
TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408 # From Cell 1
latent_dim = TOTAL_FLATTENED_VJEPA_DIM # Ensure this matches Cell 1's definition
action_dim = 8 # From Cell 1

# Determine device (re-determined for self-containment)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# --- Step 5.1: Data Loading for Latent Dynamics Model Training ---
def load_flight_dynamics_training_data(features_dir):
    """
    Function to load offline reward-free flight data for Latent Dynamics Model training.
    This assumes your features_dir contains sequences of (latent_s_t, action_t, latent_s_t_plus_1).
    """
    print(f"Loading reward-free offline flight dynamics data from: {features_dir}...")

    # --- CRITICAL: REPLACE THIS DUMMY LOGIC WITH YOUR ACTUAL DATA LOADING ---
    # This must load real (s_t, a_t, s_t_plus_1) triplets from your TartanAviation dataset.
    # The 's_t' and 's_t_plus_1' should be the V-JEPA features, which are now expected to be flattened.
    # Make sure to load/process them into [1, D] (batch of 1, D features) or just [D] (single vector).

    dynamics_training_data = []
    num_dynamics_samples = 1000 # Example: A larger number of samples might be needed for dynamics

    latent_vec_dim = TOTAL_FLATTENED_VJEPA_DIM # Use the global flattened dim
    action_vec_dim = action_dim # Use the global action dim

    for i in range(num_dynamics_samples):
        # Generate dummy data with correct dimensions
        latent_s_t = torch.rand(1, latent_vec_dim) # Example: [1, 2883584]
        action_t = torch.rand(1, action_vec_dim) # Example: [1, 8]
        latent_s_t_plus_1 = torch.rand(1, latent_vec_dim) # Example: [1, 2883584]
        dynamics_training_data.append((latent_s_t, action_t, latent_s_t_plus_1))

    print(f"Loaded {len(dynamics_training_data)} dynamics training samples from {features_dir}")
    return dynamics_training_data

# --- Step 5.2: Training Loop for Latent Dynamics Model ---
def train_latent_dynamics_model(predictor_model, optimizer, training_data, epochs=5):
    """
    Training loop for the latent dynamics model using pre-extracted features, with tqdm progress bar.
    This function currently only implements the core L_sim from the PLDM paper.
    """
    predictor_model.train()

    print("\n--- Training Latent Dynamics Predictor ---")
    for epoch in range(epochs):
        total_loss = 0

        for batch_idx, (latent_s_t, action_t, latent_s_t_plus_1) in tqdm(
            enumerate(training_data),
            total=len(training_data),
            desc=f"Epoch {epoch+1}/{epochs}"
        ):
            # Move data to device (important for both training types)
            latent_s_t, action_t, latent_s_t_plus_1 = latent_s_t.to(device), action_t.to(device), latent_s_t_plus_1.to(device)

            predicted_z_t_plus_1 = predictor_model(latent_s_t, action_t)

            loss = torch.nn.functional.mse_loss(predicted_z_t_plus_1, latent_s_t_plus_1)

            # --- CRITICAL: Placeholder for full PLDM loss components ---
            # You MUST implement these for robust dynamics learning, as per Appendix C.1.1 in the paper.
            # Example:
            # L_var_val = calculate_L_var(predicted_z_t_plus_1) # You'd need a function for this
            # L_cov_val = calculate_L_cov(predicted_z_t_plus_1)
            # L_IDM_val = calculate_L_IDM(latent_s_t, predicted_z_t_plus_1, action_t)
            # L_time_sim_val = calculate_L_time_sim(latent_s_t, predicted_z_t_plus_1)
            # loss = loss + alpha * L_var_val + beta * L_cov_val + delta * L_time_sim_val + omega * L_IDM_val
            # ---------------------------------------------------

            optimizer_pldm.zero_grad() # Use the PLDM specific optimizer
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {total_loss / len(training_data):.4f}")

# --- Main Execution Flow for Cell 3 ---
print("\n--- Cell 3: Starting Conceptual PLDM Latent Dynamics Training ---")

# Load reward-free offline flight data for dynamics model training
# 'EXTRACTED_FEATURES_DIR' is from Cell 1.
dynamics_training_data = load_flight_dynamics_training_data(EXTRACTED_FEATURES_DIR)

# Train the latent dynamics predictor. 'predictor' and 'optimizer_pldm' are global from Cell 1.
train_latent_dynamics_model(predictor, optimizer_pldm, dynamics_training_data, epochs=2)

print("\nCell 3 execution complete.")

## Cell 4: Conceptual PLDM Planning

In [7]:
# Cell 4: Conceptual PLDM Planning
# This cell assumes Cell 1, Cell 2, and Cell 3 have been successfully executed in the current session.
# All objects (model, processor, predictor, device, frames_for_pldm_planning, TOTAL_FLATTENED_VJEPA_DIM, action_dim)
# are expected to be available from Cell 1, Cell 2, and Cell 3 execution.

import logging
import torch
import numpy as np # For .cpu().numpy()
import datetime
import pytz # For time stamping
from tqdm.auto import tqdm # For progress bars

# --- Re-Declare Configuration Variables and Global Objects needed in this cell ---
# These are re-declared for robustness in case of partial session state.
# They should match the configuration from Cell 1.
TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408
latent_dim = TOTAL_FLATTENED_VJEPA_DIM
action_dim = 8 # From Cell 1

# Determine device (re-determined for self-containment)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Step 5.3: Conceptual Planning at Test Time (Mock MPPI Algorithm) ---
def plan_flight_path(current_video_frames, target_goal_video_frames, encoder_model, predictor_model, processor_instance, device_instance, planning_horizon=10):
    """
    Conceptual function to plan a flight path using the learned latent dynamics.
    This is a mock MPPI implementation demonstrating the flow, not true optimality.
    """
    encoder_model.eval()
    predictor_model.eval()

    current_inputs = processor_instance(videos=list(current_video_frames), return_tensors="pt")
    target_inputs = processor_instance(videos=list(target_goal_video_frames), return_tensors="pt")

    with torch.no_grad():
        # Move inputs to device before model inference
        current_inputs = {k: v.to(device_instance) for k, v in current_inputs.items()}
        target_inputs = {k: v.to(device_instance) for k, v in target_inputs.items()}

        # Encode and flatten features for planning
        current_latent_state = encoder_model(**current_inputs).last_hidden_state.flatten(start_dim=1)
        target_latent_state = encoder_model(**target_inputs).last_hidden_state.flatten(start_dim=1)

    print("\n--- Starting Flight Path Planning ---")
    print(f"Current Latent State Shape: {current_latent_state.shape}")
    print(f"Target Latent State Shape: {target_latent_state.shape}")

    best_action_sequence = []

    # --- Mock MPPI-like Planning Logic ---
    num_action_samples = 50 # Number of action sequences to sample at each step

    for step in range(planning_horizon):
        candidate_actions = torch.rand(num_action_samples, action_dim).to(device_instance) # Sample actions

        simulated_trajectories_cost = []

        for i in range(num_action_samples):
            # Simulate one step in latent space using the predictor
            simulated_next_latent = predictor_model(current_latent_state, candidate_actions[i].unsqueeze(0))

            # Calculate mock cost: distance to goal + a mock uncertainty/deviation cost
            # In real MPPI, cost would factor in predicted trajectory, not just one step
            goal_cost = torch.norm(target_latent_state - simulated_next_latent) # Distance from goal
            mock_uncertainty_cost = torch.rand(1).to(device_instance) * 0.1 # Dummy uncertainty cost

            total_cost = goal_cost + mock_uncertainty_cost
            simulated_trajectories_cost.append(total_cost)

        # Select the action from the lowest-cost trajectory
        best_candidate_idx = torch.argmin(torch.tensor(simulated_trajectories_cost))
        optimal_action_for_step = candidate_actions[best_candidate_idx]

        best_action_sequence.append(optimal_action_for_step.squeeze().cpu().numpy())

        # Update current latent state based on the selected action (for receding horizon control)
        with torch.no_grad():
            current_latent_state = predictor_model(current_latent_state, optimal_action_for_step.unsqueeze(0)) # Advance state

    print(f"Conceptual Plan (first {planning_horizon} steps of actions): {best_action_sequence}")
    print("--- Planning Complete ---")
    return best_action_sequence

# --- Main Execution Flow for Cell 4 ---
print("\n--- Cell 4: Starting Conceptual PLDM Planning ---")

# Access frames_for_pldm_planning (should be available from Cell 2 after video loading)
if 'frames_for_pldm_planning' not in locals() or frames_for_pldm_planning is None:
    logging.error("ERROR: 'frames_for_pldm_planning' not found or is None. Cannot perform PLDM planning without video frames from Cell 2. Exiting Cell 4.")
    exit()

# Prepare dummy goal video for planning.
target_goal_dummy_video = torch.rand(16, 3, 256, 256)

# Plan the flight path using the trained dynamics model.
# 'model', 'processor', 'predictor' are global from Cell 1.
plan_flight_path(frames_for_pldm_planning, target_goal_dummy_video, model, predictor, processor, device, planning_horizon=5)

print("\n--- Summary of Next Steps for a Practical Flight Planning AI ---")
print("1. **Crucial for Dynamics Training:** Replace `load_flight_dynamics_training_data` with detailed loading logic")
print("   to correctly parse your *actual* saved V-JEPA features and associated *actions* into (s_t, a_t, s_t_plus_1) triplets for the dynamics model.")
print("2. **Crucial for Classifier Accuracy:** Ensure `classifier_head_trained_on_tartan_aviation_sample.pth` exists and was trained on a large, diverse dataset of *real V-JEPA features and corresponding labels*.")
print("3. **Essential for Dynamics Robustness:** Refine `LatentDynamicsPredictor` architecture (e.g., use ConvPredictor if appropriate for visual features) and implement **full PLDM loss functions** (variance, covariance, inverse dynamics, time-smoothness) as detailed in the paper.")
print("4. **Essential for Action Generation:** Replace the mock planning in `plan_flight_path` with a robust MPPI algorithm implementation.")
print("5. **Integration:** Integrate this entire pipeline with a flight simulator or real-world sensor data for actual control and evaluation.")
print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")

print("\nCell 4 execution complete.")


--- Cell 4: Starting Conceptual PLDM Planning ---

--- Starting Flight Path Planning ---
Current Latent State Shape: torch.Size([1, 2883584])
Target Latent State Shape: torch.Size([1, 2883584])
Conceptual Plan (first 5 steps of actions): [array([0.6733031 , 0.6641214 , 0.67408574, 0.64152807, 0.7808706 ,
       0.5046251 , 0.8297563 , 0.91077536], dtype=float32), array([0.98800325, 0.90408134, 0.7151991 , 0.5639811 , 0.55738723,
       0.12075198, 0.90880024, 0.9048785 ], dtype=float32), array([0.9866971 , 0.37695998, 0.12013483, 0.19550359, 0.6606016 ,
       0.36345565, 0.96768445, 0.87099475], dtype=float32), array([0.29575574, 0.55870414, 0.38753158, 0.6602698 , 0.63869816,
       0.07558197, 0.9935388 , 0.9712309 ], dtype=float32), array([0.64405507, 0.598058  , 0.7225516 , 0.6947395 , 0.5091697 ,
       0.43505955, 0.5924359 , 0.8532961 ], dtype=float32)]
--- Planning Complete ---

--- Summary of Next Steps for a Practical Flight Planning AI ---
1. **Crucial for Dynamics Trainin