<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/JEPA_AGI_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. World models that can predict and reason about real situations, not just text (LeCun’s Joint Embedding Predictive Architecture, 2022).
2. Autonomous learning that discovers causal structure instead of memorizing patterns.
3. Energy-based or modular systems that reason, plan, and act coherently within physical and ethical boundaries.
4. Embodied sentience and salience — systems grounded in sensory experience, capable of focusing on what truly matters and aligning ethically with human values.
5. Cognitive world models and evolutionary learning modules — hybrid systems that combine:
• Common-sense reasoning about space, time, and agency,
• Evolutionary and meta-learning algorithms that improve over generations of experience, and
• Analog–digital integration layers that bridge symbolic reasoning with continuous perception.

In [None]:
!pip install av -q

In [2]:
!nvidia-smi

Mon Oct 20 20:21:57 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   35C    P0             52W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [14]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
!df -h /content/gdrive

Filesystem      Size  Used Avail Use% Mounted on
drive           236G   50G  187G  21% /content/gdrive


In [5]:
!git clone https://github.com/castacks/TartanAviation.git
%cd TartanAviation

Cloning into 'TartanAviation'...
remote: Enumerating objects: 180, done.[K
remote: Counting objects: 100% (180/180), done.[K
remote: Compressing objects: 100% (141/141), done.[K
remote: Total 180 (delta 97), reused 101 (delta 33), pack-reused 0 (from 0)[K
Receiving objects: 100% (180/180), 4.04 MiB | 17.91 MiB/s, done.
Resolving deltas: 100% (97/97), done.
/content/TartanAviation


In [7]:
%cd /content

/content


In [None]:
!pip install minio boto3 -q
!apt-get install -y unzip ffmpeg

## TartanAviation-vision

In [None]:
!cd /content/TartanAviation/vision

In [None]:
!python /content/TartanAviation/vision/download.py --save_dir /content/gdrive/MyDrive/datasets/TartanAviation/vision --option Sample

In [None]:
!python /content/TartanAviation/vision/download.py --save_dir /content/gdrive/MyDrive/datasets/TartanAviation/vision --option Sample --extract_frames

In [31]:
!ls -lh /content/gdrive/MyDrive/datasets/TartanAviation/vision/
!ls -lh /content/gdrive/MyDrive/datasets/TartanAviation/vision/1_2023-02-22-15-21-49/

total 79K
drwx------ 3 root root 4.0K Oct 20 10:24 1_2023-02-22-15-21-49
-rw------- 1 root root  11K Jul 23 12:43 dataloader.py
drwx------ 2 root root 4.0K Jul 23 14:08 downloaded_recordings
-rw------- 1 root root 7.2K Jul 23 12:43 download.py
-rw------- 1 root root 6.0K Jul 23 12:43 progress.py
drwx------ 2 root root 4.0K Jul 23 12:49 __pycache__
-rw------- 1 root root 3.4K Jul 23 12:43 README.md
drwx------ 2 root root 4.0K Jul 23 12:43 recording
-rw------- 1 root root  35K Jul 23 12:43 weather_stats.csv
total 2.0G
-rw------- 1 root root  16K Aug 26  2023 1_2023-02-22-15-21-49_acft_sink.pkl
-rw------- 1 root root  37K Aug 26  2023 1_2023-02-22-15-21-49_labels.zip
-rw------- 1 root root 890M Aug 26  2023 1_2023-02-22-15-21-49.mp4
drwx------ 2 root root 4.0K Oct 20 10:29 1_2023-02-22-15-21-49_sink
-rw------- 1 root root  234 Aug 26  2023 1_2023-02-22-15-21-49_sink_adsb.pkl
-rw------- 1 root root 1.1G Aug 26  2023 1_2023-02-22-15-21-49_sink_verified.avi
-rw------- 1 root root 266K Aug 26

## TartanAviation-adsb

In [16]:
!ls -lh /content/gdrive/MyDrive/datasets/TartanAviation/adsb/kbtp/raw/2022/

total 700M
-rw------- 1 root root 700M Oct 20 11:25 2022.zip


In [17]:
!mkdir -p /content/adsb/kbtp/raw/2022

In [None]:
%cd /content/gdrive/MyDrive/datasets/TartanAviation/adsb/kbtp/raw/2022
!unzip 2022.zip -d /content/adsb/kbtp/raw/2022/

In [4]:
# Please install OpenAI SDK first: `pip3 install openai`

from openai import OpenAI

from google.colab import userdata

api_key=userdata.get("DEEPSEEK_API_KEY")

client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")

## CELL1

In [None]:
# Cell 1: Conceptual Modifications - Aviation Data Definitions

import torch
import numpy as np
import os
import glob
import av
import json
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoVideoProcessor, AutoModel
from tqdm.auto import tqdm
import logging
import datetime
import pytz
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s')



class AgentConfig:
    LLM_MODEL_NAME: str = "deepseek-reasoner"
    CLASS_LABELS = [
        "airplane landing",
        "airplane takeoff",
        "airport ground operations",
        "in-flight cruise",
        "emergency landing",
        "pre-flight check/maintenance",
        "en-route cruise",
        "climb phase",
        "descent phase",
        "holding pattern"
    ]

# Define num_classes globally
num_classes = len(AgentConfig.CLASS_LABELS)

# --- FIX: CLASSIFIER_SAVE_PATH moved to global scope ---
CLASSIFIER_SAVE_PATH = "classifier_head_trained_on_tartan_aviation_sample.pth"

AIRPORTS = {
    "CYUL": {"name": "Montreal-Trudeau International", "lat": 45.4706, "lon": -73.7408, "elevation_ft": 118},
    "LFPG": {"name": "Paris-Charles de Gaulle", "lat": 49.0097, "lon": 2.5479, "elevation_ft": 392},
}

AIRCRAFT_PERFORMANCE = {
    "Boeing777_300ER": {
        "cruise_speed_knots": 490,
        "fuel_burn_kg_per_hour": 7000,
        "max_range_nm": 7900,
        "climb_rate_fpm": 2500,
        "descent_rate_fpm": 2000,
        "typical_cruise_altitude_ft": 37000,
        "fuel_capacity_kg": 145000
    }
}

hf_repo = "facebook/vjepa2-vitg-fpc64-256"
EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"

TOTAL_FLATTENED_VJEPA_DIM = 2048 * 1408

CONCEPTUAL_PLDM_LATENT_DIM = 1024

latent_dim_pldm = CONCEPTUAL_PLDM_LATENT_DIM
action_dim = 8

def load_and_process_video(video_path, processor_instance, model_instance, device_instance, num_frames_to_sample=16):
    """
    Loads a video, samples frames, and extracts V-JEPA features.
    Returns extracted features (torch.Tensor, shape like [1, 2048, 1408]) and the frames.
    Does NOT flatten the V-JEPA output here, keeping it as model's raw output.
    """
    frames = []
    if not os.path.exists(video_path):
        logging.error(f"ERROR: Video file '{video_path}' not found.")
        return None, None
    try:
        container = av.open(video_path)
        total_frames_in_video = container.streams.video[0].frames
        sampling_interval = max(1, total_frames_in_video // num_frames_to_sample)
        logging.info(f"Total frames in video: {total_frames_in_video}")
        logging.info(f"Sampling interval: {sampling_interval} frames")

        for i, frame in enumerate(container.decode(video=0)):
            if len(frames) >= num_frames_to_sample:
                break
            if i % sampling_interval == 0:
                img = frame.to_rgb().to_ndarray()
                frames.append(img)

        if not frames:
            logging.error(f"ERROR: No frames could be loaded from '{video_path}'.")
            return None, None
        elif len(frames) < num_frames_to_sample:
            logging.warning(f"WARNING: Only {len(frames)} frames loaded. Requested: {num_frames_to_sample}.")

        inputs = processor_instance(videos=list(frames), return_tensors="pt")
        inputs = {k: v.to(device_instance) for k, v in inputs.items()}

        with torch.no_grad():
            features = model_instance(**inputs).last_hidden_state

        logging.info(f"Successfully extracted V-JEPA features with raw shape: {features.shape}")
        return features, frames

    except av.FFmpegError as e:
        logging.error(f"Error loading video with PyAV: {e}")
        logging.error("This might indicate an issue with the video file itself or PyAV installation.")
        return None, None
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        logging.error("Ensure 'av' library is installed (`pip install av`) and video file is not corrupt.")
        return None, None

class ClassifierHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        return self.fc2(self.dropout(self.relu(self.fc1(x))))

class LatentDynamicsPredictor(torch.nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(latent_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, latent_state, action):
        combined_input = torch.cat([latent_state, action], dim=-1)
        predicted_next_latent_state = self.layers(combined_input)
        return predicted_next_latent_state

class LatentProjector_old(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.projector = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.projector(x))


class LatentProjector(nn.Module):
    def __init__(self, input_dim=4, output_dim=1024):
        super().__init__()
        self.projector = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()



print("\n--- Instantiating Models and Optimizers ---")
model = AutoModel.from_pretrained(hf_repo)
processor = AutoVideoProcessor.from_pretrained(hf_repo)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

latent_projector = LatentProjector(TOTAL_FLATTENED_VJEPA_DIM, CONCEPTUAL_PLDM_LATENT_DIM)
latent_projector.to(device)

predictor = LatentDynamicsPredictor(latent_dim_pldm, action_dim)
predictor.to(device)
optimizer_pldm = torch.optim.Adam(list(predictor.parameters()) + list(latent_projector.parameters()), lr=0.001)

classifier = ClassifierHead(input_dim=1408, num_classes=num_classes)
classifier.to(device)

print(f"Models instantiated and moved to {device}.")
print("\nCell 1 setup complete for conceptual flight planning.")

## CELL2

In [None]:
#Cell 2: Core Execution Feature Extraction, Classifier Training & Inference, LLM Interaction, and PLDM Training/Planning
# This cell assumes Cell 1 has been successfully executed in the current session.
# All objects (model, processor, classifier, predictor, device, optimizer_pldm)
# and all function definitions (load_and_process_video, ClassifierHead, LatentDynamicsPredictor)
# are expected to be available from Cell 1's execution.
import os
import logging
import torch
import json
from google.colab import drive
from tqdm.auto import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import datetime
import pytz

#Mounting Google Drive
print("\n--- Cell 2: Mounting Google Drive for dataset access ---")
drive.mount('/content/gdrive')
print("Google Drive mounted.")

print(f"Checking for extracted features directory: {EXTRACTED_FEATURES_DIR}")
if not os.path.exists(EXTRACTED_FEATURES_DIR):
    logging.error(f"ERROR: Extracted features directory '{EXTRACTED_FEATURES_DIR}' not found. Please create it and upload V-JEPA features.")
    exit()
else:
    print(f"Extracted features directory found at {EXTRACTED_FEATURES_DIR}")

# Part 1: Load and process airplane-landing.mp4 for initial observation
print(f"\n--- Cell 2: Part 1 - Loading actual video '/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/airplane-landing.mp4' for feature extraction ---")
flight_video_path = '/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/airplane-landing.mp4'
# Use the defined load_and_process_video helper function. It now returns features and frames.
video_features_for_inference_raw, frames_for_pldm_planning = load_and_process_video(flight_video_path, processor, model, device_instance=device)

# -- CRITICAL: Process raw V-JEPA features to match ClassifierHead's expected input_dim --
if video_features_for_inference_raw is not None:
    # V-JEPA output shape is typically [1, 2048, 1408] (Batch, Channels, Height * Width if 1D)
    # Your old code pooled it as squeeze(0).mean(dim=0).unsqueeze(0), which resulted in [1, 1408] for classifier.
    # So, extracted_embedding_dim should be 1408 for the classifier.
    pooled_features_for_classifier = video_features_for_inference_raw.squeeze(0).mean(dim=0).unsqueeze(0)
    extracted_embedding_dim_for_classifier = pooled_features_for_classifier.shape[-1]
    logging.info(f"Dynamically determined extracted_embedding_dim for ClassifierHead: {extracted_embedding_dim_for_classifier}")
else:
    pooled_features_for_classifier = None
    extracted_embedding_dim_for_classifier = -1
    logging.error("Failed to extract video features for classifier. Exiting Cell 2.")
    exit()

# Part 2: Classifier Training
print(f"\n--- Cell 2: Part 2 - Starting Classifier Training ---")
print(f"Attempting to load real V-JEPA features for classifier training or generate synthetic data.")
print(f"Using device for classifier training: {device}")

try:
    # Re-initialize classifier with the correct, dynamically determined input_dim
    classifier = ClassifierHead(input_dim=extracted_embedding_dim_for_classifier, num_classes=num_classes)
    classifier.to(device)

    train_features_list = []
    train_labels_list = []

    map_file_path = os.path.join(EXTRACTED_FEATURES_DIR, "feature_label_map.json")
    if not os.path.exists(map_file_path):
        logging.warning(f"Feature-label map file '{map_file_path}' not found. Generating synthetic data.")
        feature_label_map = {}
    else:
        with open(map_file_path, 'r') as f:
            feature_label_map = json.load(f)

    if not feature_label_map:
        logging.warning(f"Feature-label map at {map_file_path} is empty. Generating synthetic data.")
        num_training_samples = 2_000_000
        # Synthetic data generation uses the dynamically determined input_dim
        train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier)
        train_labels = torch.randint(0, num_classes, (num_training_samples,))
        train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=32, shuffle=True)
        val_loader = None
        print(f"Loaded {num_training_samples} SYNTHETIC features for training.")
    else:
        for item in tqdm(feature_label_map, desc="Loading real V-JEPA features"):
            feature_path = item['feature_path']
            label_idx = item['label_idx']
            try:
                if not os.path.isabs(feature_path):
                    feature_path = os.path.join(EXTRACTED_FEATURES_DIR, feature_path)

                if not os.path.exists(feature_path):
                    logging.warning(f"Feature file not found at {feature_path}. Skipping.")
                    continue

                feature = torch.load(feature_path, map_location=device)

                # Match your working code's pooling/squashing logic to get [1408] dim
                if feature.ndim == 3:
                    feature = feature.squeeze(0).mean(dim=0)
                elif feature.ndim == 2:
                    if feature.shape[0] == 1 and feature.shape[1] == 1408:
                        feature = feature.squeeze(0)
                    elif feature.shape[1] == 1408:
                        feature = feature.mean(dim=0)
                    else:
                        feature = feature.flatten()
                elif feature.ndim == 1:
                    pass
                else:
                    logging.warning(f"Skipping malformed feature from {feature_path}. Unexpected dimensions: {feature.ndim}")
                    continue

                # Final check after processing. Should be 1D with 1408 elements.
                if feature.shape[0] != extracted_embedding_dim_for_classifier:
                    logging.warning(f"Skipping feature at {feature_path}. Dimension mismatch: expected {extracted_embedding_dim_for_classifier}, got {feature.shape[0]}.")
                    continue

                train_features_list.append(feature)
                train_labels_list.append(label_idx)

            except Exception as e:
                logging.error(f"Error loading feature from {feature_path}: {e}. Skipping.")

        if train_features_list:
            train_features = torch.stack(train_features_list).to(device)
            train_labels = torch.tensor(train_labels_list).to(device)
            num_training_samples = len(train_features)
            print(f"Loaded {num_training_samples} REAL V-JEPA features for training.")

            if num_training_samples < 2:
                print("WARNING: Only 1 real V-JEPA feature loaded. Training may be unstable. Consider more data.")
                train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=1, shuffle=True)
                val_loader = None
            else:
                dataset_size = len(train_features)
                train_size = int(0.8 * dataset_size)
                val_size = dataset_size - train_size

                if val_size == 0 and train_size > 0:
                    train_size = dataset_size
                    train_dataset_real = TensorDataset(train_features, train_labels)
                    val_dataset_real = None
                else:
                    train_dataset_real, val_dataset_real = torch.utils.data.random_split(
                        TensorDataset(train_features, train_labels), [train_size, val_size]
                    )
                train_loader = DataLoader(train_dataset_real, batch_size=32, shuffle=True)
                val_loader = DataLoader(val_dataset_real, batch_size=32, shuffle=False) if val_dataset_real else None
                print(f"Training on {len(train_dataset_real)} samples, Validation on {len(val_dataset_real) if val_dataset_real else 0} samples.")
        else:
            logging.error("No real V-JEPA features could be loaded from map file. Generating synthetic data as fallback.")
            num_training_samples = 2_000_000
            train_features = torch.rand(num_training_samples, extracted_embedding_dim_for_classifier)
            train_labels = torch.randint(0, num_classes, (num_training_samples,))
            train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=32, shuffle=True)
            val_loader = None
            print(f"Loaded {num_training_samples} SYNTHETIC features for training as fallback.")

    criterion = torch.nn.CrossEntropyLoss()
    optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=0.001)

    num_epochs = 20
    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer_classifier.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_classifier.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader.dataset)

        val_loss = 0.0
        if val_loader and len(val_loader.dataset) > 0:
            classifier.eval()
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = classifier(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader.dataset)
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        else:
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}")
            print("No validation data available or validation dataset is empty.")


    print("\n--- Classifier Training Complete ---")
    torch.save(classifier.state_dict(), CLASSIFIER_SAVE_PATH)
    print(f"Classifier saved to: {CLASSIFIER_SAVE_PATH}")
except Exception as e:
    logging.error(f"Error during classifier training: {e}")


print("Cell 2 execution complete.")

In [11]:
# Import OpenAI client for DeepSeek API
from openai import OpenAI
from google.colab import userdata

# --- DeepSeek API Setup ---
api_key = userdata.get("DEEPSEEK_API_KEY")

if not api_key:
    print("Error: DEEPSEEK_API_KEY not found in userdata.")
    print("Please set your DeepSeek API key in Colab secrets.")
    exit()

client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
MODEL = "deepseek-reasoner"

In [12]:

#Part 3: Classification Inference and DEEPSEEK LLM Interaction
print("\n--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and DEEPSEEK LLM Interaction ---")
if pooled_features_for_classifier is None:
    logging.error("ERROR: Cannot perform classifier inference as 'pooled_features_for_classifier' is None.")
else:
    try:
        pooled_features_for_inference_on_device = pooled_features_for_classifier.to(device)

        classifier.load_state_dict(torch.load(CLASSIFIER_SAVE_PATH, map_location=device))
        logging.info(f"Classifier weights loaded from: {CLASSIFIER_SAVE_PATH}")

        classifier.eval()
        with torch.no_grad():
            logits = classifier(pooled_features_for_inference_on_device)
            probabilities = torch.softmax(logits, dim=1)

        predicted_class_idx = torch.argmax(probabilities, dim=1).item()
        predicted_confidence = probabilities[0, predicted_class_idx].item()


        # --- FIX: Use AgentConfig.CLASS_LABELS ---
        predicted_label = AgentConfig.CLASS_LABELS[predicted_class_idx]

        llm_input_description = ""
        if predicted_label == "airplane landing":
            llm_input_description = "The visual system detected an airplane landing."
        elif predicted_label == "airplane takeoff":
            llm_input_description = "The visual system detected an airplane takeoff."
        elif predicted_label == "airport ground operations":
            llm_input_description = "The visual system detected airport ground operations."
        elif predicted_label == "in-flight cruise":
            llm_input_description = "The visual system detected an airplane in-flight cruise."
        elif predicted_label == "emergency landing":
            llm_input_description = "The visual system detected a possible emergency landing."
        elif predicted_label == "pre-flight check/maintenance":
            llm_input_description = "The visual system detected pre-flight check or maintenance activity."
        else:
            llm_input_description = "The visual system detected an unrecognised flight activity."

        llm_input_description += f" (Confidence: {predicted_confidence:.2f})"

        print(f"\n--- AI Agent's Understanding from Classifier ---")
        print(f"**Primary Classification (Predicted by AI):** '{predicted_label}' {llm_input_description.split('Confidence:')[1].strip()}")
        print(f"**Description for LLM:** {llm_input_description}")
        print(f"Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.")

        print("\n--- Engaging DEEPSEEK LLM for Further Reasoning ---")
        try:
            llm_model = AgentConfig.LLM_MODEL_NAME

            prompt_for_deepseek = f"""
                  You are an AI assistant for flight planning operations.
                  Current visual observation: {llm_input_description}
                  Current time (EST): {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}

                  Based on this visual observation, provide a concise operational assessment relevant to flight planning.
                  If the observation seems random or uncertain, state that. Do not add any conversational filler.
                  """


            deepseek_response = client.chat.completions.create(
            model=llm_model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant"},
                {"role": "user", "content": prompt_for_deepseek},
            ],
            stream=False
            )


            print("\n--- DEEPSEEK LLM Response ---")
            if deepseek_response.choices and deepseek_response.choices[0].message.content:
                print(deepseek_response.choices[0].message.content)
                print("--- DEEPSEEK LLM Response - END ---")
                print('\n')
            else:

                print("DEEPSEEK LLM did not provide a text response or cannot provide one.")
                # Check if there's an error attribute before trying to print it
                if hasattr(deepseek_response, 'error') and deepseek_response.error:
                     print(f"LLM Error: {deepseek_response.error}")
                # Check if there's a prompt_feedback attribute before trying to print it
                if hasattr(deepseek_response, 'prompt_feedback') and deepseek_response.prompt_feedback:
                    print(f"Prompt Feedback: {deepseek_response.prompt_feedback}")


        except Exception as llm_e:
            logging.error(f"Error interacting with DEEPSEEK LLM: {llm_e}")
            logging.error("Ensure your DEEPSEEK_API_KEY is correctly set in Colab Secrets.")

    except Exception as e:
        logging.error(f"Error during classification inference or overall Cell 2 execution: {e}")

print(f"The V-JEPA features (shape: {pooled_features_for_classifier.shape}) are the core input that a trained classifier would learn from.")
print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")

print("\nCell 2 execution complete.")


--- Cell 2: Part 3 - Starting V-JEPA Feature-Driven Classification Inference and DEEPSEEK LLM Interaction ---

--- AI Agent's Understanding from Classifier ---
**Primary Classification (Predicted by AI):** 'airplane landing' 1.00)
**Description for LLM:** The visual system detected an airplane landing. (Confidence: 1.00)
Note: This classification's accuracy depends heavily on the quality and size of the real dataset used for classifier training.

--- Engaging DEEPSEEK LLM for Further Reasoning ---

--- DEEPSEEK LLM Response ---
Observation confirms active landing operations, which may affect arrival sequencing and runway availability for flight planning.
--- DEEPSEEK LLM Response - END ---


The V-JEPA features (shape: torch.Size([1, 1408])) are the core input that a trained classifier would learn from.
Current time in EST: 2025-10-20 15:29:23 EST

Cell 2 execution complete.


## cell3-prepdata - STEP2

In [7]:
import pandas as pd
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
print("Raw Date sample:\n", df['Date'].head())
def clean_date(x):
    try:
        if isinstance(x, (list, tuple)):
            return '-'.join(str(i) for i in x)
        if isinstance(x, str) and x.startswith('['):
            x = eval(x)  # Safely parse stringified list
            return '-'.join(str(i) for i in x)
        return str(x)
    except Exception as e:
        print(f"Error cleaning Date: {e}")
        return 'unknown'
df['Date'] = df['Date'].apply(clean_date)
df['traj_id'] = df['Tail'].astype(str) + '_' + df['Date'].astype(str)
print("Sample traj_id values:\n", df['traj_id'].head())
print("Rows per traj_id:\n", df.groupby('traj_id').size().describe())
print("Missing traj_id values:", df['traj_id'].isna().sum())

Raw Date sample:
 0    ['2022', '10', '25']
1    ['2022', '10', '25']
2    ['2022', '10', '25']
3    ['2022', '10', '25']
4    ['2022', '10', '25']
Name: Date, dtype: object
Sample traj_id values:
 0    N649DL_2022-10-25
1    N649DL_2022-10-25
2    N649DL_2022-10-25
3    N649DL_2022-10-25
4    N649DL_2022-10-25
Name: traj_id, dtype: object
Rows per traj_id:
 count      592.000000
mean       321.209459
std        911.335438
min          1.000000
25%         85.750000
50%        150.000000
75%        256.000000
max      11073.000000
dtype: float64
Missing traj_id values: 0


In [8]:
import pandas as pd
import torch
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
for col in required_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=required_cols, inplace=True)
traj = df.iloc[:5][required_cols]
for i in range(len(traj)):
    row = traj.iloc[i]
    try:
        tensor = torch.tensor(row.values.astype(np.float64), dtype=torch.float32)
        print(f"Row {i} tensor: {tensor}")
    except Exception as e:
        print(f"Error converting row {i}: {e}")

Row 0 tensor: tensor([   40.5879,   -79.7172, 33000.0000,   460.0000])
Row 1 tensor: tensor([   40.5900,   -79.7167, 33000.0000,   460.0000])
Row 2 tensor: tensor([   40.5939,   -79.7159, 33000.0000,   460.0000])
Row 3 tensor: tensor([   40.5960,   -79.7154, 33000.0000,   460.0000])
Row 4 tensor: tensor([   40.5960,   -79.7154, 33000.0000,   460.0000])


In [3]:
import pandas as pd
import numpy as np
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
for col in required_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=required_cols, inplace=True)
traj = df.iloc[:5][required_cols]
print("Raw values:\n", traj.values)
print("Raw dtype:", traj.values.dtype)
print("Converted values:\n", traj.values.astype(np.float64))
print("Converted dtype:", traj.values.astype(np.float64).dtype)

Raw values:
 [[   40.587936   -79.717224 33000.         460.      ]
 [   40.590042   -79.716736 33000.         460.      ]
 [   40.593925   -79.71586  33000.         460.      ]
 [   40.59602    -79.71542  33000.         460.      ]
 [   40.59602    -79.71542  33000.         460.      ]]
Raw dtype: float64
Converted values:
 [[   40.587936   -79.717224 33000.         460.      ]
 [   40.590042   -79.716736 33000.         460.      ]
 [   40.593925   -79.71586  33000.         460.      ]
 [   40.59602    -79.71542  33000.         460.      ]
 [   40.59602    -79.71542  33000.         460.      ]]
Converted dtype: float64


In [9]:
import pandas as pd
import torch
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
for col in required_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=required_cols, inplace=True)
traj = df.iloc[:5][required_cols]
for i in range(len(traj)):
    row = traj.iloc[i]
    try:
        tensor = torch.tensor(row.values.astype(np.float64), dtype=torch.float32)
        print(f"Row {i} tensor: {tensor}")
    except Exception as e:
        print(f"Error converting row {i}: {e}")

Row 0 tensor: tensor([   40.5879,   -79.7172, 33000.0000,   460.0000])
Row 1 tensor: tensor([   40.5900,   -79.7167, 33000.0000,   460.0000])
Row 2 tensor: tensor([   40.5939,   -79.7159, 33000.0000,   460.0000])
Row 3 tensor: tensor([   40.5960,   -79.7154, 33000.0000,   460.0000])
Row 4 tensor: tensor([   40.5960,   -79.7154, 33000.0000,   460.0000])


In [2]:
import pandas as pd
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
for col in required_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=required_cols, inplace=True)
df['traj_id'] = df['Tail'].astype(str) + '_' + df['Date'].apply(lambda x: '-'.join(map(str, x)) if isinstance(x, (list, tuple)) else str(x))
traj = df[df['traj_id'] == df['traj_id'].iloc[0]].sort_values('Time')
for i in range(min(5, len(traj))):
    row = traj.iloc[i][required_cols]
    print(f"Row {i}: {row.values}, Type: {row.values.dtype}")
    print(f"Is numeric: {all(row.apply(lambda x: isinstance(x, (int, float)) and not pd.isna(x)))}")

Row 0: [np.float64(40.634125) np.float64(-79.70679) np.float64(33000.0)
 np.float64(461.0)], Type: object
Is numeric: True
Row 1: [np.float64(40.634125) np.float64(-79.70679) np.float64(33000.0)
 np.float64(461.0)], Type: object
Is numeric: True
Row 2: [np.float64(40.638107) np.float64(-79.70587) np.float64(33000.0)
 np.float64(461.0)], Type: object
Is numeric: True
Row 3: [np.float64(40.587936) np.float64(-79.717224) np.float64(33000.0)
 np.float64(460.0)], Type: object
Is numeric: True
Row 4: [np.float64(40.590042) np.float64(-79.716736) np.float64(33000.0)
 np.float64(460.0)], Type: object
Is numeric: True


In [1]:
import pandas as pd
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
for col in required_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=required_cols, inplace=True)
print("Data types after cleaning:\n", df[required_cols].dtypes)
print("Missing values after cleaning:\n", df[required_cols].isna().sum())
print("Sample data:\n", df[required_cols].head())

Data types after cleaning:
 Lat         float64
Lon         float64
Altitude    float64
Speed       float64
dtype: object
Missing values after cleaning:
 Lat         0
Lon         0
Altitude    0
Speed       0
dtype: int64
Sample data:
          Lat        Lon  Altitude  Speed
0  40.587936 -79.717224   33000.0  460.0
1  40.590042 -79.716736   33000.0  460.0
2  40.593925 -79.715860   33000.0  460.0
3  40.596020 -79.715420   33000.0  460.0
4  40.596020 -79.715420   33000.0  460.0


In [12]:
import pandas as pd
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
print("Unique Tail values:", df['Tail'].nunique())
print("Sample Tail values:\n", df['Tail'].head())
print("Missing Tail values:", df['Tail'].isna().sum())

Unique Tail values: 578
Sample Tail values:
 0    N649DL
1    N649DL
2    N649DL
3    N649DL
4    N649DL
Name: Tail, dtype: object
Missing Tail values: 1450


In [None]:
import pandas as pd
df = pd.read_csv('/content/adsb/kbtp/raw/2022/10-25-22/1.csv', on_bad_lines='skip', engine='python')
df['traj_id'] = df['Tail'] + '_' + df['Date'].astype(str)
print("Rows per Tail_Date:\n", df.groupby('traj_id').size().describe())
print("Sample traj_id values:\n", df['traj_id'].head())
print("Missing traj_id values:", df['traj_id'].isna().sum())

In [None]:
import pandas as pd
import glob
csv_files = glob.glob('/content/adsb/kbtp/raw/2022/*/*.csv')[:5]
for csv_file in csv_files:
    try:
        df = pd.read_csv(csv_file, on_bad_lines='skip', engine='python')
        df['Date'] = df['Date'].apply(lambda x: '-'.join(x) if isinstance(x, list) else x)
        df['traj_id'] = df['Tail'].astype(str) + '_' + df['Date'].astype(str)
        print(f"\nFile: {csv_file}")
        print("Columns:", df.columns.tolist())
        print("Rows:", len(df))
        print("Unique Tail values:", df.get('Tail', pd.Series([])).nunique())
        print("Missing Tail values:", df.get('Tail', pd.Series([])).isna().sum())
        print("Missing Date values:", df.get('Date', pd.Series([])).isna().sum())
        print("Rows per Tail_Date:\n", df.groupby('traj_id').size().describe())
        print("Missing values:\n", df[['Lat', 'Lon', 'Altitude', 'Speed', 'Time']].isna().sum())
        print("Data types:\n", df[['Lat', 'Lon', 'Altitude', 'Speed', 'Time']].dtypes)
        print("Sample data:\n", df.head())
    except Exception as e:
        print(f"Error reading {csv_file}: {e}")

In [None]:
import os
from tqdm import tqdm

def get_human_readable_size(size_bytes):
    """Converts bytes into a human-readable format."""
    if size_bytes == 0:
        return "0 B"
    size_names = ("B", "KB", "MB", "GB", "TB")
    i = 0
    while size_bytes >= 1024 and i < len(size_names) - 1:
        size_bytes /= 1024
        i += 1
    return f"{size_bytes:.2f} {size_names[i]}"

def check_file_sizes(base_dir):
    """
    Walks through a directory, checks the size of all .csv files,
    calculates the total size, and deletes any 0-byte files.
    """
    total_size_bytes = 0
    file_count = 0
    zero_byte_count = 0
    deleted_count = 0

    print(f"Scanning directory for files to process and clean: {base_dir}")

    # First pass: Collect all CSV files
    all_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith('.csv'):
                all_files.append(os.path.join(root, file))

    print(f"Found {len(all_files)} total CSV files to inspect.")

    # Second pass: Process, total file sizes, and delete empty files
    for file_path in tqdm(all_files, desc="Checking and Cleaning Files"):
        file_count += 1
        try:
            size_bytes = os.path.getsize(file_path)

            if size_bytes == 0:
                zero_byte_count += 1

                # 🔥 ACTION: DELETE THE 0-BYTE FILE 🔥
                os.remove(file_path)
                deleted_count += 1

            else:
                # Only count size of files that were NOT deleted
                total_size_bytes += size_bytes

        except FileNotFoundError:
            print(f"Warning: File not found at {file_path}")
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

    # --- Summary ---
    print("\n--- Scan and Clean Results ---")
    print(f"Total CSV files inspected: {file_count}")
    print(f"Files found to be 0 bytes: {zero_byte_count}")
    print(f"Files successfully deleted: {deleted_count}")
    print(f"Remaining data volume (non-empty files): {get_human_readable_size(total_size_bytes)}")
    print("----------------------------")

    return total_size_bytes

# --- Execution ---
# Replace this path with the correct root directory of your ADS-B data,
# e.g., the folder containing the date-based directories (10-27-22, 10-26-22, etc.)
BASE_DATA_PATH = "/content/adsb/kbtp/raw/2022"

if __name__ == "__main__":
    if not os.path.exists(BASE_DATA_PATH):
        print(f"Error: Base directory not found at {BASE_DATA_PATH}")
    else:
        check_file_sizes(BASE_DATA_PATH)


In [None]:
import os
from tqdm import tqdm

def get_human_readable_size(size_bytes):
    """Converts bytes into a human-readable format."""
    if size_bytes == 0:
        return "0 B"
    size_names = ("B", "KB", "MB", "GB", "TB")
    i = 0
    while size_bytes >= 1024 and i < len(size_names) - 1:
        size_bytes /= 1024
        i += 1
    return f"{size_bytes:.2f} {size_names[i]}"

def check_file_sizes(base_dir):
    """
    Walks through a directory, checks the size of all .csv files,
    and calculates the total size.
    """
    total_size_bytes = 0
    file_count = 0
    zero_byte_count = 0

    print(f"Scanning directory: {base_dir}")

    # First pass: Collect all CSV files
    all_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith('.csv'):
                all_files.append(os.path.join(root, file))

    print(f"Found {len(all_files)} total CSV files to process.")

    # Second pass: Process and total the file sizes with a progress bar
    for file_path in tqdm(all_files, desc="Checking File Sizes"):
        file_count += 1
        try:
            size_bytes = os.path.getsize(file_path)
            total_size_bytes += size_bytes

            if size_bytes == 0:
                zero_byte_count += 1
                # Optional: Uncomment the line below to list each empty file
                # print(f"  [Empty] {file_path}")

        except FileNotFoundError:
            print(f"Warning: File not found at {file_path}")
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

    # --- Summary ---
    print("\n--- Scan Results ---")
    print(f"Total CSV files found: {file_count}")
    print(f"Files that are 0 bytes: {zero_byte_count}")
    print(f"Total size of all CSV files: {get_human_readable_size(total_size_bytes)}")
    print("--------------------")

    return total_size_bytes

# --- Execution ---
# Replace this path with the correct root directory of your ADS-B data,
# e.g., the folder containing the date-based directories (10-27-22, 10-26-22, etc.)
BASE_DATA_PATH = "/content/adsb/kbtp/raw/2022"

if __name__ == "__main__":
    if not os.path.exists(BASE_DATA_PATH):
        print(f"Error: Base directory not found at {BASE_DATA_PATH}")
    else:
        check_file_sizes(BASE_DATA_PATH)


## Cell 3

In [1]:
import pandas as pd
import torch
import torch.nn as nn
import os
from torch import optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from tqdm import tqdm
import random
import logging
import psutil
import gc

# Set logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')

# Constants
STATE_DIM = 4      # Lat, Lon, Altitude, Speed
LATENT_DIM = 16
ACTION_DIM = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Latent Projector
class LatentProjector(nn.Module):
    def __init__(self, state_dim=STATE_DIM, latent_dim=LATENT_DIM):
        super().__init__()
        self.encoder_net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return self.encoder_net(x)

latent_projector = LatentProjector(state_dim=STATE_DIM, latent_dim=LATENT_DIM).to(device)
latent_projector.eval()
print("LatentProjector instance:", latent_projector)

# Latent Dynamics Predictor
class LatentDynamicsPredictor(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, action_dim=ACTION_DIM):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, latent_dim)

    def forward(self, latent, action):
        x = torch.cat([latent, action], dim=-1)
        x = torch.relu(self.fc1(x))
        z_tp1 = self.fc2(x)
        return z_tp1

predictor = LatentDynamicsPredictor(latent_dim=LATENT_DIM, action_dim=ACTION_DIM).to(device)
optimizer_pldm = optim.Adam(list(predictor.parameters()) + list(latent_projector.parameters()), lr=1e-3)
print("Predictor instance:", predictor)

# Training Function
def train_latent_dynamics_model(predictor_model, optimizer, training_data_loader, epochs=5):
    predictor_model.train()
    criterion = torch.nn.functional.mse_loss
    print("\n--- Training Latent Dynamics Predictor for Conceptual Real Flights (Causal Focus) ---")
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (latent_s_t, action_t, latent_s_t_plus_1) in tqdm(
            enumerate(training_data_loader),
            total=len(training_data_loader),
            desc=f"Epoch {epoch+1}/{epochs}"
        ):
            latent_s_t, action_t, latent_s_t_plus_1 = (
                latent_s_t.to(device),
                action_t.to(device),
                latent_s_t_plus_1.to(device)
            )
            predicted_z_t_plus_1 = predictor_model(latent_s_t, action_t)
            loss = criterion(predicted_z_t_plus_1, latent_s_t_plus_1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {total_loss / len(training_data_loader):.6f}")
    print("--- Training Complete ---")

# Data Loading Function
def load_real_dynamics_data(device, adsb_dir="/content/adsb/kbtp/raw/2022", num_trajectories=10, max_traj_per_file=100):
    """
    Loads ADS-B trajectories and transforms them into (latent_s_t, action_t, latent_s_t_plus_1) tuples.
    """
    data = []
    trajectory_files = []

    try:
        for root, _, files in os.walk(adsb_dir):
            for file in files:
                if file.endswith('.csv'):
                    trajectory_files.append(os.path.join(root, file))
    except Exception as e:
        logging.critical(f"Failed to walk directory {adsb_dir}. Error: {e}")
        return []

    if not trajectory_files:
        logging.warning(f"No CSV files found in {adsb_dir}. Falling back to synthetic data.")
        return []

    print(f"\nAttempting to load REAL ADS-B data...")
    print(f"Found {len(trajectory_files)} CSV files. Limiting to {num_trajectories} files for demo speed.")

    trajectory_files = trajectory_files[:num_trajectories]

    for file in tqdm(trajectory_files, desc="Processing Trajectories"):
        try:
            print(f"\nProcessing {file}")
            print(f"Memory usage before: {psutil.virtual_memory().percent}%")
            if os.path.getsize(file) == 0:
                logging.warning(f"Skipping {file}: Empty file.")
                continue

            df = pd.read_csv(file, on_bad_lines='skip', engine='python')
            required_cols = ['Lat', 'Lon', 'Altitude', 'Speed']
            time_col = 'Time'

            available_cols = df.columns.tolist()
            print(f"Columns in {file}: {available_cols}")
            if not all(col in available_cols for col in required_cols):
                logging.warning(f"Skipping {file}: Missing columns. Required: {required_cols}, Available: {available_cols}")
                continue

            for col in required_cols:
                df[col] = pd.to_numeric(df[col], errors='coerce')

            df.dropna(subset=required_cols, inplace=True)
            print(f"Rows after cleaning {file}: {len(df)}")

            if df.empty or len(df) < 2:
                logging.warning(f"Skipping {file}: Fewer than 2 valid rows after cleaning.")
                continue

            # Clean Date
            def clean_date(x):
                try:
                    if isinstance(x, (list, tuple)):
                        return '-'.join(str(i).zfill(2) if idx > 0 else str(i) for idx, i in enumerate(x))
                    if isinstance(x, str):
                        if x.startswith('['):
                            x = eval(x)
                            return '-'.join(str(i).zfill(2) if idx > 0 else str(i) for idx, i in enumerate(x))
                        if '-' in x or '/' in x:
                            return pd.to_datetime(x, errors='coerce').strftime('%Y-%m-%d')
                        return x
                    return str(x)
                except Exception as e:
                    logging.warning(f"Error cleaning Date in {file}: {e}")
                    return 'unknown'

            # Clean Time
            def clean_time(x):
                try:
                    if isinstance(x, (list, tuple)):
                        return ':'.join(str(i).zfill(2) for i in x)
                    if isinstance(x, str):
                        if x.startswith('['):
                            x = eval(x)
                            return ':'.join(str(i).zfill(2) for i in x)
                        if ':' in x:
                            return x
                        return x
                    return str(x)
                except Exception as e:
                    logging.warning(f"Error cleaning Time in {file}: {e}")
                    return '00:00:00'

            if 'Date' in df.columns:
                df['Date'] = df['Date'].apply(clean_date)
                print(f"Cleaned Date sample in {file}:\n", df['Date'].head(5).to_list())
            if 'Time' in df.columns:
                df['Time'] = df['Time'].apply(clean_time)
                print(f"Cleaned Time sample in {file}:\n", df['Time'].head(5).to_list())
                df['Time'] = pd.to_datetime(df['Time'], format='%H:%M:%S', errors='coerce')

            if 'Tail' in df.columns and 'Date' in df.columns:
                df['traj_id'] = df['Tail'].astype(str) + '_' + df['Date'].astype(str)
                df['traj_id'] = df['traj_id'].fillna('missing_' + pd.Series(range(len(df)), index=df.index).astype(str))
            else:
                logging.warning(f"Missing 'Tail' or 'Date' in {file}. Using index as traj_id.")
                df['traj_id'] = pd.Series(range(len(df)), index=df.index).astype(str)

            print(f"Unique trajectories in {file}: {df['traj_id'].nunique()}")
            print("Rows per traj_id:\n", df.groupby('traj_id').size().describe())

            # Limit trajectories to manage memory
            traj_ids = df['traj_id'].unique()[:max_traj_per_file]
            print(f"Limiting to {len(traj_ids)} trajectories in {file}")

            latent_projector.eval()

            for traj_id in traj_ids:
                traj = df[df['traj_id'] == traj_id].sort_values(time_col)
                #print(f"Trajectory {traj_id} in {file}: {len(traj)} rows")
                if len(traj) < 2:
                    logging.warning(f"Skipping trajectory {traj_id} in {file}: Fewer than 2 points.")
                    continue

                for i in range(len(traj) - 1):
                    try:
                        row_t = traj.iloc[i][required_cols]
                        row_tp1 = traj.iloc[i+1][required_cols]
                        if not all(row_t.apply(lambda x: isinstance(x, (int, float)) and not pd.isna(x))) or \
                           not all(row_tp1.apply(lambda x: isinstance(x, (int, float)) and not pd.isna(x))):
                            logging.warning(f"Non-numeric data in trajectory {traj_id} at index {i}: {row_t.values}, {row_tp1.values}")
                            continue
                        state_t = torch.tensor(row_t.values.astype(np.float64), dtype=torch.float32, device=device)
                        state_tp1 = torch.tensor(row_tp1.values.astype(np.float64), dtype=torch.float32, device=device)
                        action = torch.tensor([
                            state_tp1[3] - state_t[3],  # Delta Speed
                            state_tp1[2] - state_t[2],  # Delta Altitude
                            0.0, 0.0, 0.0, 0.0, 0.0, 0.0
                        ], dtype=torch.float32, device=device)
                        with torch.no_grad():
                            projected_state_t = latent_projector(state_t).squeeze(0).cpu()
                            projected_state_tp1 = latent_projector(state_tp1).squeeze(0).cpu()
                        data.append((projected_state_t, action.cpu(), projected_state_tp1))
                    except Exception as e:
                        logging.warning(f"Error processing trajectory {traj_id} in {file}: {e}")

            del df
            gc.collect()
            print(f"Memory usage after {file}: {psutil.virtual_memory().percent}%")
        except Exception as e:
            logging.warning(f"Error reading {file}: {e}")

    print(f"Loaded {len(data)} real dynamics samples from {len(trajectory_files)} files.")
    return data

# Synthetic Data Fallback
def generate_synthetic_data(num_trajectories=1000, trajectory_length=20):
    latent_projector.eval()
    synthetic_data = []
    for _ in range(num_trajectories):
        base_state = torch.tensor([45.47, -73.74, 37000.0, 490.0], dtype=torch.float32, device=device)
        base_state += torch.randn(STATE_DIM, device=device) * 0.1
        current_state = base_state.clone()
        for _ in range(trajectory_length):
            delta_v = random.uniform(-10.0, 10.0)
            delta_alt = random.uniform(-50.0, 50.0)
            action = torch.tensor([delta_v, delta_alt, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                                  dtype=torch.float32, device='cpu').detach()
            next_state = current_state.clone()
            next_state[3] += delta_v
            next_state[2] += delta_alt
            with torch.no_grad():
                projected_state_t = latent_projector(current_state.to(device)).squeeze(0).cpu().detach()
                projected_state_tp1 = latent_projector(next_state.to(device)).squeeze(0).cpu().detach()
            synthetic_data.append((projected_state_t, action, projected_state_tp1))
            current_state = next_state.clone()
    return synthetic_data

# Execution Block
FILE_LIMIT = 5  # Process 10 files for demo
print("\n--- Starting AGI Demo Pipeline (Cell 3) ---")
dynamics_training_data = load_real_dynamics_data(device, num_trajectories=FILE_LIMIT, max_traj_per_file=100)
if not dynamics_training_data:
    print(f"--- CRITICAL WARNING: NO REAL DATA LOADED from {FILE_LIMIT} files. CREATING SYNTHETIC FALLBACK. ---")
    dynamics_training_data = generate_synthetic_data(num_trajectories=1000)
    print(f"Successfully generated {len(dynamics_training_data)} synthetic causal samples for demonstration.")
if dynamics_training_data:
    print("\nPreparing DataLoader for batch training...")
    z_t_list, a_t_list, z_tp1_list = zip(*dynamics_training_data)
    Z_T = torch.stack(z_t_list)
    A_T = torch.stack(a_t_list)
    Z_TP1 = torch.stack(z_tp1_list)
    dataset = TensorDataset(Z_T, A_T, Z_TP1)
    BATCH_SIZE = 64
    dynamics_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    print(f"Total batches for training: {len(dynamics_dataloader)}")
    train_latent_dynamics_model(predictor, optimizer_pldm, dynamics_dataloader, epochs=5)
else:
    print("FATAL ERROR: No data (real or synthetic) could be prepared. Training aborted.")
print("\nCell 3 execution complete.")
print("Cell 3 completed. Predictor and latent_projector defined.")

Using device: cuda
LatentProjector instance: LatentProjector(
  (encoder_net): Sequential(
    (0): Linear(in_features=4, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=16, bias=True)
  )
)
Predictor instance: LatentDynamicsPredictor(
  (fc1): Linear(in_features=24, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=16, bias=True)
)

--- Starting AGI Demo Pipeline (Cell 3) ---

Attempting to load REAL ADS-B data...
Found 246 CSV files. Limiting to 5 files for demo speed.


Processing Trajectories:   0%|          | 0/5 [00:00<?, ?it/s]


Processing /content/adsb/kbtp/raw/2022/05-13-22/1.csv
Memory usage before: 3.7%
Columns in /content/adsb/kbtp/raw/2022/05-13-22/1.csv: ['ID', 'Time', 'Date', 'Altitude', 'Speed', 'Heading', 'Lat', 'Lon', 'Age', 'Range', 'Bearing', 'Tail', 'AltisGNSS']
Rows after cleaning /content/adsb/kbtp/raw/2022/05-13-22/1.csv: 126665
Cleaned Date sample in /content/adsb/kbtp/raw/2022/05-13-22/1.csv:
 ['2022-05-13', '2022-05-13', '2022-05-13', '2022-05-13', '2022-05-13']
Cleaned Time sample in /content/adsb/kbtp/raw/2022/05-13-22/1.csv:
 ['06:05:56', '06:05:57', '06:05:59', '06:05:59', '06:06:01']
Unique trajectories in /content/adsb/kbtp/raw/2022/05-13-22/1.csv: 413
Rows per traj_id:
 count     413.000000
mean      306.694915
std       870.596904
min         1.000000
25%        90.000000
50%       132.000000
75%       200.000000
max      9031.000000
dtype: float64
Limiting to 100 trajectories in /content/adsb/kbtp/raw/2022/05-13-22/1.csv


Processing Trajectories:  20%|██        | 1/5 [01:15<05:00, 75.20s/it]

Memory usage after /content/adsb/kbtp/raw/2022/05-13-22/1.csv: 4.0%

Processing /content/adsb/kbtp/raw/2022/03-12-22/1.csv
Memory usage before: 4.0%
Columns in /content/adsb/kbtp/raw/2022/03-12-22/1.csv: ['ID', 'Time', 'Date', 'Altitude', 'Speed', 'Heading', 'Lat', 'Lon', 'Age', 'Range', 'Bearing', 'Tail', 'AltisGNSS']
Rows after cleaning /content/adsb/kbtp/raw/2022/03-12-22/1.csv: 11208
Cleaned Date sample in /content/adsb/kbtp/raw/2022/03-12-22/1.csv:
 ['2022-03-12', '2022-03-12', '2022-03-12', '2022-03-12', '2022-03-12']
Cleaned Time sample in /content/adsb/kbtp/raw/2022/03-12-22/1.csv:
 ['07:08:25.238', '07:08:27.143', '07:08:28.939', '07:08:29.058', '07:08:29.058']
Unique trajectories in /content/adsb/kbtp/raw/2022/03-12-22/1.csv: 168
Rows per traj_id:
 count    168.000000
mean      66.714286
std       62.286007
min        1.000000
25%       15.750000
50%       67.000000
75%       89.250000
max      530.000000
dtype: float64
Limiting to 100 trajectories in /content/adsb/kbtp/raw/2

Processing Trajectories:  40%|████      | 2/5 [01:26<01:52, 37.54s/it]

Memory usage after /content/adsb/kbtp/raw/2022/03-12-22/1.csv: 4.0%

Processing /content/adsb/kbtp/raw/2022/06-06-22/1.csv
Memory usage before: 4.0%
Columns in /content/adsb/kbtp/raw/2022/06-06-22/1.csv: ['ID', 'Time', 'Date', 'Altitude', 'Speed', 'Heading', 'Lat', 'Lon', 'Age', 'Range', 'Bearing', 'Tail', 'AltisGNSS']
Rows after cleaning /content/adsb/kbtp/raw/2022/06-06-22/1.csv: 258261
Cleaned Date sample in /content/adsb/kbtp/raw/2022/06-06-22/1.csv:
 ['2022-06-06', '2022-06-06', '2022-06-06', '2022-06-06', '2022-06-06']
Cleaned Time sample in /content/adsb/kbtp/raw/2022/06-06-22/1.csv:
 ['06:03:19', '06:03:20', '06:03:21', '06:03:23', '06:03:24']
Unique trajectories in /content/adsb/kbtp/raw/2022/06-06-22/1.csv: 863
Rows per traj_id:
 count      863.000000
mean       299.259560
std        707.961323
min          1.000000
25%         76.500000
50%        192.000000
75%        335.500000
max      13278.000000
dtype: float64
Limiting to 100 trajectories in /content/adsb/kbtp/raw/2022

Processing Trajectories:  60%|██████    | 3/5 [02:17<01:27, 43.86s/it]

Memory usage after /content/adsb/kbtp/raw/2022/06-06-22/1.csv: 4.0%

Processing /content/adsb/kbtp/raw/2022/04-14-22/2.csv
Memory usage before: 4.0%
Columns in /content/adsb/kbtp/raw/2022/04-14-22/2.csv: ['ID', 'Time', 'Date', 'Altitude', 'Speed', 'Heading', 'Lat', 'Lon', 'Age', 'Range', 'Bearing', 'Tail', 'AltisGNSS']
Rows after cleaning /content/adsb/kbtp/raw/2022/04-14-22/2.csv: 617
Cleaned Date sample in /content/adsb/kbtp/raw/2022/04-14-22/2.csv:
 ['2022-04-14', '2022-04-14', '2022-04-14', '2022-04-14', '2022-04-14']
Cleaned Time sample in /content/adsb/kbtp/raw/2022/04-14-22/2.csv:
 ['17:49:23.137', '17:49:23.137', '17:49:23.137', '17:49:23.137', '17:49:23.137']
Unique trajectories in /content/adsb/kbtp/raw/2022/04-14-22/2.csv: 6
Rows per traj_id:
 count      6.000000
mean     102.833333
std      103.902679
min        1.000000
25%       39.000000
50%       90.000000
75%      111.750000
max      295.000000
dtype: float64
Limiting to 6 trajectories in /content/adsb/kbtp/raw/2022/04

Processing Trajectories:  80%|████████  | 4/5 [02:19<00:27, 27.05s/it]

Memory usage after /content/adsb/kbtp/raw/2022/04-14-22/2.csv: 4.0%

Processing /content/adsb/kbtp/raw/2022/04-14-22/1.csv
Memory usage before: 4.0%
Columns in /content/adsb/kbtp/raw/2022/04-14-22/1.csv: ['ID', 'Time', 'Date', 'Altitude', 'Speed', 'Heading', 'Lat', 'Lon', 'Age', 'Range', 'Bearing', 'Tail', 'AltisGNSS']
Rows after cleaning /content/adsb/kbtp/raw/2022/04-14-22/1.csv: 2693
Cleaned Date sample in /content/adsb/kbtp/raw/2022/04-14-22/1.csv:
 ['2022-04-14', '2022-04-14', '2022-04-14', '2022-04-14', '2022-04-14']
Cleaned Time sample in /content/adsb/kbtp/raw/2022/04-14-22/1.csv:
 ['08:57:08.527', '08:57:09.616', '08:57:10.986', '08:57:12.037', '08:57:13.154']
Unique trajectories in /content/adsb/kbtp/raw/2022/04-14-22/1.csv: 24
Rows per traj_id:
 count     24.000000
mean     112.208333
std      152.939311
min        1.000000
25%       12.250000
50%       62.500000
75%      108.500000
max      529.000000
dtype: float64
Limiting to 24 trajectories in /content/adsb/kbtp/raw/2022

Processing Trajectories: 100%|██████████| 5/5 [02:23<00:00, 28.76s/it]


Memory usage after /content/adsb/kbtp/raw/2022/04-14-22/1.csv: 4.0%
Loaded 74836 real dynamics samples from 5 files.

Preparing DataLoader for batch training...
Total batches for training: 1170

--- Training Latent Dynamics Predictor for Conceptual Real Flights (Causal Focus) ---


Epoch 1/5: 100%|██████████| 1170/1170 [00:02<00:00, 496.97it/s]


Epoch 1/5, Average Loss: 136538.918775


Epoch 2/5: 100%|██████████| 1170/1170 [00:02<00:00, 526.79it/s]


Epoch 2/5, Average Loss: 108.723427


Epoch 3/5: 100%|██████████| 1170/1170 [00:02<00:00, 531.23it/s]


Epoch 3/5, Average Loss: 70.215588


Epoch 4/5: 100%|██████████| 1170/1170 [00:02<00:00, 546.47it/s]


Epoch 4/5, Average Loss: 194.888320


Epoch 5/5: 100%|██████████| 1170/1170 [00:02<00:00, 535.72it/s]

Epoch 5/5, Average Loss: 83.070327
--- Training Complete ---

Cell 3 execution complete.
Cell 3 completed. Predictor and latent_projector defined.





## Cell 4

In [2]:
class LatentProjector(nn.Module):
    def __init__(self, state_dim=4, latent_dim=16):
        super().__init__()
        self.encoder_net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return self.encoder_net(x)

In [None]:
import torch
import torch.nn as nn
import os
import pandas as pd
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import gc # Import garbage collector

from warnings import filterwarnings
filterwarnings("ignore")

# --- Define constants used by predictor/projector from Cell 3 (for robustness) ---
LATENT_DIM = 16
ACTION_DIM = 8
# ---------------------------------------------------------------------------------

# Define missing variables
AIRPORTS = {
    "CYUL": {"lat": 45.4706, "lon": -73.7408, "name": "Montreal-Trudeau International"},
    "LFPG": {"lat": 49.0128, "lon": 2.5500, "name": "Paris-Charles de Gaulle"}
}
AIRCRAFT_PERFORMANCE = {
    "Boeing777_300ER": {
        "max_speed_kts": 490.0,
        "cruise_altitude_ft": 37000.0,
        "range_nm": 7370.0
    }
}
# Move model and processor to the specified device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Ensure device is defined
# Consider using a smaller model if memory is an issue, e.g., "openai/clip-vit-base-patch16"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) # Processor doesn't need .to(device)

# Redefine LatentProjector (re-using definition from prior cell)
class LatentProjector(nn.Module):
    def __init__(self, state_dim=4, latent_dim=LATENT_DIM): # Using LATENT_DIM from Cell 3/here
        super().__init__()
        self.encoder_net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return self.encoder_net(x)

# Placeholder for load_and_process_video
def load_and_process_video(video_path, processor, model, device, num_frames=1):
    try:
        from torchvision.io import read_video
        # Ensure video is loaded onto the correct device
        video, _, _ = read_video(video_path, pts_unit='sec')
        # Select and move only the specified number of frames to device
        video = video[:num_frames].to(device)
        # Process frames individually or in a small batch if num_frames > 1
        # For simplicity and memory reduction, let's process frame by frame or take a simple average if multiple frames requested
        features_list = []
        for frame in video:
            inputs = processor(images=frame.unsqueeze(0), return_tensors="pt").to(device) # Process one frame at a time
            with torch.no_grad():
                features = model.get_image_features(**inputs)
                features_list.append(features)

        if not features_list:
             return None, "No frames processed"

        # Average features if more than one frame was processed
        averaged_features = torch.mean(torch.stack(features_list), dim=0)

        # Clear intermediate tensors
        del features_list, features, inputs, video, frame
        torch.cuda.empty_cache() # Clear GPU cache if using CUDA
        gc.collect() # Collect garbage

        return averaged_features, None

    except Exception as e:
        print(f"Error processing video {video_path}: {e}")
        return None, str(e) # Return error message


def plan_montreal_to_paris_flight(start_airport_data, target_airport_data, aircraft_model_data,
                                  encoder_model, processor_instance, predictor_model, latent_projector_instance,
                                  planning_horizon=50, action_dim=ACTION_DIM, num_action_samples=100): # Using ACTION_DIM
    encoder_model.eval()
    predictor_model.eval()
    latent_projector_instance.eval()
    print("AIRPORTS:", AIRPORTS)
    print("AIRCRAFT_PERFORMANCE:", AIRCRAFT_PERFORMANCE)
    print("model:", encoder_model)
    print("processor:", processor_instance)
    print("predictor:", predictor_model)
    print("latent_projector:", latent_projector_instance)
    initial_video_path = '/content/gdrive/MyDrive/datasets/TartanAviation/vision/1_2023-02-22-15-21-49/1_2023-02-22-15-21-49.mp4'
    target_video_path = initial_video_path  # Update with landing video

    # Process a minimal number of frames for memory efficiency
    initial_features, initial_error = load_and_process_video(initial_video_path, processor_instance, encoder_model, device, num_frames=1)
    target_features, target_error = load_and_process_video(target_video_path, processor_instance, encoder_model, device, num_frames=1) # Use num_frames=1 here too

    if initial_features is None or target_features is None:
        error_message = f"Video load failed. Initial: {initial_error}, Target: {target_error}. Using dummy features."
        print(error_message)
        # Ensure dummy features are on the correct device and have correct shape
        dummy_feature_shape = (1, 512) # CLIP image features shape
        initial_features = torch.rand(dummy_feature_shape).to(device)
        target_features = torch.rand(dummy_feature_shape).to(device)


    # -- FIX: Ensure LATENT_DIM is available --
    latent_state_dim = LATENT_DIM # Should be 16 from Cell 3

    # Assuming CLIP features are (1, 512) and latent_projector maps 512 -> 16
    visual_feature_dim = initial_features.shape[-1] if initial_features is not None else 512


    # Re-initialize latent_projector_instance if its input dimension is incorrect
    if latent_projector_instance.encoder_net[0].in_features != visual_feature_dim:
         print(f"Warning: LatentProjector input dimension mismatch. Expected {visual_feature_dim}, got {latent_projector_instance.encoder_net[0].in_features}. Re-initializing.")
         # Ensure latent_projector is created with the correct dimensions (visual feature dim -> latent state dim)
         # Using LATENT_DIM=16
         latent_projector_instance = LatentProjector(state_dim=visual_feature_dim, latent_dim=latent_state_dim).to(device)

    with torch.no_grad():
        # Ensure input to latent_projector is the correct shape (batch_size, visual_feature_dim)
        current_latent_state = latent_projector_instance(initial_features)
        target_latent_state = latent_projector_instance(target_features)


    # Clear CLIP features after projection
    del initial_features, target_features
    torch.cuda.empty_cache()
    gc.collect()


    ETHICAL_BOUNDARY_LATENT_VECTOR = torch.zeros(1, latent_state_dim).to(device) # Ensure size matches latent_state_dim
    weather_path = '/content/TartanAviation/vision/weather_stats.csv'
    salience = torch.rand(1).to(device) * 0.8
    if os.path.exists(weather_path):
        try:
            weather_df = pd.read_csv(weather_path)
            if 'visibility' in weather_df.columns:
                # Ensure salience is a tensor on the correct device
                salience = torch.tensor(weather_df['visibility'].mean() / 10.0, device=device)
        except Exception as we:
            print(f"Error loading weather data: {we}. Using default salience.")
            salience = torch.rand(1).to(device) * 0.8 # Fallback if weather loading fails


    print("\n--- Starting Real Flight Plan ---")
    print(f"Current Latent State Shape: {current_latent_state.shape}")
    print(f"Target Latent State Shape: {target_latent_state.shape}")
    print(f"Salience Level: {salience.item():.2f}")
    print('\n')
    best_action_sequence = []
    # num_action_samples = 100 # Now passed as argument
    action_dim = ACTION_DIM # Using ACTION_DIM from Cell 3/here

    # Ensure predictor input dimension matches (latent_state_dim + action_dim)
    # Assuming predictor has an attribute like fc1.in_features
    try:
        predicted_state_dim = predictor_model.fc1.in_features
        expected_predictor_input_dim = latent_state_dim + action_dim
        if predicted_state_dim != expected_predictor_input_dim:
            print(f"Warning: Predictor input dimension mismatch. Expected {expected_predictor_input_dim}, got {predicted_state_dim}. Predictor may not be compatible.")
            # We cannot re-initialize the predictor here as it's passed in.
            # This warning alerts the user to a potential issue with the provided predictor model.
    except AttributeError:
        print("Warning: Could not check predictor input dimension (no fc1 attribute). Predictor may not be compatible.")


    current_latent = current_latent_state # Rename for clarity in loop

    for step in range(planning_horizon):
        # Generate candidate actions - reduced number for memory
        candidate_actions = torch.rand(num_action_samples, action_dim).to(device) * 2.0 - 1.0

        # Prepare inputs for the Predictor, which expects the state and action separately.
        # Repeat current_latent (z_t) N times to match the number of candidate actions
        repeated_current_latent = current_latent.repeat(num_action_samples, 1)

        with torch.no_grad():
            # FIX: Pass current state and actions as separate arguments
            simulated_next_latents = predictor_model(repeated_current_latent, candidate_actions)

        # Clear the repeated tensor immediately to save memory
        del repeated_current_latent

        simulated_trajectories_cost = []
        for i in range(num_action_samples):
            simulated_next_latent = simulated_next_latents[i].unsqueeze(0) # Get the result for this sample

            # --- Cost Calculation ---
            goal_proximity_cost = torch.norm(target_latent_state - simulated_next_latent) * 1.0
            conceptual_fuel_cost = torch.norm(candidate_actions[i]) * 0.05
            conceptual_weather_cost = torch.rand(1).to(device) * 0.02 # This should ideally use actual weather data/model

            # --- Ethical and Salience Costs (Pillars 3 and 4) ---
            ethical_cost = 5.0 * torch.norm(ETHICAL_BOUNDARY_LATENT_VECTOR - simulated_next_latent)
            cautious_action_penalty = torch.norm(candidate_actions[i]) * salience # salience is a tensor
            salience_alignment_cost = 2.0 * cautious_action_penalty # cautious_action_penalty is a scalar cost

            total_cost = goal_proximity_cost + conceptual_fuel_cost + conceptual_weather_cost + ethical_cost + salience_alignment_cost
            simulated_trajectories_cost.append(total_cost.item()) # Append scalar cost

        # Clear intermediate tensors after calculating costs
        del simulated_next_latents, candidate_actions
        torch.cuda.empty_cache()
        gc.collect()


        best_candidate_idx = torch.argmin(torch.tensor(simulated_trajectories_cost)) # Convert list back to tensor for argmin

        # Re-generate candidate actions to get the best one (or store them before deletion)
        # Let's regenerate for simplicity, since we need the original actions that minimized the cost

        # NOTE: In a clean MPPI implementation, we would store all candidate actions before the deletion above.
        # Since they were deleted, we must regenerate the tensor and pick the optimal index.
        candidate_actions = torch.rand(num_action_samples, action_dim).to(device) * 2.0 - 1.0 # Regenerate
        optimal_action_for_step = candidate_actions[best_candidate_idx]

        best_action_sequence.append(optimal_action_for_step.squeeze().cpu().numpy())

        with torch.no_grad():
            # Update current_latent using the chosen optimal action
            optimal_action_input = optimal_action_for_step.unsqueeze(0)
            # FIX: Predict the next latent state using the correct two-argument call
            current_latent = predictor_model(current_latent, optimal_action_input)

        # Clear tensors used in this step
        del optimal_action_for_step, optimal_action_input
        torch.cuda.empty_cache()
        gc.collect()


    print(f"Real Plan for {planning_horizon} steps (first 5 actions shown):")
    for i, action in enumerate(best_action_sequence[:5]):
        print(f"Step {i+1}: {np.round(action, 4)}")
    return best_action_sequence

# Run in Cell 4
# Ensure predictor and latent_projector are defined before calling this function.
# They are likely defined in Cell 1 and potentially trained in Cell 3.
# Assuming 'predictor' and 'latent_projector' are available in the global scope from previous cells.
try:
    # Pass a smaller number of action samples to reduce memory
    conceptual_flight_plan_actions = plan_montreal_to_paris_flight(
        AIRPORTS["CYUL"], AIRPORTS["LFPG"], AIRCRAFT_PERFORMANCE["Boeing777_300ER"],
        model, processor, predictor, latent_projector, # Pass the loaded/defined predictor and latent_projector
        num_action_samples=50 # Reduced number of action samples
    )
except NameError as ne:
    print(f"Error: {ne}. Make sure 'predictor' and 'latent_projector' are defined by running previous cells.")
except Exception as e:
    print(f"An error occurred during flight planning: {e}")


AIRPORTS: {'CYUL': {'lat': 45.4706, 'lon': -73.7408, 'name': 'Montreal-Trudeau International'}, 'LFPG': {'lat': 49.0128, 'lon': 2.55, 'name': 'Paris-Charles de Gaulle'}}
AIRCRAFT_PERFORMANCE: {'Boeing777_300ER': {'max_speed_kts': 490.0, 'cruise_altitude_ft': 37000.0, 'range_nm': 7370.0}}
model: CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affin