In [1]:
%matplotlib inline

from utils import (
    create_mmsi_dict_from_file,
    filter_stationary_ships,
    segment_and_renumber,
    haversine_m,
    prepare_training_data,
)

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Fecthing data

In [2]:
file_name = "data/mmsi_type.txt"
mmsi_map = create_mmsi_dict_from_file(file_name)

if mmsi_map:
    print("--- Successfully created dictionary ---")

--- Successfully created dictionary ---


In [5]:
df = pd.read_csv("data/ais_combined.csv")
df_with_types = df.copy()
df_with_types['Type'] = df_with_types['MMSI'].astype(str).map(mmsi_map)
allowed_type = ['Cargo ship', 'Cargo ship (HAZ-A)', 'Cargo ship (HAZ-B)', 'Cargo ship (HAZ-D)', 'Tanker', 'Tanker (HAZ-A)', 'Tanker (HAZ-B)', 'Tanker (HAZ-C)', 'Tanker (HAZ-D)']
df_cargo = df_with_types[df_with_types['Type'].isin(allowed_type)]

df_cargo = df_cargo.drop(columns=["Type"], axis= 1)

df_cargo_filtered = filter_stationary_ships(df_cargo) # This df has dropped stationary ships

Found 6 stationary ships out of 291.
Cleaned DF contains 285 ships.


# Data preprocessing

In [6]:
# Configuration Parameters
GAP_BREAK_MIN = 180
INTERPOLATION_LIMIT_MIN = None
MAX_DISTANCE_M = 3000
MAX_SOG_KNOTS = 40
OUTPUT_PATH = "data/ais_data_5min_clean.csv"
NUM_COLS = ["SOG", "COG", "Longtitude", "Latitude"]
MIN_SEGMENT_LENGTH = 25
INTERVAL = 5

In [8]:
print("="*60)
print("STEP 1: Data Preprocessing")
print("="*60)

# Sort data by MMSI and Timestamp
df_cargo = df_cargo.sort_values(["MMSI", "Timestamp"]).reset_index(drop=True)
df_cargo["Timestamp"] = pd.to_datetime(df_cargo["Timestamp"], errors="coerce")

print(f"Initial data shape: {df_cargo.shape}")
print(f"Data types:\n{df_cargo.dtypes}\n")

# Segment trajectories based on time gaps
print("Segmenting trajectories...")
df = segment_and_renumber(df_cargo, GAP_BREAK_MIN)

# Downsample & interpolate per segment
print(f"Downsampling to {INTERVAL}-minute intervals and interpolating...")
results = []

for (mmsi, seg), g in df.groupby(["MMSI", "Segment"], observed=True):
    g = g.set_index("Timestamp")
    
    # Downsample to 5-minute intervals (keep last observation)
    g1 = g.resample(f"{INTERVAL}min").last()
    
    # Interpolate numeric columns for short gaps only
    g1[NUM_COLS] = g1[NUM_COLS].interpolate(
        method="time", limit=INTERPOLATION_LIMIT_MIN, limit_direction="both"
    )
    
    # Drop rows where ANY of the critical columns are NaN
    g1 = g1.dropna(subset=NUM_COLS, how="any")
    
    # Fill identifiers
    g1["MMSI"] = mmsi
    g1["Segment"] = seg
    
    # Calculate distance and speed between consecutive points
    lat = g1["Latitude"].to_numpy()
    lon = g1["Longtitude"].to_numpy()
    lat_prev, lon_prev = np.roll(lat, 1), np.roll(lon, 1)
    lat_prev[0], lon_prev[0] = lat[0], lon[0]
    
    g1["distance_m"] = haversine_m(lat, lon, lat_prev, lon_prev)
    g1.loc[g1.index[0], "distance_m"] = 0.0
    g1["speed_mps_track"] = g1["distance_m"] / 60.0
    
    # Filter unrealistic movement or SOG
    g1 = g1[(g1["distance_m"] < MAX_DISTANCE_M) & (g1["SOG"] <= MAX_SOG_KNOTS)]
    
    results.append(g1)

# Combine all segments
df_clean = pd.concat(results).reset_index()

STEP 1: Data Preprocessing
Initial data shape: (1235258, 7)
Data types:
MMSI                   int64
SOG                  float64
COG                  float64
Longtitude           float64
Latitude             float64
Timestamp     datetime64[ns]
Segment                int64
dtype: object

Segmenting trajectories...
Downsampling to 5-minute intervals and interpolating...


In [9]:
print("="*60)
print("STEP 2: Data Quality Check")
print("="*60)
print(f"Rows before cleaning: {len(df_clean)}")

# Check for missing data
missing = df_clean[df_clean[["SOG", "COG", "Latitude", "Longtitude"]].isna().any(axis=1)]
print(f"Rows with missing numeric data: {len(missing)} ({len(missing)/len(df_clean)*100:.2f}%)")
print(f"MMSI with missing data: {missing['MMSI'].nunique()}")

# Remove rows with missing critical data
df_clean = df_clean.dropna(subset=["SOG", "COG", "Latitude", "Longtitude", "MMSI", "Segment"])
print(f"Rows after cleaning: {len(df_clean)}")

# Verify time gaps before re-segmentation
max_gap = df_clean.groupby(["MMSI","Segment"])["Timestamp"].diff().dt.total_seconds().div(60).max()
print(f"Maximum time gap before re-segmentation: {max_gap:.2f} minutes")

# Re-segment based on time gaps created by filtering (gaps > INTERVAL min = new segment)
print(f"\nRe-segmenting based on gaps > {INTERVAL} minutes...")
df_clean = df_clean.sort_values(["MMSI", "Segment", "Timestamp"]).reset_index(drop=True)
df_clean["time_gap"] = df_clean.groupby(["MMSI", "Segment"])["Timestamp"].diff().dt.total_seconds().div(60)
df_clean["new_seg"] = (df_clean["time_gap"] > INTERVAL) | (df_clean["time_gap"].isna())
df_clean["Segment"] = df_clean.groupby("MMSI")["new_seg"].cumsum()
df_clean = df_clean.drop(columns=["time_gap", "new_seg"])

# Verify time gaps after re-segmentation
max_gap_after = df_clean.groupby(["MMSI","Segment"])["Timestamp"].diff().dt.total_seconds().div(60).max()
print(f"Maximum time gap after re-segmentation: {max_gap_after:.2f} minutes")
has_large_gaps = (df_clean.groupby(["MMSI","Segment"])["Timestamp"]
                  .diff().dt.total_seconds().div(60).max() > INTERVAL).any()
print(f"Has gaps > {INTERVAL} minutes: {has_large_gaps}")

STEP 2: Data Quality Check
Rows before cleaning: 51903
Rows with missing numeric data: 0 (0.00%)
MMSI with missing data: 0
Rows after cleaning: 51903
Maximum time gap before re-segmentation: 250.00 minutes

Re-segmenting based on gaps > 5 minutes...
Maximum time gap after re-segmentation: 5.00 minutes
Has gaps > 5 minutes: False


In [10]:
print("="*60)
print("STEP 3: Segment Length Filtering")
print("="*60)
# Filter out short segments
print(f"\nFiltering segments with < {MIN_SEGMENT_LENGTH} points...")
print(f"Segments before filtering: {df_clean.groupby(['MMSI', 'Segment']).ngroups}")
print(f"Rows before filtering: {len(df_clean)}")

segment_sizes = df_clean.groupby(["MMSI", "Segment"]).size()
valid_segments = segment_sizes[segment_sizes >= MIN_SEGMENT_LENGTH].index
df_clean = df_clean.set_index(["MMSI", "Segment"]).loc[valid_segments].reset_index()

print(f"Segments after filtering: {df_clean.groupby(['MMSI', 'Segment']).ngroups}")
print(f"Rows after filtering: {len(df_clean)}")

STEP 3: Segment Length Filtering

Filtering segments with < 25 points...
Segments before filtering: 453
Rows before filtering: 51903
Segments after filtering: 309
Rows after filtering: 51023


In [11]:
print("\n" + "="*60)
print("STEP 4: Final Dataset Summary")
print("="*60)
print(f"Total rows: {len(df_clean)}")
print(f"Unique vessels (MMSI): {df_clean['MMSI'].nunique()}")
print(f"Total segments: {df_clean.groupby(['MMSI', 'Segment']).ngroups}")
print(f"Average segment length: {df_clean.groupby(['MMSI', 'Segment']).size().mean()*INTERVAL:.1f} minutes")
print(f"Columns: {list(df_clean.columns)}")

# Save cleaned data
df_clean.to_csv(OUTPUT_PATH, index=False)
print(f"\nCleaned data saved to: {OUTPUT_PATH}")



STEP 4: Final Dataset Summary
Total rows: 51023
Unique vessels (MMSI): 279
Total segments: 309
Average segment length: 825.6 minutes
Columns: ['MMSI', 'Segment', 'Timestamp', 'SOG', 'COG', 'Longtitude', 'Latitude', 'distance_m', 'speed_mps_track']

Cleaned data saved to: data/ais_data_5min_clean.csv


# Preparing Dataset

In [12]:
# Configuration for model training
SEQUENCE_LENGTH = 20
FEATURES = ["Latitude", "Longtitude", "SOG", "COG"] 
TARGET_FEATURES = ["Latitude", "Longtitude", "SOG", "COG"]
MIN_SEGMENT_LENGTH = SEQUENCE_LENGTH + 5

In [13]:
#Prepare sequences
X, y, segment_info = prepare_training_data(
    df_clean,
    SEQUENCE_LENGTH,
    FEATURES,
    TARGET_FEATURES,
    MIN_SEGMENT_LENGTH
)
print(f"Total sequences created: {len(X)}")
print(f"Input shape: {X.shape}")
print(f"Target shape: {y.shape}")
print(f"Segments used: {len(segment_info)}")
print(f"Average sequences per segment: {len(X) / len(segment_info):.1f}")

# Display some statistics
segment_lengths = [s['length'] for s in segment_info]
print(f"\nSegment statistics:")
print(f"  Min length: {min(segment_lengths)*INTERVAL} minutes")
print(f"  Max length: {max(segment_lengths)*INTERVAL} minutes")
print(f"  Mean length: {np.mean(segment_lengths)*INTERVAL:.1f} minutes")
print(f"  Median length: {np.median(segment_lengths)*INTERVAL:.1f} minutes")

Total sequences created: 44843
Input shape: (44843, 20, 4)
Target shape: (44843, 4)
Segments used: 44843
Average sequences per segment: 1.0

Segment statistics:
  Min length: 125 minutes
  Max length: 1440 minutes
  Mean length: 1149.0 minutes
  Median length: 1370.0 minutes


In [14]:
# Split data by ships (MMSI)
print("="*60)
print("Splitting Data by Ships (MMSI)")
print("="*60)

# Get unique MMSIs from segment_info
unique_mmsis = list(set([seg['mmsi'] for seg in segment_info]))
n_ships = len(unique_mmsis)

print(f"Total unique ships: {n_ships}")

# Split ships into train (64%), val (16%), test (20%)
mmsi_temp, mmsi_test = train_test_split(
    unique_mmsis, test_size=0.2, random_state=42, shuffle=True
)

mmsi_train, mmsi_val = train_test_split(
    mmsi_temp, test_size=0.2, random_state=42, shuffle=True
)

print(f"\nShips in training set: {len(mmsi_train)} ({len(mmsi_train)/n_ships*100:.1f}%)")
print(f"Ships in validation set: {len(mmsi_val)} ({len(mmsi_val)/n_ships*100:.1f}%)")
print(f"Ships in test set: {len(mmsi_test)} ({len(mmsi_test)/n_ships*100:.1f}%)")

# Create sets of MMSIs for fast lookup
mmsi_train_set = set(mmsi_train)
mmsi_val_set = set(mmsi_val)
mmsi_test_set = set(mmsi_test)

# Split sequences based on which ship they belong to
train_indices = [i for i, seg in enumerate(segment_info) if seg['mmsi'] in mmsi_train_set]
val_indices = [i for i, seg in enumerate(segment_info) if seg['mmsi'] in mmsi_val_set]
test_indices = [i for i, seg in enumerate(segment_info) if seg['mmsi'] in mmsi_test_set]

# Get the actual sequences for each set (RAW)
X_train_raw = X[train_indices]
y_train_raw = y[train_indices]

X_val_raw = X[val_indices]
y_val_raw = y[val_indices]

X_test_raw = X[test_indices]
y_test_raw = y[test_indices]

print(f"\nSequences in training set: {X_train_raw.shape[0]} ({X_train_raw.shape[0]/X.shape[0]*100:.1f}%)")
print(f"Sequences in validation set: {X_val_raw.shape[0]} ({X_val_raw.shape[0]/X.shape[0]*100:.1f}%)")
print(f"Sequences in test set: {X_test_raw.shape[0]} ({X_test_raw.shape[0]/X.shape[0]*100:.1f}%)")

Splitting Data by Ships (MMSI)
Total unique ships: 279

Ships in training set: 178 (63.8%)
Ships in validation set: 45 (16.1%)
Ships in test set: 56 (20.1%)

Sequences in training set: 28586 (63.7%)
Sequences in validation set: 7689 (17.1%)
Sequences in test set: 8568 (19.1%)


In [15]:
# Normalize the data
print("="*60)
print("Normalizing Data")
print("="*60)

# Reshape X_train for normalization
n_samples_train, n_timesteps, n_features = X_train_raw.shape
X_train_reshaped = X_train_raw.reshape(-1, n_features)

# Fit scaler on training data
scaler_X = StandardScaler()
X_train_normalized_reshaped = scaler_X.fit_transform(X_train_reshaped)
X_train = X_train_normalized_reshaped.reshape(n_samples_train, n_timesteps, n_features)

# Transform val and test
X_val_reshaped = X_val_raw.reshape(-1, n_features)
X_val = scaler_X.transform(X_val_reshaped).reshape(X_val_raw.shape[0], n_timesteps, n_features)

X_test_reshaped = X_test_raw.reshape(-1, n_features)
X_test = scaler_X.transform(X_test_reshaped).reshape(X_test_raw.shape[0], n_timesteps, n_features)

# Normalize targets
scaler_y = StandardScaler()
y_train = scaler_y.fit_transform(y_train_raw)
y_val = scaler_y.transform(y_val_raw)
y_test = scaler_y.transform(y_test_raw)

print(f"Input data normalized")
print(f"Target data normalized")
print(f"\nFeature means: {scaler_X.mean_}")
print(f"Feature stds: {scaler_X.scale_}")
print(f"\nTarget means: {scaler_y.mean_}")
print(f"Target stds: {scaler_y.scale_}")

Normalizing Data
Input data normalized
Target data normalized

Feature means: [ 56.00486371  11.13368651   3.58015288 174.77718504]
Feature stds: [  1.09773542   2.05185728   2.78642971 106.00857452]

Target means: [ 56.00520574  11.14744614   3.57007037 174.82785785]
Target stds: [  1.09808968   2.05125326   2.78513621 105.8189964 ]
