In [13]:
import pandas as pd
import torch as pt
import torch.nn as nn

import tqdm
import numpy as np
from ucf_atd_model.data import data_loc
from ucf_atd_model.datasets.create_link_data import calculate_link_features, haversine_distance_m
from ucf_atd_model.datasets.create_20class_data import subset, setidx, boolfilter, paddata, const_data
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score
from ucf_atd_model.datasets.create_link_data import calculate_link_features, haversine_distance_m, project_forward
from datetime import datetime
import scipy as sp
from ucf_atd_model.c20_consts import *

import atd2025

In [2]:
device = pt.device("cpu")

Preprocessing, loading checkpoints

In [3]:
badnames = [x for x in full_names if x.endswith("_16")]

ynames = [x for x in ynames if not x.endswith("_16")]
xnames = [x for x in colnames if x not in badnames]

# Setup the model
inp_dim = len(colnames) - len(badnames) + 1
h_dim = 2000
out_dim = n_norm_classes + 1 - 1

model = nn.Sequential(
    nn.Linear(inp_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, h_dim),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(h_dim, h_dim // 2),
    nn.ReLU(),
    nn.Dropout(0.5),
    # nn.Linear(h_dim // 2, h_dim // 2),
    # nn.ReLU(),
    # nn.Dropout(0.2),
    nn.Linear(h_dim // 2, out_dim),
).to(device)

model.load_state_dict(pt.load("checkpoints/epoch_180.pt", weights_only=True, map_location=device))
model.eval()

Sequential(
  (0): Linear(in_features=231, out_features=2000, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2000, out_features=2000, bias=True)
  (3): ReLU()
  (4): Dropout(p=0.1, inplace=False)
  (5): Linear(in_features=2000, out_features=1000, bias=True)
  (6): ReLU()
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=1000, out_features=17, bias=True)
)

In [9]:
xmean = pt.load("xmean.pt").float()
xstd = pt.load("xstd.pt").float()

Code to run the model

In [10]:
def run_model(df, model):
    """Implements the final ML-Enhanced Tracking algorithm."""
    df = df.sort_values('time').reset_index(drop=True)
    df['track_id'] = -1

    next_track_id = 0
    
    n = df.shape[0]
    lastPtInTrack = {
        "time": np.repeat(pd.Timestamp(year=1970, month=1, day=1, hour=0, minute=0, second=0).to_numpy(), n), 
        "lat": np.repeat(-1.0, n), 
        "lon": np.repeat(-1.0, n), 
        "speed": np.repeat(-1.0, n), 
        "course": np.repeat(-1.0, n),
        "track_id_true": np.repeat(-1, n)
    }

    for i in tqdm.tqdm(range(len(df))):
        p_current = df.iloc[i]

        if next_track_id == 0:
            df.loc[i, 'track_id'] = next_track_id
            setidx(lastPtInTrack, i, p_current)
            next_track_id += 1
            continue


        active_tracks_df = subset(lastPtInTrack, next_track_id)
        
        time_diff = (p_current["time"].to_numpy() - active_tracks_df["time"]).astype("timedelta64[s]").astype("int")

        max_dist_m = time_diff * 30 * 0.5144
        real_dist = haversine_distance_m(active_tracks_df["lat"], active_tracks_df["lon"], p_current["lat"], p_current["lon"])
        
        kinematic_errors = haversine_distance_m(p_current["lat"], p_current["lon"], *project_forward(active_tracks_df['lat'], active_tracks_df['lon'], active_tracks_df['speed'], active_tracks_df['course'], time_diff))
        error_cutoff = np.sort(kinematic_errors)[:n_norm_classes].max()
        kinematic_filter = kinematic_errors < error_cutoff
        
        loc_filter = real_dist < max_dist_m
        timeCorrect: np.ndarray = (0 < time_diff)
        big_filter = loc_filter & timeCorrect & kinematic_filter
        idxs = np.arange(len(timeCorrect))

        # Create data if we find data points within the filters
        if np.any(big_filter):
            all_data = paddata(calculate_link_features(boolfilter(active_tracks_df, big_filter), p_current, eval=True))
            maindata = all_data[normal_features][:-1].to_numpy()
            otherdata = all_data[currpt_features].iloc[0].to_numpy()
            
            # Predict on this data
            toappend = np.zeros((n_norm_classes - 1) * len(normal_features) + len(currpt_features))
            raveled = np.ravel(maindata)
            toappend[:raveled.shape[0]] = raveled
            toappend[raveled.shape[0]:] = otherdata

            tensorIn = ((pt.from_numpy(toappend).float() - xmean) / xstd).to(device)

            modelOut = model(tensorIn).detach().cpu().numpy()
            outMask = np.zeros_like(modelOut)
            outMask[:-1] = np.all(maindata == -1, axis=1) * -1e8
            outMask[-1] = 0
            modelOut = modelOut + outMask
            
            argMaxModelOut = np.argmax(modelOut)

            # Assignment with a confidence threshold
            if argMaxModelOut != len(modelOut) - 1:
                best_match_track_id = None
                try:
                    best_match_track_id = idxs[big_filter][argMaxModelOut]
                except IndexError:
                    print("Bad")
                    best_match_track_id = next_track_id
                    next_track_id += 1

                df.loc[i, 'track_id'] = best_match_track_id
                setidx(lastPtInTrack, best_match_track_id, p_current)
            else:
                df.loc[i, 'track_id'] = next_track_id
                setidx(lastPtInTrack, next_track_id, p_current)
                next_track_id += 1
        else:
            df.loc[i, 'track_id'] = next_track_id
            setidx(lastPtInTrack, next_track_id, p_current)
            next_track_id += 1

    return df[['point_id', 'track_id']]


In [11]:
truth_df = pd.read_csv(data_loc("dataset1_truth.csv"))

truth_df["time"] = pd.to_datetime(truth_df["time"])
truth_df["time"] = truth_df["time"].apply(lambda x: datetime.combine(datetime(1970, 1, 1, 0, 0, 0).date(), x.time()))
truth_df["track_id_true"] = truth_df["track_id"]

output = run_model(truth_df, model)
output.to_csv("ml_out_ds1.csv", index=False)

100%|██████████| 102861/102861 [07:55<00:00, 216.40it/s]


In [14]:
atd2025.accuracy.evaluate_predictions("ml_out_ds1.csv", data_loc("dataset1_truth.csv"))

0.5275323008720506