# Bi-LSTM GNSS Spoofing Detection

A complete workflow for GNSS spoofing detection using a Bi-LSTM model in PyTorch. This notebook covers data loading, preprocessing, feature engineering, sequence creation, model training, evaluation, and export.

## 1. Environment Setup and Imports

Import required libraries for GNSS spoofing detection.

In [1]:
# Environment setup and imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# Sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report

print(f"PyTorch version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")
print(f"GPU Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Numpy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print("All packages loaded successfully")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 2.9.1+cu126
GPU Available: True
GPU Device: NVIDIA GeForce RTX 4060 Laptop GPU
Numpy version: 1.26.0
Pandas version: 2.3.3
All packages loaded successfully
Using device: cuda


## 2. GNSS Data Loading

Load the GNSS dataset from CSV or parquet files. Display basic statistics and sample rows.

In [2]:
# GNSS dataset configuration
from pathlib import Path
import json
import pandas as pd
import os
import glob

# Update this path to your actual GNSS dataset directory
gnss_data_dir = Path("./GNSS3/Processed data")
print(f"GNSS data directory exists: {gnss_data_dir.exists()}")

obs_files = []
pvt_files = []
sat_files = []
# List JSON files for observation, pvtSolution, satelliteInfomation
for root, dirs, files in os.walk(gnss_data_dir):
    for dir in dirs: # dirs 21...30
        obs_files.extend(sorted(glob.glob(str(gnss_data_dir / dir / "observation*.json"))))
        pvt_files.extend(sorted(glob.glob(str(gnss_data_dir / dir / "pvtSolution*.json"))))
        sat_files.extend(sorted(glob.glob(str(gnss_data_dir / dir / "satelliteInfomation*.json"))))

print(f"Found {len(obs_files)} observation JSON files")
print(f"Found {len(pvt_files)} pvtSolution JSON files")
print(f"Found {len(sat_files)} satelliteInfomation JSON files")

# Load a sample of each type for inspection
def load_json_file(filepath):
    with open(filepath, "r") as f:
        return json.load(f)

# Inspect first file of each type
sample_obs = load_json_file(obs_files[0]) if obs_files else None
sample_pvt = load_json_file(pvt_files[0]) if pvt_files else None
sample_sat = load_json_file(sat_files[0]) if sat_files else None

print("Sample observation JSON keys:", list(sample_obs.keys()) if sample_obs else "None")
print("Sample pvtSolution JSON keys:", list(sample_pvt.keys()) if sample_pvt else "None")
print("Sample satelliteInfomation JSON keys:", list(sample_sat.keys()) if sample_sat else "None")

# Convert all observation JSONs to DataFrame (example for observation)
obs_data = []
for f in obs_files:
    data = load_json_file(f)
    if isinstance(data, dict):
        obs_data.append(data)
    elif isinstance(data, list):
        obs_data.extend(data)
gnss_df = pd.DataFrame(obs_data)
print("GNSS DataFrame loaded from observation JSONs:")
print(f"  Records: {len(gnss_df):,}")
print(f"  Columns: {list(gnss_df.columns)}")
print(gnss_df.head())

GNSS data directory exists: True
Found 240 observation JSON files
Found 240 pvtSolution JSON files
Found 240 satelliteInfomation JSON files
Sample observation JSON keys: ['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1', 'cpMes_G1', 'cn0_G1', 'prStd_G1', 'cpStd_G1', 'doStd_G1', 'prMes_G2', 'doMes_G2', 'cpMes_G2', 'cn0_G2', 'prStd_G2', 'cpStd_G2', 'doStd_G2', 'prMes_E1', 'doMes_E1', 'cpMes_E1', 'cn0_E1', 'prStd_E1', 'cpStd_E1', 'doStd_E1', 'prMes_E2', 'doMes_E2', 'cpMes_E2', 'cn0_E2', 'prStd_E2', 'cpStd_E2', 'doStd_E2', 'prMes_B1', 'doMes_B1', 'cpMes_B1', 'cn0_B1', 'prStd_B1', 'cpStd_B1', 'doStd_B1', 'prMes_B2', 'doMes_B2', 'cpMes_B2', 'cn0_B2', 'prStd_B2', 'cpStd_B2', 'doStd_B2', 'prMes_Q1', 'doMes_Q1', 'cpMes_Q1', 'cn0_Q1', 'prStd_Q1', 'cpStd_Q1', 'doStd_Q1', 'prMes_Q2', 'doMes_Q2', 'cpMes_Q2', 'cn0_Q2', 'prStd_Q2', 'cpStd_Q2', 'doStd_Q2', 'prMes_R1', 'doMes_R1', 'cpMes_R1', 'cn0_R1', 'prStd_R1', 'cpStd_R1', 'doStd_R1', 'prMes_R2', 'doMes_R2', 'cpMes_R2', 'cn0_R

MemoryError: 

In [None]:
!nvidia-smi

In [None]:
# --- GNSS Data EDA (Exploratory Data Analysis) ---
print(f"GNSS DataFrame shape: {gnss_df.shape}")
print(f"GNSS DataFrame columns: {list(gnss_df.columns)}")
print("\nHead:")
print(gnss_df.head())
print("\nDescribe:")
print(gnss_df.describe(include='all'))

# Check for unique device or identifier columns
for col in ['device_id', 'mmsi']:
    if col in gnss_df.columns:
        print(f"Unique {col} values: {gnss_df[col].nunique()}")
        print(f"{col} values: {gnss_df[col].unique()}")

# Check timestamp column
if 'timestamp' in gnss_df.columns:
    try:
        gnss_df['datetime'] = pd.to_datetime(gnss_df['timestamp'], errors='coerce')
        print(f"\nDate range: {gnss_df['datetime'].min()} to {gnss_df['datetime'].max()}")
    except Exception as e:
        print(f"Error converting timestamp: {e}")

# Value counts and non-zero checks for key columns
for col in ['speed', 'course', 'lat', 'lon']:
    if col in gnss_df.columns:
        print(f"\n{col} value counts (top 10):")
        print(gnss_df[col].value_counts().head(10))
        print(f"\n{col} summary:")
        print(gnss_df[col].describe())
        if gnss_df[col].dtype in [np.float64, np.int64]:
            print(f"Non-zero {col}:", (gnss_df[col] != 0).sum())

GNSS DataFrame shape: (24, 77)
GNSS DataFrame columns: ['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1', 'cpMes_G1', 'cn0_G1', 'prStd_G1', 'cpStd_G1', 'doStd_G1', 'prMes_G2', 'doMes_G2', 'cpMes_G2', 'cn0_G2', 'prStd_G2', 'cpStd_G2', 'doStd_G2', 'prMes_E1', 'doMes_E1', 'cpMes_E1', 'cn0_E1', 'prStd_E1', 'cpStd_E1', 'doStd_E1', 'prMes_E2', 'doMes_E2', 'cpMes_E2', 'cn0_E2', 'prStd_E2', 'cpStd_E2', 'doStd_E2', 'prMes_B1', 'doMes_B1', 'cpMes_B1', 'cn0_B1', 'prStd_B1', 'cpStd_B1', 'doStd_B1', 'prMes_B2', 'doMes_B2', 'cpMes_B2', 'cn0_B2', 'prStd_B2', 'cpStd_B2', 'doStd_B2', 'prMes_Q1', 'doMes_Q1', 'cpMes_Q1', 'cn0_Q1', 'prStd_Q1', 'cpStd_Q1', 'doStd_Q1', 'prMes_Q2', 'doMes_Q2', 'cpMes_Q2', 'cn0_Q2', 'prStd_Q2', 'cpStd_Q2', 'doStd_Q2', 'prMes_R1', 'doMes_R1', 'cpMes_R1', 'cn0_R1', 'prStd_R1', 'cpStd_R1', 'doStd_R1', 'prMes_R2', 'doMes_R2', 'cpMes_R2', 'cn0_R2', 'prStd_R2', 'cpStd_R2', 'doStd_R2', 'scenario']

Head:


Label the data with attack type

In [None]:
import numpy as np

# Example: label by file path
def get_label_from_filepath(filepath):
    if "Spoofing" in str(filepath):
        return "spoofing"
    elif "Jamming" in str(filepath):
        return "jamming"
    else:
        return "clean"

# Apply labels to the DataFrame
gnss_df["scenario"] = [get_label_from_filepath(f) for f in obs_files]


In [None]:
all_dfs = []

for day_folder in range(21, 31):  # days 21 to 30
    path = f'processed/{day_folder}'
    obs_files = glob.glob(f'{path}/observation*.json')
    pvt_files = glob.glob(f'{path}/pvtSolution*.json')
    sat_files = glob.glob(f'{path}/satelliteInformation*.json')

    obs_data = [load_json_file(f) for f in obs_files]
    pvt_data = [load_json_file(f) for f in pvt_files]
    sat_data = [load_json_file(f) for f in sat_files]

    obs_df = pd.DataFrame(obs_data)
    pvt_df = pd.DataFrame(pvt_data)
    sat_df = pd.DataFrame(sat_data)

    merged = obs_df.merge(pvt_df, on="recordTime", how="left")
    merged = merged.merge(sat_df, on="recordTime", how="left")

    all_dfs.append(merged)

full_df = pd.concat(all_dfs)
full_df['recordTime'] = pd.to_datetime(full_df['recordTime'])
full_df.set_index('recordTime', inplace=True)


In [None]:
print(gnss_df['recordTime'].apply(type).value_counts())
print(pvt_df['recordTime'].apply(type).value_counts())
print(sat_df['recordTime'].apply(type).value_counts())


recordTime
<class 'list'>    24
Name: count, dtype: int64
recordTime
<class 'list'>    24
Name: count, dtype: int64
recordTime
<class 'list'>    24
Name: count, dtype: int64


recordTime column in all three DataFrames is a list, so flatten

In [None]:
# Flatten recordTime lists to single values
gnss_df['recordTime'] = gnss_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)
pvt_df['recordTime'] = pvt_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)
sat_df['recordTime'] = sat_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)


In [None]:
gnss_df['recordTime'] = pd.to_datetime(gnss_df['recordTime'])
pvt_df['recordTime'] = pd.to_datetime(pvt_df['recordTime'])
sat_df['recordTime'] = pd.to_datetime(sat_df['recordTime'])

Merge with PVT or satellite info

In [None]:
# Load PVT and Satellite info DataFrames (similar to obs_data)
pvt_data = [load_json_file(f) for f in pvt_files]
pvt_df = pd.DataFrame(pvt_data)

sat_data = [load_json_file(f) for f in sat_files]
sat_df = pd.DataFrame(sat_data)

# Make sure recordTime is scalar (flatten lists if needed)
gnss_df['recordTime'] = gnss_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)
pvt_df['recordTime'] = pvt_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)
sat_df['recordTime'] = sat_df['recordTime'].apply(lambda x: x[0] if isinstance(x, list) else x)

# Optional: convert to datetime for consistent merging
gnss_df['recordTime'] = pd.to_datetime(gnss_df['recordTime'])
pvt_df['recordTime'] = pd.to_datetime(pvt_df['recordTime'])
sat_df['recordTime'] = pd.to_datetime(sat_df['recordTime'])

# Merge with GNSS observation by timestamp
merged_df = gnss_df.merge(pvt_df, on="recordTime", how="left")
merged_df = merged_df.merge(sat_df, on="recordTime", how="left")

# Inspect merged DataFrame
print("Merged DataFrame columns:", merged_df.columns)
print(f"Records after merge: {len(merged_df):,}")


Merged DataFrame columns: Index(['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1',
       'cpMes_G1', 'cn0_G1',
       ...
       'qualityInd_Q', 'health_Q', 'svId_R', 'svUsed_R', 'cno_R', 'elev_R',
       'azim_R', 'prRes_R', 'qualityInd_R', 'health_R'],
      dtype='object', length=147)
Records after merge: 24


Since merged_df comes from GNSS, PVT, and satellite JSONs, likely we have columns where missing values mean “data not available” rather than “carry the last value forward.”

In [None]:
missing_summary = merged_df.isna().sum().sort_values(ascending=False)
print(missing_summary[missing_summary > 0])

Series([], dtype: int64)


In [None]:
# # Forward fill / backward fill only numeric columns
# numeric_cols = merged_df.select_dtypes(include='number').columns
# merged_df[numeric_cols] = merged_df[numeric_cols].fillna(method='ffill')
# merged_df[numeric_cols] = merged_df[numeric_cols].fillna(method='bfill')

# # For object or list columns, you may choose to fill with a placeholder
# object_cols = merged_df.select_dtypes(include='object').columns
# merged_df[object_cols] = merged_df[object_cols].fillna('missing')


In [None]:
# missing_summary = merged_df.isna().sum().sort_values(ascending=False)
# print(missing_summary[missing_summary > 0])


In [None]:
print(merged_df.dtypes)
merged_df['recordTime'] = pd.to_datetime(merged_df['recordTime'])


recordTime      datetime64[ns]
VSG                     object
VSE                     object
VSB                     object
VSQ                     object
                     ...      
elev_R                  object
azim_R                  object
prRes_R                 object
qualityInd_R            object
health_R                object
Length: 147, dtype: object


In [None]:
# Check for duplicate timestamps
duplicate_count = merged_df.duplicated(subset=['recordTime']).sum()
print(f"Number of duplicate recordTime entries: {duplicate_count}")

Number of duplicate recordTime entries: 0


In [None]:
# Check first few entries
print(merged_df['recordTime'].head())

# Check the type of each entry
print(merged_df['recordTime'].map(type).value_counts())

# Optional: try parsing with errors='coerce' to see if any fail
invalid_times = pd.to_datetime(merged_df['recordTime'], errors='coerce').isna().sum()
print(f"Number of invalid timestamps: {invalid_times}")


0   2023-09-30 00:00:00
1   2023-09-30 01:00:00
2   2023-09-30 10:00:00
3   2023-09-30 11:00:00
4   2023-09-30 12:00:00
Name: recordTime, dtype: datetime64[ns]
recordTime
<class 'pandas._libs.tslibs.timestamps.Timestamp'>    24
Name: count, dtype: int64
Number of invalid timestamps: 0


In [None]:
merged_df.set_index('recordTime', inplace=True)

In [None]:
print(merged_df.index)       # shows the current index object
print(merged_df.index.name)  # name of the index column, if any

DatetimeIndex(['2023-09-30 00:00:00', '2023-09-30 01:00:00',
               '2023-09-30 10:00:00', '2023-09-30 11:00:00',
               '2023-09-30 12:00:00', '2023-09-30 13:00:00',
               '2023-09-30 14:00:00', '2023-09-30 15:00:00',
               '2023-09-30 16:00:00', '2023-09-30 17:00:00',
               '2023-09-30 18:00:00', '2023-09-30 19:00:00',
               '2023-09-30 02:00:00', '2023-09-30 20:00:00',
               '2023-09-30 21:00:00', '2023-09-30 22:00:00',
               '2023-09-30 23:00:00', '2023-09-30 03:00:00',
               '2023-09-30 04:00:00', '2023-09-30 05:00:00',
               '2023-09-30 06:00:00', '2023-09-30 07:00:00',
               '2023-09-30 08:00:00', '2023-09-30 09:00:00'],
              dtype='datetime64[ns]', name='recordTime', freq=None)
recordTime


In [None]:
print(merged_df.info())        # check column types, memory usage
print(merged_df.describe())    # numeric stats
print(merged_df.head())        # first few rows

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 24 entries, 2023-09-30 00:00:00 to 2023-09-30 09:00:00
Columns: 146 entries, VSG to health_R
dtypes: object(146)
memory usage: 27.6+ KB
None
                                                      VSG  \
count                                                  24   
unique                                                 24   
top     [[0.0, 2.0, 0.0, 4.0, 0.0, 0.0, 7.0, 8.0, 0.0,...   
freq                                                    1   

                                                      VSE  \
count                                                  24   
unique                                                 24   
top     [[0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   
freq                                                    1   

                                                      VSB  \
count                                                  24   
unique                                                 24   
top     [[0

## 3. Data Cleaning and Preprocessing

Apply domain-specific cleaning rules: remove invalid coordinates, handle missing values, and filter outliers.

In [None]:
# Print columns of gnss_df to debug KeyError
print("gnss_df columns:", list(gnss_df.columns))

gnss_df columns: ['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1', 'cpMes_G1', 'cn0_G1', 'prStd_G1', 'cpStd_G1', 'doStd_G1', 'prMes_G2', 'doMes_G2', 'cpMes_G2', 'cn0_G2', 'prStd_G2', 'cpStd_G2', 'doStd_G2', 'prMes_E1', 'doMes_E1', 'cpMes_E1', 'cn0_E1', 'prStd_E1', 'cpStd_E1', 'doStd_E1', 'prMes_E2', 'doMes_E2', 'cpMes_E2', 'cn0_E2', 'prStd_E2', 'cpStd_E2', 'doStd_E2', 'prMes_B1', 'doMes_B1', 'cpMes_B1', 'cn0_B1', 'prStd_B1', 'cpStd_B1', 'doStd_B1', 'prMes_B2', 'doMes_B2', 'cpMes_B2', 'cn0_B2', 'prStd_B2', 'cpStd_B2', 'doStd_B2', 'prMes_Q1', 'doMes_Q1', 'cpMes_Q1', 'cn0_Q1', 'prStd_Q1', 'cpStd_Q1', 'doStd_Q1', 'prMes_Q2', 'doMes_Q2', 'cpMes_Q2', 'cn0_Q2', 'prStd_Q2', 'cpStd_Q2', 'doStd_Q2', 'prMes_R1', 'doMes_R1', 'cpMes_R1', 'cn0_R1', 'prStd_R1', 'cpStd_R1', 'doStd_R1', 'prMes_R2', 'doMes_R2', 'cpMes_R2', 'cn0_R2', 'prStd_R2', 'cpStd_R2', 'doStd_R2']


In [None]:
# Data cleaning for GNSS

def clean_gnss_data(df):
    """
    Clean GNSS data by removing invalid coordinates, handling missing values, and filtering outliers.
    """
    original_len = len(df)
    df = df.copy()
    # Print columns for debugging
    print("[clean_gnss_data] Columns:", list(df.columns))
    # Remove invalid lat/lon
    df = df[(df['lat'].between(-90, 90)) & (df['lon'].between(-180, 180))]
    # Drop rows with missing critical fields
    df = df.dropna(subset=['lat', 'lon', 'timestamp'])
    # Remove outliers in speed (e.g., > 100 m/s)
    if 'speed' in df.columns:
        df = df[df['speed'] <= 100]
    removed = original_len - len(df)
    print(f"Data Cleaning:")
    print(f"  Original records: {original_len:,}")
    print(f"  After cleaning: {len(df):,}")
    print(f"  Removed: {removed:,} ({removed/original_len*100:.2f}%)")
    return df

gnss_clean = clean_gnss_data(gnss_df)
print(gnss_clean.describe())

[clean_gnss_data] Columns: ['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1', 'cpMes_G1', 'cn0_G1', 'prStd_G1', 'cpStd_G1', 'doStd_G1', 'prMes_G2', 'doMes_G2', 'cpMes_G2', 'cn0_G2', 'prStd_G2', 'cpStd_G2', 'doStd_G2', 'prMes_E1', 'doMes_E1', 'cpMes_E1', 'cn0_E1', 'prStd_E1', 'cpStd_E1', 'doStd_E1', 'prMes_E2', 'doMes_E2', 'cpMes_E2', 'cn0_E2', 'prStd_E2', 'cpStd_E2', 'doStd_E2', 'prMes_B1', 'doMes_B1', 'cpMes_B1', 'cn0_B1', 'prStd_B1', 'cpStd_B1', 'doStd_B1', 'prMes_B2', 'doMes_B2', 'cpMes_B2', 'cn0_B2', 'prStd_B2', 'cpStd_B2', 'doStd_B2', 'prMes_Q1', 'doMes_Q1', 'cpMes_Q1', 'cn0_Q1', 'prStd_Q1', 'cpStd_Q1', 'doStd_Q1', 'prMes_Q2', 'doMes_Q2', 'cpMes_Q2', 'cn0_Q2', 'prStd_Q2', 'cpStd_Q2', 'doStd_Q2', 'prMes_R1', 'doMes_R1', 'cpMes_R1', 'cn0_R1', 'prStd_R1', 'cpStd_R1', 'doStd_R1', 'prMes_R2', 'doMes_R2', 'cpMes_R2', 'cn0_R2', 'prStd_R2', 'cpStd_R2', 'doStd_R2']


KeyError: 'lat'

Output: [clean_gnss_data] Columns: ['recordTime', 'VSG', 'VSE', 'VSB', 'VSQ', 'VSR', 'prMes_G1', 'doMes_G1', 'cpMes_G1', 'cn0_G1', 'prStd_G1', 'cpStd_G1', 'doStd_G1', 'prMes_G2', 'doMes_G2', 'cpMes_G2', 'cn0_G2', 'prStd_G2', 'cpStd_G2', 'doStd_G2', 'prMes_E1', 'doMes_E1', 'cpMes_E1', 'cn0_E1', 'prStd_E1', 'cpStd_E1', 'doStd_E1', 'prMes_E2', 'doMes_E2', 'cpMes_E2', 'cn0_E2', 'prStd_E2', 'cpStd_E2', 'doStd_E2', 'prMes_B1', 'doMes_B1', 'cpMes_B1', 'cn0_B1', 'prStd_B1', 'cpStd_B1', 'doStd_B1', 'prMes_B2', 'doMes_B2', 'cpMes_B2', 'cn0_B2', 'prStd_B2', 'cpStd_B2', 'doStd_B2', 'prMes_Q1', 'doMes_Q1', 'cpMes_Q1', 'cn0_Q1', 'prStd_Q1', 'cpStd_Q1', 'doStd_Q1', 'prMes_Q2', 'doMes_Q2', 'cpMes_Q2', 'cn0_Q2', 'prStd_Q2', 'cpStd_Q2', 'doStd_Q2', 'prMes_R1', 'doMes_R1', 'cpMes_R1', 'cn0_R1', 'prStd_R1', 'cpStd_R1', 'doStd_R1', 'prMes_R2', 'doMes_R2', 'cpMes_R2', 'cn0_R2', 'prStd_R2', 'cpStd_R2', 'doStd_R2']

## 4. Feature Engineering for GNSS Spoofing Detection

Extract features such as latitude, longitude, speed, course, time-based features, and calculate movement deltas. Fill NaN values and display feature matrix.

In [None]:
len(gnss_df["recordTime"]), type(gnss_df["recordTime"])

(24, pandas.core.series.Series)

In [None]:
# Feature extraction for GNSS spoofing detection

def extract_gnss_features(df):
    df = df.copy()
    features = ['lat', 'lon']
    if 'speed' in df.columns:
        features.append('speed')
    if 'course' in df.columns:
        features.append('course')
    # Time-based features
    df['recordTime'] = pd.to_datetime(df['recordTime'])
    df['hour'] = df['recordTime'].dt.hour
    df['day_of_week'] = df['recordTime'].dt.dayofweek
    features.extend(['hour', 'day_of_week'])
    # Movement deltas
    df = df.sort_values(['device_id', 'recordTime']) if 'device_id' in df.columns else df.sort_values('recordTime')
    df['lat_diff'] = df['lat'].diff()
    df['lon_diff'] = df['lon'].diff()
    df['distance'] = np.sqrt(df['lat_diff']**2 + df['lon_diff']**2)
    features.append('distance')
    # Fill NaN values
    df[features] = df[features].fillna(method='bfill').fillna(method='ffill').fillna(0)
    print(f"Feature Extraction Complete:")
    print(f"  Features: {features}")
    print(f"  Feature matrix shape: {df[features].shape}")
    return df, features

# gnss_featured, feature_cols = extract_gnss_features(gnss_clean)
gnss_featured, feature_cols = extract_gnss_features(gnss_df)
print(gnss_featured[feature_cols].head(10))

TypeError: <class 'list'> is not convertible to datetime, at position 0

## 5. Sequence Creation for LSTM

Create temporal sequences from the feature matrix for LSTM input. Use sliding windows and assign spoofing labels to sequences.

In [None]:
# Sequence creation for LSTM

def create_gnss_sequences(df, feature_cols, sequence_length=128, stride=32, label_col='is_spoofed'):
    X_list, y_list = [], []
    # If device_id exists, group by device, else treat as single trajectory
    group_key = 'device_id' if 'device_id' in df.columns else None
    groups = df.groupby(group_key) if group_key else [(None, df)]
    for _, group in groups:
        group = group.sort_values('timestamp')
        features = group[feature_cols].values
        labels = group[label_col].values if label_col in group.columns else np.zeros(len(group))
        if len(features) < sequence_length:
            continue
        for i in range(0, len(features) - sequence_length + 1, stride):
            seq_x = features[i:i+sequence_length]
            seq_y = labels[i:i+sequence_length]
            # Label as spoofed if majority of points are spoofed
            is_spoofed = int(seq_y.sum() > (sequence_length // 2))
            X_list.append(seq_x)
            y_list.append(is_spoofed)
    X_array = np.array(X_list, dtype=np.float32)
    y_array = np.array(y_list, dtype=np.uint8)
    print(f"Sequence Creation:")
    print(f"  Sequence length: {sequence_length}")
    print(f"  Stride: {stride}")
    print(f"  Total sequences: {len(X_array):,}")
    print(f"  Spoofed sequences: {(y_array == 1).sum():,}")
    print(f"  Genuine sequences: {(y_array == 0).sum():,}")
    return X_array, y_array

# Example: If you have spoofing labels, add a column 'is_spoofed' to gnss_featured, else all zeros
gnss_featured['is_spoofed'] = 0  # Replace with actual labels if available
SEQUENCE_LENGTH = 128
STRIDE = 32
X_sequences, y_labels = create_gnss_sequences(gnss_featured, feature_cols, sequence_length=SEQUENCE_LENGTH, stride=STRIDE, label_col='is_spoofed')

## 6. Train/Test Split

Split the sequences into training, validation, and test sets. Ensure no data leakage by splitting by device or trajectory if possible.

In [None]:
# Train/test/validation split
# If possible, split by device_id to avoid leakage. Otherwise, use stratified random split.

X_train, X_test, y_train, y_test = train_test_split(
    X_sequences, y_labels, test_size=0.2, random_state=42, stratify=y_labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.125, random_state=42, stratify=y_train  # 0.125 * 0.8 = 0.1 of total
)
print(f"Data Split:")
print(f"  Train: {len(X_train):,} sequences")
print(f"  Val:   {len(X_val):,} sequences")
print(f"  Test:  {len(X_test):,} sequences")
print(f"  Train - Genuine: {(y_train==0).sum()}, Spoofed: {(y_train==1).sum()}")
print(f"  Val   - Genuine: {(y_val==0).sum()}, Spoofed: {(y_val==1).sum()}")
print(f"  Test  - Genuine: {(y_test==0).sum()}, Spoofed: {(y_test==1).sum()}")

## 7. Bi-LSTM Model Definition (PyTorch)

Define the Bi-LSTM model architecture in PyTorch, including input, LSTM layers, and output layer.

In [None]:
# Bi-LSTM model definition for GNSS spoofing detection
class BiLSTMModel(nn.Module):
    def __init__(self, input_size, lstm_units_1=62, lstm_units_2=30):
        super(BiLSTMModel, self).__init__()
        self.bilstm1 = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_units_1,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        self.bilstm2 = nn.LSTM(
            input_size=lstm_units_1 * 2,
            hidden_size=lstm_units_2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        self.fc = nn.Linear(lstm_units_2 * 2, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        lstm1_out, _ = self.bilstm1(x)
        lstm2_out, _ = self.bilstm2(lstm1_out)
        last_hidden = lstm2_out[:, -1, :]
        output = self.fc(last_hidden)
        output = self.sigmoid(output)
        return output

def build_bilstm_model(input_size, lstm_units_1=62, lstm_units_2=30):
    model = BiLSTMModel(input_size, lstm_units_1, lstm_units_2)
    return model.to(device)

print("Building Bi-LSTM Model (PyTorch)...")
print(f"  Input size: {X_train.shape[2]} features")
model = build_bilstm_model(input_size=X_train.shape[2])
print(model)

## 8. Model Training Loop

Implement the training loop with weighted loss for class imbalance, early stopping, and metric calculation.

In [None]:
# Training configuration
BATCH_SIZE = 30
EPOCHS = 50
LEARNING_RATE = 0.004
PATIENCE = 10

# Class weights for imbalance
n_genuine = (y_train == 0).sum()
n_spoofed = (y_train == 1).sum()
total_samples = len(y_train)
weight_genuine = total_samples / (2.0 * n_genuine)
weight_spoofed = total_samples / (2.0 * n_spoofed) if n_spoofed > 0 else 1.0

print(f"Class Distribution in Training Set:")
print(f"  Genuine: {n_genuine:,} ({n_genuine/total_samples*100:.2f}%)")
print(f"  Spoofed: {n_spoofed:,} ({n_spoofed/total_samples*100:.2f}%)")
print(f"Class Weights:")
print(f"  Genuine weight: {weight_genuine:.4f}")
print(f"  Spoofed weight: {weight_spoofed:.4f}")

# Data loaders
train_dataset = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float().reshape(-1, 1))
val_dataset = TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float().reshape(-1, 1))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Weighted BCE loss
pos_weight = torch.tensor([weight_spoofed / weight_genuine]).to(device) if n_spoofed > 0 else torch.tensor([1.0]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if n_spoofed > 0 else nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop with early stopping
history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': [], 'train_precision': [], 'val_precision': [], 'train_recall': [], 'val_recall': []}
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

print(f"\nTraining Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Early stopping patience: {PATIENCE}")
print(f"  Loss function: BCEWithLogitsLoss (weighted)")
print(f"  Device: {device}")
print(f"\nStarting training...\n")

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_x.size(0)
        all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())
        all_labels.extend(batch_y.detach().cpu().numpy())
    return total_loss / len(loader.dataset), np.array(all_preds), np.array(all_labels)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            total_loss += loss.item() * batch_x.size(0)
            all_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())
    return total_loss / len(loader.dataset), np.array(all_preds), np.array(all_labels)

for epoch in range(EPOCHS):
    train_loss, train_preds, train_labels = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_preds, val_labels = validate(model, val_loader, criterion, device)
    train_preds_binary = (train_preds > 0.5).astype(int)
    val_preds_binary = (val_preds > 0.5).astype(int)
    train_acc = (train_preds_binary == train_labels).mean()
    val_acc = (val_preds_binary == val_labels).mean()
    train_prec = precision_score(train_labels, train_preds_binary, zero_division=0)
    val_prec = precision_score(val_labels, val_preds_binary, zero_division=0)
    train_rec = recall_score(train_labels, train_preds_binary, zero_division=0)
    val_rec = recall_score(val_labels, val_preds_binary, zero_division=0)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_accuracy'].append(train_acc)
    history['val_accuracy'].append(val_acc)
    history['train_precision'].append(train_prec)
    history['val_precision'].append(val_prec)
    history['train_recall'].append(train_rec)
    history['val_recall'].append(val_rec)
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Recall: {val_rec:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_model_state = model.state_dict().copy()
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping at epoch {epoch+1}")
            model.load_state_dict(best_model_state)
            break
print("\nTraining Complete!")

## 9. Training History Visualization

Plot training and validation loss, accuracy, precision, and recall over epochs using matplotlib.

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Loss
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Model Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy
axes[1].plot(history['train_accuracy'], label='Train Accuracy', linewidth=2)
axes[1].plot(history['val_accuracy'], label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Model Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Precision & Recall
axes[2].plot(history['train_precision'], label='Train Precision', linewidth=2)
axes[2].plot(history['val_precision'], label='Val Precision', linewidth=2, linestyle='--')
axes[2].plot(history['train_recall'], label='Train Recall', linewidth=2)
axes[2].plot(history['val_recall'], label='Val Recall', linewidth=2, linestyle='--')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Score')
axes[2].set_title('Precision & Recall')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Threshold Optimization

Find the optimal decision threshold on the validation set by maximizing F1-score.

In [None]:
# Threshold optimization on validation set
thresholds = np.arange(0.1, 0.95, 0.05)
best_f1 = 0
best_threshold = 0.5
best_metrics = {}
val_probs_all = []
val_labels_all = []
model.eval()
with torch.no_grad():
    for batch_x, batch_y in val_loader:
        batch_x = batch_x.to(device)
        outputs = model(batch_x).cpu().numpy().flatten()
        val_probs_all.extend(outputs)
        val_labels_all.extend(batch_y.numpy().flatten())
val_probs_all = np.array(val_probs_all)
val_labels_all = np.array(val_labels_all)
print(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("-" * 50)
for threshold in thresholds:
    preds = (val_probs_all > threshold).astype(int)
    f1 = f1_score(val_labels_all, preds, zero_division=0)
    precision = precision_score(val_labels_all, preds, zero_division=0)
    recall = recall_score(val_labels_all, preds, zero_division=0)
    print(f"τ={threshold:<8.2f}   {precision:<12.4f} {recall:<12.4f} {f1:<12.4f}")
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold
        best_metrics = {'f1': f1, 'precision': precision, 'recall': recall}
THRESHOLD = best_threshold
print(f"\nOptimal Threshold: {THRESHOLD:.2f}")
if best_metrics:
    print(f"  F1-Score:  {best_metrics['f1']:.4f}")
    print(f"  Precision: {best_metrics['precision']:.4f}")
    print(f"  Recall:    {best_metrics['recall']:.4f}")

## 11. Model Evaluation on Test Set

Evaluate the trained model on the test set. Calculate accuracy, precision, recall, F1-score, and detection rate.

In [None]:
# Evaluate on test set
BATCH_SIZE_INFERENCE = 256
model.eval()
y_pred_probs = []
test_losses = []
with torch.no_grad():
    for i in range(0, len(X_test), BATCH_SIZE_INFERENCE):
        batch_X = torch.from_numpy(X_test[i:i+BATCH_SIZE_INFERENCE]).float().to(device)
        batch_y = torch.from_numpy(y_test[i:i+BATCH_SIZE_INFERENCE]).float().reshape(-1, 1).to(device)
        batch_probs = model(batch_X)
        y_pred_probs.extend(batch_probs.cpu().numpy())
        batch_loss = criterion(batch_probs, batch_y)
        test_losses.append(batch_loss.item() * len(batch_X))
y_pred_probs = np.array(y_pred_probs).flatten()
y_pred = (y_pred_probs > THRESHOLD).astype(int)
test_loss = sum(test_losses) / len(X_test)
test_acc = (y_pred == y_test).mean()
precision = precision_score(y_test, y_pred, zero_division=0)
recall = recall_score(y_test, y_pred, zero_division=0)
f1 = f1_score(y_test, y_pred, zero_division=0)
detection_rate = recall
print("Test Set Evaluation:")
print(f"  Accuracy:       {test_acc:.4f}")
print(f"  Precision:      {precision:.4f}")
print(f"  Recall:         {recall:.4f}")
print(f"  F1-Score:       {f1:.4f}")
print(f"  Detection Rate: {detection_rate:.4f}")

## 12. Confusion Matrix and Classification Report

Plot the confusion matrix and print the detailed classification report for test predictions.

In [None]:
# Confusion matrix and classification report
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(cm, cmap='Blues')
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(['Genuine', 'Spoofed'])
ax.set_yticklabels(['Genuine', 'Spoofed'])
for i in range(2):
    for j in range(2):
        text = ax.text(j, i, cm[i, j], ha="center", va="center", color="white" if cm[i, j] > cm.max() / 2 else "black", fontsize=16, fontweight='bold')
ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)
ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
print("\nDetailed Classification Report:")
print(classification_report(y_test, y_pred, target_names=['Genuine', 'Spoofed']))

## 13. Model Export (Save Model and Scaler)

Save the trained PyTorch model, feature scaler, and configuration to disk for deployment.

In [None]:
# Save model artifacts
import pickle
output_root = Path("./models")
output_root.mkdir(parents=True, exist_ok=True)
model_path = output_root / 'gnss_bilstm_model.pt'
torch.save(model.state_dict(), model_path)
# Save scaler (if used)
# scaler_path = output_root / 'feature_scaler.pkl'
# with open(scaler_path, 'wb') as f:
#     pickle.dump(feature_scaler, f)
# Save configuration
config = {
    'feature_cols': feature_cols,
    'sequence_length': SEQUENCE_LENGTH,
    'threshold': THRESHOLD,
    'lstm_units_1': 62,
    'lstm_units_2': 30,
    'learning_rate': LEARNING_RATE,
    'batch_size': BATCH_SIZE,
    'input_size': X_train.shape[2],
    'test_metrics': {
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'detection_rate': float(detection_rate),
        'accuracy': float(test_acc)
    }
}
config_path = output_root / 'gnss_model_config.pkl'
with open(config_path, 'wb') as f:
    pickle.dump(config, f)
print(f"Model saved to: {model_path}")
print(f"Config saved to: {config_path}")