<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/SEQ2SEQ_fptransformer_model_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Environment

In [9]:
!nvidia-smi

Tue May  6 16:29:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   77C    P0             36W /   72W |     905MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install colab-env -q
import colab_env
!pip install transformers datasets torch -q
!pip install geopy -q

In [2]:
import colab_env
import os

access_token_write = os.getenv("HUGGINGFACE_ACCESS_TOKEN_WRITE")

from huggingface_hub import login

login(
  token=access_token_write, # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

The code covers the complete process of training, evaluation, and validation for a built-from-scratch Transformer model for waypoint coordinate prediction using a Sequence-to-Sequence (Seq2Seq) architecture.

Here's a breakdown of how the code addresses each stage:

1. Building the Model:

* * It defines a custom Transformer model (Seq2SeqCoordsTransformer) with an encoder-decoder structure and attention mechanisms.

* * It includes positional encoding to capture sequence order information.

* * The output layer is designed to predict both waypoint coordinates and the waypoint count.

2. Training:

* * It uses a training loop to update the model's parameters using the training dataset.
* * It employs an optimizer (AdamW) and a combined loss function (CombinedLossSeq2Seq) that considers both coordinate and count prediction errors.
* * Data augmentation is applied during training to improve the model's robustness.

3. Evaluation and Validation:

* * The code splits the data into training, validation, and test sets.

* * After each training epoch, the model is evaluated on the validation set to monitor its performance on unseen data.

* * Early stopping is implemented to prevent overfitting and select the best-performing model.

4. Inference and Testing:

* * After training, the best model is loaded and used for inference on the test set.

* * The code calculates various evaluation metrics, including average coordinate loss, count loss, and average absolute count difference, to assess the model's accuracy and generalization ability.

In summary, the code provides a comprehensive implementation of a Seq2Seq Transformer model for flight plan waypoint prediction, including all the necessary steps for training, evaluation, validation, and testing. This suggests a well-structured and thorough approach to developing a model for this task.

## Training - Seq2SeqCoordsTransformer

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
# Ensure tqdm.notebook is used for interactive environments like Colab/Jupyter
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
import math
import numpy as np
import os
import json
import shutil
import random
import traceback # For printing full tracebacks on error

from warnings import filterwarnings
filterwarnings("ignore")


import torch.distributed as dist


# --- Configuration ---
# ... (other configuration settings) ...
world_size = torch.cuda.device_count()  # Get the number of available GPUs
rank = int(os.environ.get('RANK', 0)) # Set RANK variable for environment like slurm
print(f"World size: {world_size}, Rank: {rank}")
if world_size > 1:
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    device = torch.device(f"cuda:{rank}")
    print(f"Process {rank} using device: {device}")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using single device: {device}")


World size: 1, Rank: 0
Using single device: cuda


In [None]:
# ==============================================================================
# FINAL CODE (Seq2Seq Arch + Classification Count + Corrected count_loss_weight=100.0)
# Includes: Seq2Seq, LR=1e-5, Coord Norm+Sigmoid, Learned SOS, isclose Mask,
#           count_loss_weight=100.0, patience=10, Augmentation, CPU Debugging.
# WARNING: hf_repo_id points to frankmorales2020/FlightPlan_Transformer_LLM.
# Loading from this ID later will likely FAIL due to incompatible architecture.
# ==============================================================================

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
# Ensure tqdm.notebook is used for interactive environments like Colab/Jupyter
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
import math
import numpy as np
import os
import json
import shutil
import random
import traceback # For printing full tracebacks on error

from warnings import filterwarnings
filterwarnings("ignore")


import torch.distributed as dist

# --- Hugging Face Hub Integration ---
try:
    from huggingface_hub import HfApi, HfFolder, login, create_repo, upload_file, notebook_login, hf_hub_download
    access_token_write = os.getenv("HUGGINGFACE_ACCESS_TOKEN_WRITE")
    # Suppressing login attempt messages
except ImportError:
    print("Warning: huggingface_hub not found. Deployment/loading features unavailable.")
    HfApi = None; hf_hub_download = None

# --- Configuration ---
hf_repo_id = "frankmorales2020/FlightPlan_Transformer_LLM_1GPU_Colab"  # Removed the ')'
tokenizer_name = "gpt2"
dataset_name = "frankmorales2020/flight_plan_waypoints"
# Model Hyperparameters
embedding_dimension = 256; nhead = 8; num_encoder_layers = 4; num_decoder_layers = 4
dim_feedforward = 1024; transformer_dropout = 0.1
# Training Hyperparameters
batch_size = 16
learning_rate = 1e-5 # Keeping reduced LR
num_epochs = 20
# >>> PARAMETER CORRECTION: Setting count weight to 100.0 <<<
count_loss_weight = 0.5 # Increased significantly to improve count accuracy
coordinate_pad_value = 0.0
train_subset_size = None; eval_subset_size = None
# Early Stopping Configuration
early_stopping_patience = 10 # Keeping increased patience
min_delta = 0.0001
best_model_save_path = "./best_seq2seq_model_clf_count.bin"
# Data Augmentation Config
augment_training_data = True
coord_noise_level = 0.01
# Coordinate Scaling Params
LAT_MIN, LAT_MAX = -90.0, 90.0; LON_MIN, LON_MAX = -180.0, 180.0
COORD_EPSILON = 1e-6
print(f"Using Coord Scaling: Lat ({LAT_MIN}, {LAT_MAX}), Lon ({LON_MIN}, {LON_MAX})")

# --- Explicitly Setting max_waypoints ---
max_waypoints = 10
num_count_classes = max_waypoints + 1
max_coord_seq_len = max_waypoints + 1
max_text_seq_len = 128
print(f"Using max_waypoints: {max_waypoints} => Num Count Classes: {num_count_classes}")
print(f"Max Decoder Seq Len: {max_coord_seq_len}")


# >>> FORCING CPU EXECUTION FOR DEBUGGING <<<
#device = torch.device("cpu")
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#print(f"*** RUNNING ON CPU FOR DEBUGGING ***")
#print('\n\n')

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# --- Tokenizer Setup ---
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
special_tokens_dict = {}
if tokenizer.bos_token is None: special_tokens_dict['bos_token'] = '[SOS]'
if tokenizer.eos_token is None: special_tokens_dict['eos_token'] = '[EOS]'
if tokenizer.pad_token is None: special_tokens_dict['pad_token'] = '[PAD]'
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
if num_added_toks > 0: print(f"Added {num_added_toks} special tokens: {special_tokens_dict}")
sos_token_id = tokenizer.bos_token_id; eos_token_id = tokenizer.eos_token_id; pad_token_id = tokenizer.pad_token_id
vocab_size = len(tokenizer)
print(f"Tokenizer vocabulary size: {vocab_size}")
print(f"SOS ID: {sos_token_id}, EOS ID: {eos_token_id}, PAD ID: {pad_token_id}")
print('\n')

# --- Load Dataset ---
print(f"Loading dataset: {dataset_name}")
try: dataset = load_dataset(dataset_name); print("Dataset loaded.")
except Exception as e: raise SystemExit(f"ERROR: Failed to load dataset '{dataset_name}'. Error: {e}") from e


# --- Coordinate Normalization / Denormalization ---
PAD_COORD_NORM_LAT = max(0.0, min(1.0, (coordinate_pad_value - LAT_MIN) / (LAT_MAX - LAT_MIN + COORD_EPSILON)))
PAD_COORD_NORM_LON = max(0.0, min(1.0, (coordinate_pad_value - LON_MIN) / (LON_MAX - LON_MIN + COORD_EPSILON)))
PAD_COORD_NORM = [PAD_COORD_NORM_LAT, PAD_COORD_NORM_LON]
print(f"Using PAD_COORD_NORM: {PAD_COORD_NORM}")
print('\n')

def normalize_coords(coords_list):
    normalized = []
    for coords in coords_list:
        lat, lon = coords[0], coords[1]
        norm_lat = (lat - LAT_MIN) / (LAT_MAX - LAT_MIN + COORD_EPSILON)
        norm_lon = (lon - LON_MIN) / (LON_MAX - LON_MIN + COORD_EPSILON)
        norm_lat = max(0.0, min(1.0, norm_lat)); norm_lon = max(0.0, min(1.0, norm_lon))
        normalized.append([norm_lat, norm_lon])
    return normalized

def denormalize_coords(norm_coords_list):
    denormalized = []
    for norm_coords in norm_coords_list:
        norm_lat, norm_lon = norm_coords[0], norm_coords[1]
        if abs(norm_lat - PAD_COORD_NORM[0]) < COORD_EPSILON and abs(norm_lon - PAD_COORD_NORM[1]) < COORD_EPSILON: lat, lon = coordinate_pad_value, coordinate_pad_value
        else: lat = norm_lat * (LAT_MAX - LAT_MIN + COORD_EPSILON) + LAT_MIN; lon = norm_lon * (LON_MAX - LON_MIN + COORD_EPSILON) + LON_MIN
        denormalized.append([lat, lon])
    return denormalized

# --- Data Preprocessing Function (Seq2Seq, Norm, Correct SOS/EOS Handling v2) ---
print("Defining data preprocessing function...")
def preprocess_seq2seq_data(examples, is_training=False):
    # Returns integer target_count
    if "input" not in examples or "waypoints" not in examples or "label" not in examples: return {"input_ids": [], "attention_mask": [], "decoder_input_coords_norm": [], "target_coords_output_norm": [], "target_count": [], "coord_mask": []}
    encoder_inputs = tokenizer(examples["input"], padding="max_length", truncation=True, max_length=max_text_seq_len)
    input_ids = encoder_inputs["input_ids"]; attention_mask = encoder_inputs["attention_mask"]
    decoder_input_batch_norm, target_output_batch_norm, target_counts_batch, coord_masks_batch = [], [], [], []
    waypoints_list = examples["waypoints"] if isinstance(examples["waypoints"], list) else []; labels_list = examples["label"] if isinstance(examples["label"], list) else []
    min_len = min(len(waypoints_list), len(labels_list))

    for i in range(min_len):
        waypoints, label = waypoints_list[i], labels_list[i]
        try:
            if isinstance(waypoints, list) and all(isinstance(wp, (list, tuple)) and len(wp) == 2 for wp in waypoints): waypoints_float = [[float(lat), float(lon)] for lat, lon in waypoints]
            else: raise TypeError("Waypoints format incorrect")
        except (ValueError, TypeError, IndexError): waypoints_float = []

        if is_training and augment_training_data and waypoints_float:
            augmented_waypoints = []
            for lat, lon in waypoints_float: noise_lat=random.uniform(-coord_noise_level,coord_noise_level); noise_lon=random.uniform(-coord_noise_level,coord_noise_level); augmented_waypoints.append([lat+noise_lat,lon+noise_lon])
            coords_processed = augmented_waypoints
        else: coords_processed = waypoints_float

        coords_truncated = coords_processed[:max_waypoints]; num_actual_waypoints = len(coords_truncated)
        coords_normalized = normalize_coords(coords_truncated)
        decoder_input_seq_norm = coords_normalized
        target_output_seq_norm = coords_normalized + [PAD_COORD_NORM]
        decoder_input_padding_len = max_waypoints - len(decoder_input_seq_norm); decoder_input_seq_norm.extend([PAD_COORD_NORM] * decoder_input_padding_len)
        target_output_padding_len = max_coord_seq_len - len(target_output_seq_norm); target_output_seq_norm.extend([PAD_COORD_NORM] * target_output_padding_len)
        coord_mask = [1.0] * (num_actual_waypoints + 1) + [0.0] * target_output_padding_len

        decoder_input_batch_norm.append(decoder_input_seq_norm)
        target_output_batch_norm.append(target_output_seq_norm)
        coord_masks_batch.append(coord_mask)
        try:
            count_label = int(round(float(label))); count_label = max(0, min(max_waypoints, count_label))
            target_counts_batch.append(count_label)
        except (ValueError, TypeError): target_counts_batch.append(0)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_coords_norm": decoder_input_batch_norm, "target_coords_output_norm": target_output_batch_norm, "target_count": target_counts_batch, "coord_mask": coord_masks_batch}


# --- Apply Preprocessing and Split ---
print("Applying preprocessing (with augmentation for training set)...")
columns_to_remove_post_preprocess = ["distance", "distance_category", "waypoint_names"]
columns_to_remove_train_val = ['input', 'waypoints', 'label'] + columns_to_remove_post_preprocess
columns_to_remove_test = ['waypoints', 'label'] + columns_to_remove_post_preprocess
print('\n')
try:
    train_testvalid_original = dataset['train'].train_test_split(test_size=0.2, seed=42)
    test_valid_original = train_testvalid_original['test'].train_test_split(test_size=0.5, seed=42)
    processed_train = train_testvalid_original['train'].map(lambda examples: preprocess_seq2seq_data(examples, is_training=True), batched=True, remove_columns=columns_to_remove_train_val)
    processed_validation = test_valid_original['test'].map(lambda examples: preprocess_seq2seq_data(examples, is_training=False), batched=True, remove_columns=columns_to_remove_train_val)
    processed_test_for_loss = test_valid_original['train'].map(lambda examples: preprocess_seq2seq_data(examples, is_training=False), batched=True, remove_columns=['input', 'waypoints', 'label'] + columns_to_remove_test)
    original_test_set_for_comparison = test_valid_original['train']
    processed_train.set_format("torch"); processed_validation.set_format("torch"); processed_test_for_loss.set_format("torch")
    print("Preprocessing complete.")
except Exception as e: raise SystemExit(f"ERROR during preprocessing: {e}") from e

# Select data for training/evaluation/testing
train_data = processed_train.shuffle(seed=42).select(range(min(train_subset_size, len(processed_train)))) if train_subset_size else processed_train
eval_data = processed_validation.shuffle(seed=42).select(range(min(eval_subset_size, len(processed_validation)))) if eval_subset_size else processed_validation
test_data_processed_for_loss = processed_test_for_loss
print(f"Using Train: {len(train_data)}, Validation: {len(eval_data)}, Test (for loss): {len(test_data_processed_for_loss)} samples.")
print('\n')


# --- Data Loaders ---
print("Creating DataLoaders...")
try:
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    eval_dataloader = DataLoader(eval_data, batch_size=batch_size, drop_last=False)
    test_dataloader_for_loss = DataLoader(test_data_processed_for_loss, batch_size=batch_size, drop_last=False)

    print(f"Loaders created (Train/Eval/Test batches): {len(train_dataloader)} / {len(eval_dataloader)} / {len(test_dataloader_for_loss)}")
except Exception as e: raise SystemExit(f"ERROR creating DataLoaders: {e}") from e

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__(); self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1); div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model); pe[:, 0::2] = torch.sin(position * div_term); pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): x = x + self.pe[:x.size(1), :].unsqueeze(0); return self.dropout(x)

# --- Model Definition (Encoder-Decoder Transformer with Classification Count Head) ---
print("Defining the Seq2SeqCoordsTransformer model (Classification Count Head)...")
print('\n')
class Seq2SeqCoordsTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, nhead: int, src_vocab_size: int, num_count_classes: int, tgt_coord_dim: int = 2, dim_feedforward: int = 512, dropout: float = 0.1, max_text_len: int = 128, max_coord_len: int = 12):
        super().__init__(); self.emb_size = emb_size; self.max_coord_len = max_coord_len
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size); self.pos_encoder_enc = PositionalEncoding(emb_size, dropout, max_len=max_text_len); self.pos_encoder_dec = PositionalEncoding(emb_size, dropout, max_len=max_coord_len)
        self.coord_input_proj = nn.Linear(tgt_coord_dim, emb_size); self.coord_output_proj = nn.Linear(emb_size, tgt_coord_dim)
        self.sos_embedding = nn.Parameter(torch.randn(1, 1, emb_size) * 0.02)
        self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.encoder_pooler = lambda x: x.mean(dim=1)
        self.count_head = nn.Linear(emb_size, num_count_classes) # Classification head
        self._reset_parameters()
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1: nn.init.xavier_uniform_(p)
    def forward(self, src_input_ids: torch.Tensor, tgt_input_coords_norm: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor, src_padding_mask: torch.Tensor, tgt_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor):
        src_input_ids_clamped = src_input_ids.clamp(0, self.src_tok_emb.num_embeddings - 1); src_emb_lookup = self.src_tok_emb(src_input_ids_clamped); src_emb = self.pos_encoder_enc(src_emb_lookup)
        memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask); pooled_encoder_output = self.encoder_pooler(memory)
        predicted_count_logits = self.count_head(pooled_encoder_output) # Output logits
        batch_size = tgt_input_coords_norm.size(0)
        coord_vals_emb = self.coord_input_proj(tgt_input_coords_norm); sos_emb_batch = self.sos_embedding.repeat(batch_size, 1, 1)
        tgt_emb = torch.cat([sos_emb_batch, coord_vals_emb], dim=1); tgt_emb = self.pos_encoder_dec(tgt_emb)
        decoder_output = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_key_padding_mask=tgt_padding_mask)
        projected_coords = self.coord_output_proj(decoder_output); predicted_coords_normalized = torch.sigmoid(projected_coords)
        return predicted_coords_normalized, predicted_count_logits # Return logits

    def encode(self, src_input_ids: torch.Tensor, src_mask: torch.Tensor): # src_mask is padding mask
        src_input_ids_clamped = src_input_ids.clamp(0, self.src_tok_emb.num_embeddings - 1); src_emb_lookup = self.src_tok_emb(src_input_ids_clamped); src_emb = self.pos_encoder_enc(src_emb_lookup)
        memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_mask); pooled_memory = self.encoder_pooler(memory)
        predicted_count_logits = self.count_head(pooled_memory); return memory, predicted_count_logits # Return logits

# --- Utility Functions for Seq2Seq ---
def generate_square_subsequent_mask(sz, device): return torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)
def create_mask(src_input_ids, target_output_norm, pad_idx, device): # Use target output shape for target mask dims
    src_seq_len = src_input_ids.shape[1]; tgt_seq_len = target_output_norm.shape[1]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_padding_mask = (src_input_ids == pad_idx)
    pad_tensor = torch.tensor(PAD_COORD_NORM, device=device).unsqueeze(0).unsqueeze(0)
    tgt_padding_mask = torch.all(torch.isclose(target_output_norm, pad_tensor), dim=-1) # Check against normalized pad
    memory_key_padding_mask = src_padding_mask
    return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask

# --- Loss Function Definition (Classification Count Loss) ---
print("Defining the CombinedLoss function (Classification Count Loss)...")
print('\n')
class CombinedLossSeq2Seq(nn.Module):
    def __init__(self, count_loss_weight=100.0): # Using 100.0 now
        super().__init__(); self.coord_loss_fn = nn.MSELoss(reduction='none')
        self.count_loss_fn = nn.CrossEntropyLoss() # Using CrossEntropy
        self.count_loss_weight = count_loss_weight
    def forward(self, predicted_coords_norm, predicted_count_logits, target_coords_output_norm, target_count_labels, coord_mask):
        # predicted_count_logits: (N, num_classes), target_count_labels: (N,) LongTensor
        effective_coord_mask = coord_mask.unsqueeze(-1).expand_as(predicted_coords_norm)
        coord_loss_elementwise = self.coord_loss_fn(predicted_coords_norm, target_coords_output_norm)
        masked_coord_loss = coord_loss_elementwise * effective_coord_mask
        num_actual_elements = effective_coord_mask.sum(); mean_coord_loss = masked_coord_loss.sum() / num_actual_elements if num_actual_elements > 0 else torch.tensor(0.0, device=predicted_coords_norm.device)
        # Ensure labels are long and clamped
        target_count_labels = target_count_labels.long().clamp(0, predicted_count_logits.size(1) - 1)
        count_loss = self.count_loss_fn(predicted_count_logits, target_count_labels) # CrossEntropy loss calculation
        total_loss = mean_coord_loss + self.count_loss_weight * count_loss
        if not torch.isfinite(total_loss): total_loss = torch.tensor(0.0, requires_grad=True, device=predicted_coords_norm.device); mean_coord_loss = torch.tensor(0.0); count_loss = torch.tensor(0.0)
        return total_loss, mean_coord_loss, count_loss # Return CE loss for count

# --- Instantiate Model, Loss, Optimizer ---
print("Instantiating Seq2Seq model (Classification Count), loss function, and optimizer...")
model = Seq2SeqCoordsTransformer(num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, emb_size=embedding_dimension, nhead=nhead, src_vocab_size=vocab_size, num_count_classes=num_count_classes, tgt_coord_dim=2, dim_feedforward=dim_feedforward, dropout=transformer_dropout, max_text_len=max_text_seq_len, max_coord_len=max_coord_seq_len)

if world_size > 1:
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])



loss_fn = CombinedLossSeq2Seq(count_loss_weight=count_loss_weight) # Passes 100.0
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.to(device)
model.src_tok_emb = nn.Embedding(vocab_size, embedding_dimension).to(device)
print(f"Model moved to: {device}. Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print('\n')

# --- Training Loop with Early Stopping (Classification Count) ---
print(f"Starting training for up to {num_epochs} epochs with Early Stopping (patience={early_stopping_patience}) on CPU...")
training_stats = []
best_eval_loss = float('inf')
epochs_no_improve = 0

epoch_iterator = tqdm(range(num_epochs), desc="Overall Training Progress")
for epoch in epoch_iterator:
    model.train()
    batch_iterator_train = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs} Training", leave=False)
    for batch in batch_iterator_train:
        try:
            input_ids = batch['input_ids'].to(device)
            decoder_input_norm_wp_only = batch['decoder_input_coords_norm'].float().to(device)
            target_output_norm = batch['target_coords_output_norm'].float().to(device)
            target_cnt_labels = batch['target_count'].long().to(device) # Target is LONG type
            output_coord_mask = batch['coord_mask'].float().to(device)
            tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = create_mask(input_ids, target_output_norm, pad_token_id, device) # Use target shape for mask
            optimizer.zero_grad()
            predicted_coords_norm, predicted_count_logits = model(src_input_ids=input_ids, tgt_input_coords_norm=decoder_input_norm_wp_only, src_mask=None, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
            loss, coord_loss_norm, count_loss = loss_fn(predicted_coords_norm, predicted_count_logits, target_output_norm, target_cnt_labels, output_coord_mask) # Pass logits/labels
            if torch.isfinite(loss) and loss > 0:
                loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0); optimizer.step()
                batch_iterator_train.set_postfix({'loss': f"{loss.item():.4f}", 'coord_N': f"{coord_loss_norm.item():.4f}", 'count_CE': f"{count_loss.item():.4f}"}) # Use count_CE
        except Exception as e: print(f"\nERROR training batch: {e}\n{traceback.format_exc()}"); continue

    # --- Evaluation Phase ---
    model.eval()
    eval_losses, eval_coord_losses_norm, eval_count_losses = [], [], []
    batch_iterator_eval = tqdm(eval_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs} Evaluation", leave=False)
    with torch.no_grad():
        for batch in batch_iterator_eval:
            try:
                input_ids = batch['input_ids'].to(device)
                decoder_input_norm_wp_only = batch['decoder_input_coords_norm'].float().to(device)
                target_output_norm = batch['target_coords_output_norm'].float().to(device)
                target_cnt_labels = batch['target_count'].long().to(device) # Target is LONG type
                output_coord_mask = batch['coord_mask'].float().to(device)
                tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = create_mask(input_ids, target_output_norm, pad_token_id, device)
                predicted_coords_norm, predicted_count_logits = model(src_input_ids=input_ids, tgt_input_coords_norm=decoder_input_norm_wp_only, src_mask=None, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
                loss, coord_loss_norm, count_loss = loss_fn(predicted_coords_norm, predicted_count_logits, target_output_norm, target_cnt_labels, output_coord_mask) # Pass logits/labels
                if torch.isfinite(loss): eval_losses.append(loss.item()); eval_coord_losses_norm.append(coord_loss_norm.item()); eval_count_losses.append(count_loss.item())
                batch_iterator_eval.set_postfix({'loss': f"{loss.item():.4f}", 'coord_N': f"{coord_loss_norm.item():.4f}", 'count_CE': f"{count_loss.item():.4f}"}) # Use count_CE
            except Exception as e: print(f"\nERROR eval batch: {e}\n{traceback.format_exc()}"); continue

    avg_eval_loss = np.mean(eval_losses) if eval_losses else float('inf')
    avg_eval_coord_loss_norm = np.mean(eval_coord_losses_norm) if eval_coord_losses_norm else float('inf')
    avg_eval_count_loss = np.mean(eval_count_losses) if eval_count_losses else float('inf') # Avg CrossEntropy loss
    print(f"\n--- Epoch {epoch + 1}/{num_epochs} Eval Summary ---")
    print(f"  Avg Eval Loss: {avg_eval_loss:.4f} (CoordNorm: {avg_eval_coord_loss_norm:.4f}, CountCE: {avg_eval_count_loss:.4f})") # Use CountCE
    training_stats.append({'epoch': epoch + 1, 'eval_loss': avg_eval_loss, 'eval_coord_loss_norm': avg_eval_coord_loss_norm, 'eval_count_loss_ce': avg_eval_count_loss}) # Use count_CE
    epoch_iterator.set_postfix({'Avg Eval Loss': f"{avg_eval_loss:.4f}", 'Avg CoordNorm Loss': f"{avg_eval_coord_loss_norm:.4f}", 'Avg CountCE Loss': f"{avg_eval_count_loss:.4f}"}) # Use CountCE

    # --- Early Stopping Check ---
    if avg_eval_loss < best_eval_loss - min_delta:
        best_eval_loss = avg_eval_loss; epochs_no_improve = 0
        try: torch.save(model.state_dict(), best_model_save_path); print(f"  New best model saved (Eval Loss: {best_eval_loss:.4f})")
        except Exception as e: print(f"  ERROR saving best model: {e}")
    else:
        epochs_no_improve += 1; print(f"  No improvement in eval loss for {epochs_no_improve} epoch(s).")
    if epochs_no_improve >= early_stopping_patience:
        print(f"\n--- Early stopping triggered after {epoch + 1} epochs ---"); break

print("\n--- Training loop finished ---")
print(f"Best validation loss achieved: {best_eval_loss:.4f}")
print('\n')

# --- Load the best model state before saving/deploying ---
print(f"\nLoading best model state from {best_model_save_path} for final steps...")
try:
    if os.path.exists(best_model_save_path): state_dict = torch.load(best_model_save_path, map_location=device); model.load_state_dict(state_dict); print("Loaded best model weights.")
    else: print(f"Warning: Best model checkpoint not found. Using state from last epoch.")
except Exception as e: print(f"ERROR loading best model state: {e}. Using state from last epoch.")

# --- Saving the model Locally ---
print("\nSaving best model locally...")
model_save_path = "./flight_plan_seq2seq_clf_model_final" # New path name
os.makedirs(model_save_path, exist_ok=True)
try:
    torch.save(model.state_dict(), os.path.join(model_save_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(model_save_path)
    config_to_save = {"vocab_size": vocab_size, "emb_size": embedding_dimension, "nhead": nhead, "num_encoder_layers": num_encoder_layers, "num_decoder_layers": num_decoder_layers, "dim_feedforward": dim_feedforward, "dropout": transformer_dropout, "max_text_len": max_text_seq_len, "max_coord_len": max_coord_seq_len, "max_waypoints": max_waypoints,
                      "num_count_classes": num_count_classes, # Save num classes info
                      "architecture": model.__class__.__name__}
    with open(os.path.join(model_save_path, "config.json"), "w") as f: json.dump(config_to_save, f, indent=4)
    print(f"Model saved to {model_save_path}")
except Exception as e: print(f"ERROR saving model locally: {e}")

# --- Deployment to Hugging Face Hub ---
print(f"\n--- Attempting Deployment of Best Model to Hugging Face Hub: {hf_repo_id} ---")
# WARNING: This will overwrite the target repo ID with the new Seq2Seq model!
if HfApi and hf_hub_download:
    try:
        print(f"Creating/accessing repository '{hf_repo_id}'...")
        create_repo(hf_repo_id, private=False, exist_ok=True)
        api = HfApi()
        # --- README Generation (Updated for Classification Count) ---
        print("Generating README.md content...")
        readme_content = f"""---
license: apache-2.0
tags:
  - flight-planning
  - transformer
  - coordinate-prediction
  - sequence-to-sequence
  - count-classification
---
# Flight Plan Coordinate Prediction Model ({model.__class__.__name__})
Encoder-Decoder Transformer model trained for AI flight planning project. Predicts normalized coordinates directly and waypoint count via classification.
## Model Description
{model.__class__.__name__} architecture using `torch.nn.Transformer`. Predicts normalized lat/lon coordinates autoregressively and waypoint count (0-{max_waypoints}) via classification head on encoder output.
* Embed Dim: {embedding_dimension}, Heads: {nhead}, Enc Layers: {num_encoder_layers}, Dec Layers: {num_decoder_layers}, Max Waypoints: {max_waypoints}
## Intended Use
Research prototype. **Not for real-world navigation.**
## Limitations
Accuracy depends on data/tuning. Fixed max waypoints ({max_waypoints}). Not certified. **Architecture differs significantly from previous versions in this repo.**
## How to Use
Requires loading the custom `{model.__class__.__name__}` class and weights. Generation requires autoregressive decoding and taking argmax of count logits.
## Training Data
Trained on `{dataset_name}` - https://huggingface.co/datasets/frankmorales2020/flight_plan_waypoints.
## Contact
Frank Morales, BEng, MEng, SMIEEE (Boeing ATF) - https://www.linkedin.com/in/frank-morales1964/"""
        try:
            with open("README.md", "w", encoding="utf-8") as f: f.write(readme_content)
            print("Uploading README.md..."); api.upload_file(path_or_fileobj="README.md", path_in_repo="README.md", repo_id=hf_repo_id, repo_type="model", commit_message="Update README (Seq2Seq Clf Count)"); os.remove("README.md"); print("README.md uploaded.")
        except Exception as e: print(f"ERROR creating/uploading README.md: {e}")

        print(f"Uploading model files from {model_save_path}...")
        api.upload_folder(folder_path=model_save_path, repo_id=hf_repo_id, repo_type="model", commit_message=f"Upload trained {model.__class__.__name__} (Seq2Seq, Clf Count)")
        print(f"Model files uploaded: https://huggingface.co/{hf_repo_id}")
    except Exception as e: print(f"ERROR deploying to HF Hub: {e}")
else: print("Skipping deployment: huggingface_hub library/login unavailable.")
print('\n')

Using Coord Scaling: Lat (-90.0, 90.0), Lon (-180.0, 180.0)
Using max_waypoints: 10 => Num Count Classes: 11
Max Decoder Seq Len: 11
Loading tokenizer: gpt2
Added 1 special tokens: {'pad_token': '[PAD]'}
Tokenizer vocabulary size: 50258
SOS ID: 50256, EOS ID: 50256, PAD ID: 50257


Loading dataset: frankmorales2020/flight_plan_waypoints
Dataset loaded.
Using PAD_COORD_NORM: [0.49999999722222227, 0.4999999986111111]


Defining data preprocessing function...
Applying preprocessing (with augmentation for training set)...


Map: 100%
 200/200 [00:00<00:00, 4845.77 examples/s]
Preprocessing complete.
Using Train: 1600, Validation: 200, Test (for loss): 200 samples.


Creating DataLoaders...
Loaders created (Train/Eval/Test batches): 100 / 13 / 13
Defining the Seq2SeqCoordsTransformer model (Classification Count Head)...


Defining the CombinedLoss function (Classification Count Loss)...


Instantiating Seq2Seq model (Classification Count), loss function, and optimizer...
Model moved to: cuda. Parameters: 20,244,237


Starting training for up to 20 epochs with Early Stopping (patience=10) on CPU...
Overall Training Progress: 100%
 20/20 [01:35<00:00,  4.75s/it, Avg Eval Loss=0.9785, Avg CoordNorm Loss=0.0170, Avg CountCE Loss=1.9230]

--- Epoch 1/20 Eval Summary ---
  Avg Eval Loss: 1.1007 (CoordNorm: 0.0222, CountCE: 2.1571)
  New best model saved (Eval Loss: 1.1007)

--- Epoch 2/20 Eval Summary ---
  Avg Eval Loss: 1.1003 (CoordNorm: 0.0221, CountCE: 2.1564)
  New best model saved (Eval Loss: 1.1003)

--- Epoch 3/20 Eval Summary ---
  Avg Eval Loss: 1.0853 (CoordNorm: 0.0215, CountCE: 2.1276)
  New best model saved (Eval Loss: 1.0853)

--- Epoch 4/20 Eval Summary ---
  Avg Eval Loss: 1.0790 (CoordNorm: 0.0216, CountCE: 2.1147)
  New best model saved (Eval Loss: 1.0790)

--- Epoch 5/20 Eval Summary ---
  Avg Eval Loss: 1.0780 (CoordNorm: 0.0209, CountCE: 2.1141)
  New best model saved (Eval Loss: 1.0780)

--- Epoch 6/20 Eval Summary ---
  Avg Eval Loss: 1.0649 (CoordNorm: 0.0209, CountCE: 2.0880)
  New best model saved (Eval Loss: 1.0649)

--- Epoch 7/20 Eval Summary ---
  Avg Eval Loss: 1.0502 (CoordNorm: 0.0203, CountCE: 2.0598)
  New best model saved (Eval Loss: 1.0502)

--- Epoch 8/20 Eval Summary ---
  Avg Eval Loss: 1.0443 (CoordNorm: 0.0204, CountCE: 2.0480)
  New best model saved (Eval Loss: 1.0443)

--- Epoch 9/20 Eval Summary ---
  Avg Eval Loss: 1.0308 (CoordNorm: 0.0189, CountCE: 2.0239)
  New best model saved (Eval Loss: 1.0308)

--- Epoch 10/20 Eval Summary ---
  Avg Eval Loss: 1.0245 (CoordNorm: 0.0191, CountCE: 2.0107)
  New best model saved (Eval Loss: 1.0245)

--- Epoch 11/20 Eval Summary ---
  Avg Eval Loss: 1.0196 (CoordNorm: 0.0180, CountCE: 2.0032)
  New best model saved (Eval Loss: 1.0196)

--- Epoch 12/20 Eval Summary ---
  Avg Eval Loss: 1.0165 (CoordNorm: 0.0196, CountCE: 1.9939)
  New best model saved (Eval Loss: 1.0165)

--- Epoch 13/20 Eval Summary ---
  Avg Eval Loss: 0.9984 (CoordNorm: 0.0180, CountCE: 1.9609)
  New best model saved (Eval Loss: 0.9984)

--- Epoch 14/20 Eval Summary ---
  Avg Eval Loss: 0.9899 (CoordNorm: 0.0181, CountCE: 1.9437)
  New best model saved (Eval Loss: 0.9899)

--- Epoch 15/20 Eval Summary ---
  Avg Eval Loss: 0.9968 (CoordNorm: 0.0175, CountCE: 1.9587)
  No improvement in eval loss for 1 epoch(s).

--- Epoch 16/20 Eval Summary ---
  Avg Eval Loss: 0.9980 (CoordNorm: 0.0192, CountCE: 1.9576)
  No improvement in eval loss for 2 epoch(s).

--- Epoch 17/20 Eval Summary ---
  Avg Eval Loss: 0.9875 (CoordNorm: 0.0165, CountCE: 1.9420)
  New best model saved (Eval Loss: 0.9875)

--- Epoch 18/20 Eval Summary ---
  Avg Eval Loss: 0.9863 (CoordNorm: 0.0162, CountCE: 1.9402)
  New best model saved (Eval Loss: 0.9863)

--- Epoch 19/20 Eval Summary ---
  Avg Eval Loss: 0.9842 (CoordNorm: 0.0168, CountCE: 1.9349)
  New best model saved (Eval Loss: 0.9842)

--- Epoch 20/20 Eval Summary ---
  Avg Eval Loss: 0.9785 (CoordNorm: 0.0170, CountCE: 1.9230)
  New best model saved (Eval Loss: 0.9785)

--- Training loop finished ---
Best validation loss achieved: 0.9785



Loading best model state from ./best_seq2seq_model_clf_count.bin for final steps...
Loaded best model weights.

Saving best model locally...
Model saved to ./flight_plan_seq2seq_clf_model_final

--- Attempting Deployment of Best Model to Hugging Face Hub: frankmorales2020/FlightPlan_Transformer_LLM_1GPU_Colab ---
Creating/accessing repository 'frankmorales2020/FlightPlan_Transformer_LLM_1GPU_Colab'...
Generating README.md content...
Uploading README.md...
README.md uploaded.
Uploading model files from ./flight_plan_seq2seq_clf_model_final...
pytorch_model.bin: 100%
 81.2M/81.2M [00:09<00:00, 7.33MB/s]
Model files uploaded: https://huggingface.co/frankmorales2020/FlightPlan_Transformer_LLM_1GPU_Colab



## Evaluation - Seq2SeqCoordsTransformer

In [None]:
from warnings import simplefilter
simplefilter(action='ignore')

# --- Model Loading and Test Set Evaluation (Refactored for Seq2Seq Clf Count) ---
print("\n--- Loading Model and Evaluating on Test Set ---")
print('\n')

# --- Generation Function (Updated for Classification Count Head) ---
def generate_flight_plan_seq2seq(trained_model, tokenizer_instance, query_text, device_instance, max_len=max_coord_seq_len):
    trained_model.eval(); trained_model.to(device_instance)
    try:
        inputs = tokenizer_instance(query_text, return_tensors='pt', padding='longest', truncation=True, max_length=max_text_seq_len)
        src_input_ids = inputs['input_ids'].to(device_instance); src_padding_mask = (src_input_ids == pad_token_id)
        with torch.no_grad():
            # >>> FIX: Encode now returns logits <<<
            memory, predicted_count_logits = trained_model.encode(src_input_ids, src_padding_mask)
            # >>> FIX: Get predicted count index from logits <<<
            pred_count_index = torch.argmax(predicted_count_logits, dim=1).item() # This is the predicted count (0 to max_waypoints)

        # Start decoding with SOS embedding
        decoder_input_embeddings = trained_model.sos_embedding.repeat(1, 1, 1) # (N=1, 1, E)
        generated_denorm_coords_list = []

        for step in range(max_waypoints): # Generate up to max_waypoints coordinates
            current_tgt_len = decoder_input_embeddings.size(1)
            tgt_mask = generate_square_subsequent_mask(current_tgt_len, device_instance) # (T, T)
            memory_key_padding_mask = src_padding_mask # (N, S)
            tgt_padding_mask = torch.zeros(1, current_tgt_len, dtype=torch.bool, device=device_instance) # (N, T)

            tgt_emb_with_pe = trained_model.pos_encoder_dec(decoder_input_embeddings) # (N, T, E)
            decoder_output = trained_model.transformer.decoder(tgt_emb_with_pe, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_key_padding_mask=tgt_padding_mask)
            predicted_norm_coord_next = torch.sigmoid(trained_model.coord_output_proj(decoder_output[:, -1:, :])) # (1, 1, 2) Normalized

            # Denormalize prediction to store
            #predicted_denorm_coord_next = denormalize_coords(predicted_norm_coord_next.squeeze(0).cpu().numpy().tolist())
            predicted_denorm_coord_next = denormalize_coords(predicted_norm_coord_next.detach().squeeze(0).cpu().numpy().tolist())
            new_denorm_coord_value = predicted_denorm_coord_next[0]
            if isinstance(new_denorm_coord_value, list) and len(new_denorm_coord_value) == 2: generated_denorm_coords_list.append(new_denorm_coord_value)
            else: print(f"Warning: Invalid denorm coord predicted: {new_denorm_coord_value}"); break

            # Prepare next input embedding
            next_input_emb = trained_model.coord_input_proj(predicted_norm_coord_next) # Project normalized prediction
            decoder_input_embeddings = torch.cat([decoder_input_embeddings, next_input_emb], dim=1)

            # Stop if we've generated the predicted number of points
            if len(generated_denorm_coords_list) >= pred_count_index: break

        final_waypoints = generated_denorm_coords_list
        # Return the predicted count index (which is the count) and raw logits if needed
        return final_waypoints, pred_count_index, predicted_count_logits.squeeze().cpu().numpy() # Return index and logits

    except Exception as e: print(f"ERROR generating for query '{query_text}': {e}\n{traceback.format_exc()}"); return [], 0, np.array([])


# --- Load the model ---
load_from_hf = False # Defaulting to False
model_load_id = hf_repo_id if load_from_hf else model_save_path
loaded_model = None; loaded_tokenizer = None
print(f"Attempting to load Seq2Seq model from: {model_load_id} (Using {'HF Hub' if load_from_hf else 'Local Path'})")
print('\n')
#print(device)


if hf_hub_download:
    try:
        if load_from_hf:
             if "YOUR_USERNAME" in model_load_id or model_load_id == "frankmorales2020/FlightPlan_Transformer_LLM": raise ValueError(f"Refusing to load potentially incompatible model from '{model_load_id}'. Update hf_repo_id or set load_from_hf=False.")
             tokenizer_load_path = model_load_id; config_load_path = hf_hub_download(repo_id=model_load_id, filename="config.json"); weights_load_path = hf_hub_download(repo_id=model_load_id, filename="pytorch_model.bin")
        else: # Loading from local
             tokenizer_load_path = model_save_path; config_load_path = os.path.join(model_save_path, "config.json"); weights_load_path = best_model_save_path
             if not os.path.exists(weights_load_path): print(f"Warning: Best model file {weights_load_path} not found, trying final..."); weights_load_path = os.path.join(model_save_path, "pytorch_model.bin")
             if not all(os.path.exists(p) for p in [config_load_path, weights_load_path, os.path.join(tokenizer_load_path, 'tokenizer_config.json')]): raise FileNotFoundError(f"Required model files not found locally.")

        loaded_tokenizer = AutoTokenizer.from_pretrained(tokenizer_load_path)
        with open(config_load_path, 'r') as f: config_dict = json.load(f)
        expected_arch = "Seq2SeqCoordsTransformer"; loaded_arch = config_dict.get("architecture")
        if loaded_arch != expected_arch: print(f"\n>>> WARNING: Config architecture ('{loaded_arch}') != Expected ('{expected_arch}'). Ensure correct model type is loaded. <<<\n")

        # Use loaded config values, providing defaults
        loaded_model = Seq2SeqCoordsTransformer(
            num_encoder_layers=config_dict.get('num_encoder_layers', num_encoder_layers),
            num_decoder_layers=config_dict.get('num_decoder_layers', num_decoder_layers),
            emb_size=config_dict.get('emb_size', embedding_dimension),
            nhead=config_dict.get('nhead', nhead),
            src_vocab_size=len(loaded_tokenizer),
            # Get num_count_classes from config
            num_count_classes=config_dict.get('num_count_classes', num_count_classes),
            tgt_coord_dim=2,
            dim_feedforward=config_dict.get('dim_feedforward', dim_feedforward),
            dropout=config_dict.get('dropout', transformer_dropout),
            max_text_len=config_dict.get('max_text_len', max_text_seq_len),
            max_coord_len=config_dict.get('max_coord_len', max_coord_seq_len)
        )
        loaded_model.to(device)

        state_dict = torch.load(weights_load_path, map_location=device)
        # Handle embedding resize AFTER model instantiation and moving to device
        current_tokenizer_vocab_size = len(loaded_tokenizer)
        if state_dict.get('src_tok_emb.weight') is not None and state_dict['src_tok_emb.weight'].size(0) != current_tokenizer_vocab_size:
            print(f"Resizing embedding weights from {state_dict['src_tok_emb.weight'].size(0)} to {current_tokenizer_vocab_size}")
            loaded_model.src_tok_emb = nn.Embedding(current_tokenizer_vocab_size, embedding_dimension).to(device)
            new_emb = loaded_model.src_tok_emb.weight.data
            common_size = min(state_dict['src_tok_emb.weight'].size(0), new_emb.size(0))
            new_emb[:common_size, :] = state_dict['src_tok_emb.weight'][:common_size, :]
            state_dict['src_tok_emb.weight'] = new_emb
        elif 'src_tok_emb.weight' not in state_dict:
             print("Warning: src_tok_emb.weight not found in state_dict. Initializing embedding layer.")
             loaded_model.src_tok_emb = nn.Embedding(current_tokenizer_vocab_size, embedding_dimension).to(device)
        else: # Ensure model's embedding layer matches state dict if no resize needed
             loaded_model.src_tok_emb = nn.Embedding(current_tokenizer_vocab_size, embedding_dimension).to(device)

        # Load state dict - use strict=False due to potential architecture changes or saved optimizer states
        load_result = loaded_model.load_state_dict(state_dict, strict=False)
        print(f"Model load result (strict=False): Missing keys: {load_result.missing_keys}, Unexpected keys: {load_result.unexpected_keys}")
        loaded_model.eval(); print("Model loading successful.")

    except Exception as e: print(f"\n>>> ERROR loading model from {model_load_id}: {e}\n{traceback.format_exc()}"); loaded_model = None
else: print("Skipping model loading: huggingface_hub library not available.")


# --- Run Inference Loop and Calculate Loss on Test Set (Classification Count) ---
if loaded_model and loaded_tokenizer:
    print("\nRunning inference and loss calculation on the test set using Seq2Seq model...")
    test_results = []
    test_iterator_batches = tqdm(test_dataloader_for_loss, desc="Processing Test Set Batches for Loss")
    #test_iterator_batches = tqdm(train_dataloader, desc="Processing Train Set Batches for Loss")


    #train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    #eval_dataloader = DataLoader(eval_data, batch_size=batch_size, drop_last=False)
    #test_dataloader_for_loss = DataLoader(test_data_processed_for_loss, batch_size=batch_size, drop_last=False)

    total_test_samples_loss = 0
    test_coord_losses_norm = []
    test_count_losses_ce = []

    loaded_model.eval()
    with torch.no_grad():
        for batch in test_iterator_batches:
            try:

                # Get predictions
                tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = create_mask(input_ids, target_output_norm, pad_token_id, device)
                predicted_coords_norm, predicted_count_logits = model(src_input_ids=input_ids, tgt_input_coords_norm=decoder_input_norm_wp_only, src_mask=None, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask)

                # Calculate loss using logits and labels
                loss, coord_loss_norm, count_loss_ce = loss_fn(predicted_coords_norm, predicted_count_logits, target_output_norm, target_cnt_labels, output_coord_mask)

                # Accumulate losses
                if torch.isfinite(coord_loss_norm):
                    test_coord_losses_norm.append(coord_loss_norm.item() * input_ids.size(0))
                if torch.isfinite(count_loss_ce):
                    test_count_losses_ce.append(count_loss_ce.item() * input_ids.size(0))
                total_test_samples_loss += input_ids.size(0)

                # Display progress
                test_iterator_batches.set_postfix({'batch_coord_norm_loss': f"{coord_loss_norm.item():.4f}", 'batch_count_CE_loss': f"{count_loss_ce.item():.4f}"})

            except Exception as e:
                print(f"\nERROR processing test batch for loss: {e}")
                continue

    # --- Loop for Generation Metrics ---
    print("\nCalculating generation metrics on test samples...")
    original_test_set_for_gen = original_test_set_for_comparison
    test_iterator_samples = tqdm(range(len(original_test_set_for_gen)), desc="Generating Test Samples")
    total_count_diff_gen = 0
    total_test_samples_gen = len(original_test_set_for_gen)
    count_correct = 0  # Track number of perfectly predicted counts

    for i in test_iterator_samples:
        try:
            sample = original_test_set_for_gen[i]
            query = sample.get('input', '')
            actual_waypoints_raw = sample.get('waypoints', [])

            # Handle waypoints (convert to list if necessary)
            if isinstance(actual_waypoints_raw, np.ndarray):
                actual_waypoints = actual_waypoints_raw.tolist()
            elif isinstance(actual_waypoints_raw, list):
                actual_waypoints = actual_waypoints_raw
            else:
                actual_waypoints = []

            # Get actual count label (handle errors)
            try:
                actual_count_label = int(round(float(sample.get('label', 0))))
                actual_count_label = max(0, min(max_waypoints, actual_count_label))
            except (ValueError, TypeError):
                actual_count_label = 0

            # Skip empty queries
            if not query:
                total_test_samples_gen -= 1
                continue

            # Generate flight plan and calculate count difference
            pred_waypoints, pred_count_index, pred_count_logits = generate_flight_plan_seq2seq(loaded_model, loaded_tokenizer, query, device)
            count_diff = abs(pred_count_index - actual_count_label)
            total_count_diff_gen += count_diff

            #print("\n--- Generated Flight Plan ---")
            #print(f"Query: {query}")
            #print(f"Predicted Waypoints: {pred_waypoints}")
            #print(f"Actual Waypoints: {actual_waypoints}")
            #print(f"Predicted Count: {pred_count_index}")
            #print(f"Actual Count: {actual_count_label}")
            #print('\n')

            # Check count accuracy
            if pred_count_index == actual_count_label:
                count_correct += 1

            # Store results
            test_results.append({'query': query, 'predicted_waypoints': pred_waypoints, 'predicted_count': pred_count_index, 'actual_count': actual_count_label})

            # Display progress
            test_iterator_samples.set_postfix({'avg_count_diff': f"{total_count_diff_gen / (i + 1):.2f}"})

        except Exception as e:
            print(f"\nERROR generating for test sample {i}: {e}")
            total_test_samples_gen -= 1
            continue

    # --- Calculate and Print Overall Metrics ---
    avg_test_coord_loss_norm = np.sum(test_coord_losses_norm) / total_test_samples_loss if total_test_samples_loss > 0 else 0
    avg_test_count_loss_ce = np.sum(test_count_losses_ce) / total_test_samples_loss if total_test_samples_loss > 0 else 0
    avg_test_count_difference = total_count_diff_gen / total_test_samples_gen if total_test_samples_gen > 0 else 0
    test_count_accuracy = count_correct / total_test_samples_gen if total_test_samples_gen > 0 else 0

    print(f"\n--- Final Test Set Evaluation Summary ---")
    print(f"  Average Absolute Count Difference: {avg_test_count_difference:.4f}")
    print(f"  Count Prediction Accuracy:         {test_count_accuracy:.4f}")
    print(f"  Average Coordinate Loss (MSE, Normalized): {avg_test_coord_loss_norm:.4f}")
    print(f"  Average Count Loss (CrossEntropy):       {avg_test_count_loss_ce:.4f}")

else:
    print("\nSkipping test set evaluation: model/tokenizer loading failed or unavailable.")

print("\n--- Script Finished ---")


--- Loading Model and Evaluating on Test Set ---


Attempting to load Seq2Seq model from: ./flight_plan_seq2seq_clf_model_final (Using Local Path)


Model load result (strict=False): Missing keys: [], Unexpected keys: []
Model loading successful.

Running inference and loss calculation on the test set using Seq2Seq model...
Processing Test Set Batches for Loss: 100%
 13/13 [00:00<00:00, 51.48it/s, batch_coord_norm_loss=0.0169, batch_count_CE_loss=1.9580]

Calculating generation metrics on test samples...
Generating Test Samples: 100%
 200/200 [00:08<00:00, 24.10it/s, avg_count_diff=1.35]

--- Final Test Set Evaluation Summary ---
  Average Absolute Count Difference: 1.3550
  Count Prediction Accuracy:         0.1800
  Average Coordinate Loss (MSE, Normalized): 0.0169
  Average Count Loss (CrossEntropy):       1.9580

--- Script Finished ---