<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/V_JEPA_DEMO_JUN2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## dataset

https://paperswithcode.com/paper/tartanaviation-image-speech-and-ads-b

In [None]:
!pip install colab-env -q
import colab_env

In [13]:
%cd /content/gdrive/MyDrive/datasets/
!git clone https://github.com/castacks/TartanAviation.git

/content/gdrive/MyDrive/datasets
Cloning into 'TartanAviation'...
remote: Enumerating objects: 180, done.[K
remote: Counting objects: 100% (180/180), done.[K
remote: Compressing objects: 100% (141/141), done.[K
remote: Total 180 (delta 97), reused 97 (delta 33), pack-reused 0 (from 0)[K
Receiving objects: 100% (180/180), 4.04 MiB | 14.17 MiB/s, done.
Resolving deltas: 100% (97/97), done.


In [14]:
!ls -l /content/gdrive/MyDrive/datasets/TartanAviation/vision/

total 67
-rw------- 1 root root 11153 Jul 23 12:43 dataloader.py
-rw------- 1 root root  7364 Jul 23 12:43 download.py
-rw------- 1 root root  6119 Jul 23 12:43 progress.py
-rw------- 1 root root  3436 Jul 23 12:43 README.md
drwx------ 2 root root  4096 Jul 23 12:43 recording
-rw------- 1 root root 35369 Jul 23 12:43 weather_stats.csv


## Part 0

In [32]:
import os
import sys
import datetime
import pytz
import shutil # For removing directories

# Ensure you are in the correct directory to run the download script
%cd /content/TartanAviation/vision/
print(f"Current working directory set to: {os.getcwd()}")


# 1. Install MinIO library
print("\nInstalling minio...")
!pip install minio


# Define your final desired save directory IN GOOGLE DRIVE
# This is where the *recording folders* (e.g., '1_2023-02-22-15-21-49') will be placed directly.
DRIVE_RECORDINGS_ROOT = "/content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/"

# Create the final destination directory in Drive if it doesn't exist
os.makedirs(DRIVE_RECORDINGS_ROOT, exist_ok=True)
print(f"\nGoogle Drive destination for raw recordings: {DRIVE_RECORDINGS_ROOT}")


# 2. Run the download script to get a sample of the data
print(f"\nAttempting to download TartanAviation sample data (approx. 1.9 GB) directly to DRIVE {DRIVE_RECORDINGS_ROOT}...")

# Optional: Remove previous download if it exists to force a fresh download
sample_recording_folder_name = "1_2023-02-22-15-21-49" # Name of the sample recording folder
full_sample_recording_path = os.path.join(DRIVE_RECORDINGS_ROOT, sample_recording_folder_name)
if os.path.exists(full_sample_recording_path):
    print(f"Found existing sample recording at {full_sample_recording_path}. Removing to force fresh download.")
    shutil.rmtree(full_sample_recording_path)
    print("Removed old sample recording.")


# --- DIRECT EXECUTION OF download.py ---
print("\n--- Executing download.py (Output will stream directly below) ---")
print("Waiting for download.py to complete. This may take a while.")
print("Look for progress/error messages from download.py itself.")

# Execute the command directly, letting its output stream to the console
download_command = f"python3 download.py --option Sample --save_dir \"{DRIVE_RECORDINGS_ROOT}\""
!{download_command}

print("\n--- download.py execution finished ---")


# --- DIAGNOSTIC STEPS POST-DOWNLOAD ATTEMPT ---
print(f"\n--- Post-Download Attempt Verification ({datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')} EST) ---")

# Verify the data in the final Google Drive path
ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE = os.path.join(DRIVE_RECORDINGS_ROOT, sample_recording_folder_name)

if os.path.exists(ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE):
    if len(os.listdir(ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE)) > 0:
        print(f"SUCCESS: Data directory found in DRIVE at: {ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE}")
        print("Contents (first level - should be recording IDs like '1_2023-02-22-15-21-49'):")
        !ls -l "{ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE}"
        print("\nAttempting to list first recording's content (if available):")
        !ls -l "{ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE}"/* | head -n 5
    else:
        print(f"FAILURE: Data directory in DRIVE at: {ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE} is empty.")
        print("This indicates download.py ran but did not place any data there.")
else:
    print(f"FAILURE: Data directory NOT found in DRIVE at: {ACTUAL_DOWNLOADED_DATA_PATH_IN_DRIVE}.")
    print("This indicates download.py failed to download or place data.")

%cd /content/ # Change back to /content/
print(f"\nPart 0 execution complete at {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')} EST.")

/content/TartanAviation/vision
Current working directory set to: /content/TartanAviation/vision

Installing minio...

Google Drive destination for raw recordings: /content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/

Attempting to download TartanAviation sample data (approx. 1.9 GB) directly to DRIVE /content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/...

--- Executing download.py (Output will stream directly below) ---
Waiting for download.py to complete. This may take a while.
Look for progress/error messages from download.py itself.
Selected Option: Sample
Number of Video Folders: 1
1_2023-02-22-15-21-49.zip: |####################| 1945.34 MB/1945.34 MB 100% [elapsed: 29:17 left: 00:00,  1.11 MB/sec]Archive:  /content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/1_2023-02-22-15-21-49.zip
   creating: /content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/1_2023-02-22-15-21-49/
 extracting

## Part 1 (Cell 1): Video Loading and V-JEPA Feature Extraction

In [None]:
import torch
import av
import numpy as np
import os
from transformers import AutoVideoProcessor, AutoModel
from tqdm.notebook import tqdm # For progress bars in notebooks
import logging
import json # For saving feature map
import sys # For exiting on critical errors

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- GLOBAL CONFIGURATION PARAMETERS ---
# V-JEPA Model and Processor
HF_REPO = "facebook/vjepa2-vitg-fpc64-256"
NUM_FRAMES_TO_SAMPLE = 16 # Number of frames to sample for V-JEPA

# --- PATHS FOR DATA PROCESSING ---
# Define the path to the single video you want to classify in Part 3 (e.g., your original demo video)
# MAKE SURE THIS VIDEO IS UPLOADED TO /content/ or is accessible here.
DEMO_VIDEO_PATH = '/content/airplane-landing.mp4'

# Define the root directory where your raw TartanAviation videos are stored in Drive (after download.py)
# This path now points directly to the folder containing recording IDs (e.g., '1_2023-02-22-15-21-49')
DATASET_RAW_VIDEO_ROOT = "/content/gdrive/MyDrive/datasets/TartanAviation/vision/downloaded_recordings/"

# Define the directory where extracted V-JEPA features (from TartanAviation) will be saved.
EXTRACTED_FEATURES_SAVE_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/"
os.makedirs(EXTRACTED_FEATURES_SAVE_DIR, exist_ok=True) # Ensure this directory exists

# Define the class labels (these need to be derived from TartanAviation's actual labels)
CLASS_LABELS = [
    "airplane landing",
    "airplane takeoff",
    "airport ground operations",
    "in-flight cruise",
    "emergency landing",
    "pre-flight check/maintenance"
]

# --- V-JEPA Model/Processor Loading ---
logging.info(f"Loading V-JEPA model and processor from {HF_REPO}...")
try:
    model = AutoModel.from_pretrained(HF_REPO)
    processor = AutoVideoProcessor.from_pretrained(HF_REPO)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    logging.info(f"V-JEPA Model and processor loaded on {device}.")
except Exception as e:
    logging.error(f"Failed to load V-JEPA model/processor: {e}")
    model, processor, device = None, None, "cpu"
    sys.exit("Exiting: V-JEPA model or processor failed to load.")


# --- 1. Extract Features for the SPECIFIC DEMO VIDEO (for Part 3 Inference) ---
video_features_for_inference = None # This will hold the features for the single demo video

logging.info(f"\n--- Extracting features for the specific DEMO VIDEO: {DEMO_VIDEO_PATH} ---")
if not os.path.exists(DEMO_VIDEO_PATH):
    logging.error(f"Error: Demo video file '{DEMO_VIDEO_PATH}' not found.")
    logging.warning("Proceeding with dummy video creation for DEMO INFERENCE purposes.")
    height, width = 256, 256
    dummy_video_tensor = torch.rand(NUM_FRAMES_TO_SAMPLE, 3, height, width)
    frames_demo = [frame.permute(1, 2, 0).numpy() for frame in dummy_video_tensor]
    logging.info(f"Created dummy video frames for demo inference.")
else:
    frames_demo = []
    try:
        container_demo = av.open(DEMO_VIDEO_PATH)
        total_frames_in_video_demo = container_demo.streams.video[0].frames
        sampling_interval_demo = max(1, total_frames_in_video_demo // NUM_FRAMES_TO_SAMPLE)
        for i, frame in enumerate(tqdm(container_demo.decode(video=0), total=total_frames_in_video_demo, leave=False, desc=f"Sampling Demo Video: {DEMO_VIDEO_PATH}")):
            if len(frames_demo) >= NUM_FRAMES_TO_SAMPLE:
                break
            if i % sampling_interval_demo == 0:
                img = frame.to_rgb().to_ndarray()
                frames_demo.append(img)
        container_demo.close()
        logging.info(f"Successfully sampled {len(frames_demo)} frames from demo video.")
    except Exception as e:
        logging.error(f"Error loading demo video '{DEMO_VIDEO_PATH}': {e}")
        logging.warning("Proceeding with dummy video creation for DEMO INFERENCE purposes.")
        height, width = 256, 256
        dummy_video_tensor = torch.rand(NUM_FRAMES_TO_SAMPLE, 3, height, width)
        frames_demo = [frame.permute(1, 2, 0).numpy() for frame in dummy_video_tensor]

if frames_demo:
    try:
        inputs_demo = processor(videos=list(frames_demo), return_tensors="pt").to(device)
        with torch.no_grad():
            outputs_demo = model(**inputs_demo)
            video_features_for_inference = outputs_demo.last_hidden_state
        logging.info(f"Extracted features for demo video. Shape: {video_features_for_inference.shape}")
    except Exception as e:
        logging.error(f"Error extracting features for demo video: {e}")
        video_features_for_inference = None
else:
    logging.error("Failed to get frames for demo video. video_features_for_inference will be None.")


# --- 2. Dataset Processing to Extract V-JEPA Features (for Classifier TRAINING) ---

def get_label_from_video_metadata(recording_dir, class_labels):
    """
    Refined conceptual function for TartanAviation Sample:
    The sample recording is '1_2023-02-22-15-21-49'.
    The README and dataset structure point to '_sink' files for this sample,
    which typically indicates a landing or sink rate.
    This function attempts to infer based on these cues.
    For other full dataset videos, you would parse _labels.zip or dataloader.py.
    """
    dir_name = os.path.basename(recording_dir).lower()

    # Specific to the TartanAviation sample ID if it's '1_2023-02-22-15-21-49'
    if "1_2023-02-22-15-21-49" in dir_name:
        # This specific sample is confirmed to be a landing event.
        # It has "_sink" in associated files and is mentioned as a "sink" example in docs.
        return "airplane landing"

    # Generic (less reliable without actual labels) for other samples if they exist
    elif "sink" in dir_name or "landing" in dir_name:
        return "airplane landing"
    elif "takeoff" in dir_name:
        return "airplane takeoff"

    # Fallback if no specific keyword matches (most likely for a random sample)
    return "airport ground operations"


processed_count = 0
skipped_count = 0
feature_label_map = []
extracted_embedding_dim = -1 # To get actual embedding dim from first feature

print(f"\n--- Starting to extract V-JEPA features for TRAINING from: {DATASET_RAW_VIDEO_ROOT} ---")
if not os.path.exists(DATASET_RAW_VIDEO_ROOT):
    logging.error(f"ERROR: Raw video root directory for training data not found: {DATASET_RAW_VIDEO_ROOT}")
    logging.error("Skipping training data feature extraction.")
else:
    for recording_id_dir in tqdm(os.listdir(DATASET_RAW_VIDEO_ROOT), desc="Processing Recordings for Training Data"):
        recording_path = os.path.join(DATASET_RAW_VIDEO_ROOT, recording_id_dir)
        if os.path.isdir(recording_path):
            mp4_files = [f for f in os.listdir(recording_path) if f.endswith(".mp4")]
            if not mp4_files:
                logging.warning(f"No MP4 found in {recording_path}. Skipping for training data.")
                skipped_count += 1
                continue

            video_file = mp4_files[0]
            video_path_dataset = os.path.join(recording_path, video_file)

            label_str = get_label_from_video_metadata(recording_path, CLASS_LABELS)
            if label_str not in CLASS_LABELS:
                logging.warning(f"Skipping {video_path_dataset}: Cannot derive valid label for '{label_str}'.")
                skipped_count += 1
                continue
            label_idx = CLASS_LABELS.index(label_str)

            frames_dataset = []
            try:
                container_dataset = av.open(video_path_dataset)
                total_frames_in_video_dataset = container_dataset.streams.video[0].frames
                sampling_interval_dataset = max(1, total_frames_in_video_dataset // NUM_FRAMES_TO_SAMPLE)

                for i, frame in enumerate(tqdm(container_dataset.decode(video=0), total=total_frames_in_video_dataset, leave=False, desc=f"Sampling {video_file}")):
                    if len(frames_dataset) >= NUM_FRAMES_TO_SAMPLE:
                        break
                    if i % sampling_interval_dataset == 0:
                        img = frame.to_rgb().to_ndarray()
                        frames_dataset.append(img)
                container_dataset.close()

                if not frames_dataset:
                    logging.warning(f"No frames loaded from {video_file}. Skipping feature extraction for training data.")
                    skipped_count += 1
                    continue

                inputs_dataset = processor(videos=list(frames_dataset), return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs_dataset = model(**inputs_dataset)
                    video_features_raw_dataset = outputs_dataset.last_hidden_state

                pooled_feature_dataset = video_features_raw_dataset.squeeze(0).mean(dim=0).unsqueeze(0)
                if extracted_embedding_dim == -1: # Set it from the first successfully processed feature
                    extracted_embedding_dim = pooled_feature_dataset.shape[1]

                feature_filename = f"{os.path.splitext(video_file)[0]}_feature.pt"
                feature_save_path = os.path.join(EXTRACTED_FEATURES_SAVE_DIR, feature_filename)
                torch.save(pooled_feature_dataset.cpu(), feature_save_path)
                feature_label_map.append({
                    "feature_path": feature_save_path,
                    "label_idx": label_idx,
                    "label_name": label_str
                })
                processed_count += 1

            except Exception as e:
                logging.error(f"Error processing {video_file} for training data: {e}")
                skipped_count += 1

    map_file_path = os.path.join(EXTRACTED_FEATURES_SAVE_DIR, "feature_label_map.json")
    with open(map_file_path, 'w') as f:
        json.dump(feature_label_map, f, indent=4)
    print(f"\n--- Dataset Feature Extraction for Training Complete ---")
    print(f"Total videos processed for training data: {processed_count}")
    print(f"Total videos skipped for training data: {skipped_count}")
    print(f"Feature-label map saved to: {map_file_path}")
    print(f"Extracted features saved to: {EXTRACTED_FEATURES_SAVE_DIR}")
    print(f"Observed V-JEPA embedding dimension: {extracted_embedding_dim}")

## Part 2 (Cell 2): Classifier Training

In [None]:
import torch
import numpy as np
import logging
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm
import json # For loading the feature_label_map.json
import os # For path joining

# Configure logging for this cell
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Configuration for Real Data Loading ---
EXTRACTED_FEATURES_DIR = "/content/gdrive/MyDrive/datasets/TartanAviation_VJEPA_Features/" # Must match save dir in Part 1
CLASSIFIER_SAVE_PATH = "classifier_head_trained_on_tartan_aviation_sample.pth" # Specific name for real data

# --- Check for required variables from Part 1 ---
if 'video_features_for_inference' not in locals() or video_features_for_inference is None:
    logging.error("ERROR: 'video_features_for_inference' from Part 1 not found or is None. Please run Part 1 first.")
    exit()

if 'extracted_embedding_dim' not in locals() or extracted_embedding_dim == -1:
    logging.error("ERROR: 'extracted_embedding_dim' not found or not set by Part 1. Please ensure Part 1 ran successfully.")
    exit()
embedding_dim = extracted_embedding_dim # Use the dynamically determined embedding dim

if 'device' not in locals(): # device should be set by Part 1
    logging.warning("WARNING: 'device' variable not found from Part 1. Defaulting to 'cpu'.")
    device = "cpu"

# --- Classifier Training ---
print("\n--- Starting Classifier Training ---")
print(f"Attempting to load real V-JEPA features from: {EXTRACTED_FEATURES_DIR}")

# Define class labels (MUST MATCH LABELS USED DURING FEATURE EXTRACTION IN Part 1)
CLASS_LABELS = [ # Using CLASS_LABELS to be consistent with Part 1
    "airplane landing",
    "airplane takeoff",
    "airport ground operations",
    "in-flight cruise",
    "emergency landing",
    "pre-flight check/maintenance"
]
num_classes = len(CLASS_LABELS)

try:
    # Define a simple classification head
    class ClassifierHead(nn.Module):
        def __init__(self, input_dim, num_classes):
            super().__init__()
            self.fc1 = nn.Linear(input_dim, 256)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(0.3)
            self.fc2 = nn.Linear(256, num_classes)

        def forward(self, x):
            return self.fc2(self.dropout(self.relu(self.fc1(x))))

    classifier = ClassifierHead(input_dim=embedding_dim, num_classes=num_classes).to(device)

    # --- CORE TRAINING LOGIC ---
    print("\n--- Initiating Classifier Training ---")

    train_features_list = []
    train_labels_list = []

    map_file_path = os.path.join(EXTRACTED_FEATURES_DIR, "feature_label_map.json")

    if os.path.exists(EXTRACTED_FEATURES_DIR) and os.path.exists(map_file_path):
        with open(map_file_path, 'r') as f:
            feature_label_map = json.load(f)

        if not feature_label_map:
            logging.error(f"Feature-label map at {map_file_path} is empty. No real data to load for training.")
            logging.info("Classifier will be trained on SYNTHETIC data as a fallback.")
            num_training_samples = 2_000_000 # Fallback synthetic
            train_features = torch.rand(num_training_samples, embedding_dim).to(device)
            train_labels = torch.randint(0, num_classes, (num_training_samples,)).to(device)
            train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
            val_loader = None
            print(f"Loaded {num_training_samples} SYNTHETIC features for training (due to no real data).")

        else:
            # Load real features
            for item in tqdm(feature_label_map, desc="Loading real V-JEPA features"):
                feature_path = item['feature_path']
                label_idx = item['label_idx']
                try:
                    if not os.path.isabs(feature_path):
                        feature_path = os.path.join(EXTRACTED_FEATURES_DIR, os.path.basename(feature_path))

                    feature = torch.load(feature_path, map_location=device) # Load directly to device
                    if feature.ndim > 1 and feature.shape[0] == 1:
                        train_features_list.append(feature.squeeze(0))
                    elif feature.ndim == 1:
                        train_features_list.append(feature)
                    else:
                        logging.warning(f"Skipping malformed feature at {feature_path}. Expected 1D or [1, D], got {feature.shape}")
                        continue
                    train_labels_list.append(label_idx)
                except Exception as e:
                    logging.error(f"Error loading feature from {feature_path}: {e}")

            if train_features_list:
                train_features = torch.stack(train_features_list)
                train_labels = torch.tensor(train_labels_list)
                num_training_samples = len(train_features)
                print(f"Loaded {num_training_samples} REAL V-JEPA features for training.")

                if num_training_samples < 2: # Check if less than 2 samples
                    print("WARNING: Only 1 real V-JEPA feature loaded. Training will be performed on this single sample (no train/val split).")
                    train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=1, shuffle=False) # Batch size 1, no shuffle
                    val_loader = None # No validation set
                else:
                    # Original split logic for 2+ samples
                    dataset_size = len(train_features)
                    train_size = int(0.8 * dataset_size)
                    val_size = dataset_size - train_size
                    train_dataset_real, val_dataset_real = torch.utils.data.random_split(TensorDataset(train_features, train_labels), [train_size, val_size])

                    train_loader = DataLoader(train_dataset_real, batch_size=32, shuffle=True) # Smaller batch size for real data
                    val_loader = DataLoader(val_dataset_real, batch_size=32, shuffle=False)

                    print(f"Training on {len(train_dataset_real)} samples, Validating on {len(val_dataset_real)} samples.")
            else:
                logging.error("No real V-JEPA features could be loaded from map. Falling back to synthetic data.")
                num_training_samples = 2_000_000
                train_features = torch.rand(num_training_samples, embedding_dim).to(device)
                train_labels = torch.randint(0, num_classes, (num_training_samples,)).to(device)
                train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
                val_loader = None
                print(f"Loaded {num_training_samples} SYNTHETIC features for training (due to real data load failure).")

    else:
        logging.error(f"Extracted features directory '{EXTRACTED_FEATURES_DIR}' or map file '{map_file_path}' not found.")
        logging.info("Classifier will be trained on SYNTHETIC data as a fallback.")
        num_training_samples = 2_000_000
        train_features = torch.rand(num_training_samples, embedding_dim).to(device)
        train_labels = torch.randint(0, num_classes, (num_training_samples,)).to(device)
        train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=128, shuffle=True)
        val_loader = None
        print(f"Loaded {num_training_samples} SYNTHETIC features for training (no real data found).")


    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)

    num_epochs = 20
    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
            # CRITICAL FIX: Ensure inputs and labels are on the correct device
            inputs, labels = inputs.to(device), labels.to(device) # <--- ADDED THIS LINE

            optimizer.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader.dataset)

        val_loss = 0.0
        if val_loader and len(val_loader.dataset) > 0: # Only run validation if val_loader exists and is not empty
            classifier.eval()
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device) # <--- ADDED THIS LINE FOR VAL
                    outputs = classifier(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader.dataset)
            logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        else:
             logging.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f} (No validation data)")

    print("--- Classifier Training Complete ---")

    torch.save(classifier.state_dict(), CLASSIFIER_SAVE_PATH)
    print(f"Classifier saved to: {CLASSIFIER_SAVE_PATH}")

except Exception as e:
    logging.error(f"Error during classifier training: {e}")

## Part 3 (Cell 3): Classification Inference and Gemini LLM Interaction

In [38]:
import torch
import numpy as np
import logging
import torch.nn as nn
import datetime # Import datetime
import pytz # Import pytz for timezone

# Import Google Generative AI components
from google.colab import userdata
import google.generativeai as genai

# Configure logging for this cell
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- API Key Setup ---
GOOGLE_API_KEY = userdata.get('GEMINI')
if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)
    print("Google Generative AI configured successfully using Colab Secrets.")
else:
    print("WARNING: GOOGLE_API_KEY not found in Colab Secrets. Please ensure 'GEMINI' secret is set.")
    print("API calls will likely fail. Proceeding with unconfigured API.")

# --- Agent Configuration ---
class AgentConfig:
    LLM_MODEL_NAME: str = "gemini-2.5-flash"

# --- Check for required variables from Part 1 ---
if 'video_features_for_inference' not in locals() or video_features_for_inference is None:
    logging.error("ERROR: 'video_features_for_inference' not found or is None. Please run Part 1 (Video Loading and Feature Extraction) first and ensure it completes successfully.")
    exit()

if 'device' not in locals():
    logging.warning("WARNING: 'device' variable not found from Part 1. Defaulting to 'cpu'.")
    device = "cpu"

# Define class labels (MUST match the order used during training in Part 2)
CLASS_LABELS = [
    "airplane landing",
    "airplane takeoff",
    "airport ground operations",
    "in-flight cruise",
    "emergency landing",
    "pre-flight check/maintenance"
]
num_classes = len(CLASS_LABELS)
CLASSIFIER_SAVE_PATH = "classifier_head_trained_on_tartan_aviation_sample.pth" # Must match save path in Part 2

# --- Classification Inference and Gemini LLM Interaction ---
print("\n--- Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---")

try:
    # Determine the feature dimension (embedding_dim) from V-JEPA output
    embedding_dim = video_features_for_inference.shape[2]
    # Prepare the single video's feature for inference
    pooled_features_for_inference = video_features_for_inference.squeeze(0).mean(dim=0).unsqueeze(0)
    pooled_features_for_inference = pooled_features_for_inference.to(device)

    # Define the ClassifierHead architecture (must be identical to Part 2)
    class ClassifierHead(nn.Module):
        def __init__(self, input_dim, num_classes):
            super().__init__()
            self.fc1 = nn.Linear(input_dim, 256)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(0.3)
            self.fc2 = nn.Linear(256, num_classes)

        def forward(self, x):
            return self.fc2(self.dropout(self.relu(self.fc1(x)))).to(self.fc2.weight.device)

    classifier = ClassifierHead(input_dim=embedding_dim, num_classes=num_classes).to(device)

    # Load the trained classifier's weights
    if os.path.exists(CLASSIFIER_SAVE_PATH):
        classifier.load_state_dict(torch.load(CLASSIFIER_SAVE_PATH, map_location=device))
        logging.info(f"Classifier weights loaded from: {CLASSIFIER_SAVE_PATH}")
    else:
        logging.error(f"ERROR: Classifier weights file '{CLASSIFIER_SAVE_PATH}' not found.")
        logging.error("Please ensure Part 2 (Classifier Training) was run successfully and saved the model.")
        exit()

    # --- Making a Prediction Using the TRAINED Classifier ---
    classifier.eval() # Set model to evaluation mode for inference
    with torch.no_grad():
        logits = classifier(pooled_features_for_inference)
        probabilities = torch.softmax(logits, dim=1)
        predicted_class_idx = torch.argmax(probabilities, dim=1).item()
        predicted_confidence = probabilities[0, predicted_class_idx].item()
        predicted_label = CLASS_LABELS[predicted_class_idx]

    # --- Prepare Textual Input for Gemini LLM based on Prediction ---
    llm_input_description = ""
    if predicted_label == "airplane landing":
        llm_input_description = "The visual system detected an airplane landing. "
    elif predicted_label == "airplane takeoff":
        llm_input_description = "The visual system detected an airplane takeoff. "
    elif predicted_label == "airport ground operations":
        llm_input_description = "The visual system detected airport ground operations. "
    elif predicted_label == "in-flight cruise":
        llm_input_description = "The visual system detected an airplane in flight/cruise. "
    elif predicted_label == "emergency landing":
        llm_input_description = "The visual system detected a possible emergency landing scenario. "
    elif predicted_label == "pre-flight check/maintenance":
        llm_input_description = "The visual system detected pre-flight checks or maintenance activities. "
    else:
        llm_input_description = "The visual system detected an unrecognized or ambiguous aviation event. "

    llm_input_description += f"(Confidence: {predicted_confidence:.2f})"


    print(f"\n--- AI Agent's Understanding from Classifier ---")
    print(f"**Primary Classification (Predicted by AI):** '{predicted_label}' (Confidence: {predicted_confidence:.2f})")
    print(f"**Description for LLM:** {llm_input_description}")
    print(f"Note: This classification is from a model trained on synthetic data and is likely random until a real dataset is used for training.")

    # --- Gemini LLM Interaction ---
    print("\n--- Engaging Gemini LLM for Further Reasoning ---")
    try:
        llm_model = genai.GenerativeModel(AgentConfig.LLM_MODEL_NAME)

        prompt_for_gemini = f"""
        You are an AI assistant for flight planning operations.
        Current visual observation: {llm_input_description}
        Current time (EST): {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}
        Location: Montreal, Quebec, Canada.

        Based on this visual observation, provide a concise operational assessment relevant for flight planning.
        If the observation seems random or uncertain, state that. Do not invent details not present in the observation.
        """

        gemini_response = llm_model.generate_content(prompt_for_gemini)

        print("\n--- Gemini LLM Response ---")
        if gemini_response.candidates:
            for candidate in gemini_response.candidates:
                if candidate.content and candidate.content.parts:
                    for part in candidate.content.parts:
                        print(part.text)
        else:
            print("Gemini LLM did not provide a text response or candidates.")
            if gemini_response.prompt_feedback:
                print(f"Prompt Feedback: {gemini_response.prompt_feedback}")
            if hasattr(gemini_response, 'error'):
                print(f"LLM Error: {gemini_response.error}")


    except Exception as llm_e:
        logging.error(f"Error interacting with Gemini LLM: {llm_e}")
        logging.error("Ensure your GOOGLE_API_KEY is correctly set in Colab Secrets and the model name is valid.")


    print(f"\nThis prediction comes from a classifier that was trained on synthetic examples.")
    print(f"Due to the synthetic nature of the training data, this prediction is effectively random.")
    print(f"For *truly accurate and high-confidence predictions* on real videos,")
    print(f"the classifier needs to be trained on a large, diverse dataset of *real V-JEPA features and their corresponding labels*.")
    print(f"The V-JEPA features (shape: {video_features_for_inference.shape}) are the core input that a trained classifier would learn from.")
    print(f"Current time in EST: {datetime.datetime.now(pytz.timezone('EST')).strftime('%Y-%m-%d %H:%M:%S EST')}")

except Exception as e:
    logging.error(f"Error during classification inference or overall process: {e}")

Google Generative AI configured successfully using Colab Secrets.

--- Starting V-JEPA Feature-Driven Classification Inference and Gemini LLM Interaction ---

--- AI Agent's Understanding from Classifier ---
**Primary Classification (Predicted by AI):** 'airplane landing' (Confidence: 1.00)
**Description for LLM:** The visual system detected an airplane landing. (Confidence: 1.00)
Note: This classification is from a model trained on synthetic data and is likely random until a real dataset is used for training.

--- Engaging Gemini LLM for Further Reasoning ---

--- Gemini LLM Response ---
**Operational Assessment:**

An airplane is currently landing at Montreal (likely CYUL). This indicates active runway occupancy and short-term traffic flow at the airport. Flight planners should note this for immediate inbound or outbound operations in the Montreal area.

This prediction comes from a classifier that was trained on synthetic examples.
Due to the synthetic nature of the training data, t