In [1]:
import torch
import sys
import folium
import matplotlib.pyplot as plt
sys.path.append('../')

from models.model_utils import AttrDict, load_rid_freqs, get_rid_rnfea_dict, load_rn_dict, toseq
from models.models_attn_tandem import Encoder, DecoderMulti, Seq2SeqMulti
from models.multi_train import evaluate, init_weights, train
from common.mbr import MBR
from common.road_network import load_rn_shp
from models.datasets import Dataset, collate_fn

In [2]:
def plot_attention(attention, input_seq, output_seq):
    """
    Plots a heatmap of attention weights.
    
    :param attention: A 2D numpy array of attention weights.
    :param input_seq: List of input tokens.
    :param output_seq: List of output tokens.
    """
    fig, ax = plt.subplots()
    cax = ax.matshow(attention, cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_seq, rotation=90)
    ax.set_yticklabels([''] + output_seq)

    # Show label at every tick
    ax.xaxis.set_major_locator(plt.MultipleLocator(1))
    ax.yaxis.set_major_locator(plt.MultipleLocator(1))

    plt.show()

In [3]:
args = AttrDict()
device = torch.device('cuda')
args_dict = {'module_type': 'debug', 'debug': False, 'device': device, 'load_pretrained_flag': False, 'model_old_path': '', 'train_flag': True, 'test_flag': False, 'attn_flag': True, 'dis_prob_mask_flag': False, 'search_dist': 50, 'beta': 15, 'tandem_fea_flag': False, 'pro_features_flag': False, 'online_features_flag': False, 'rid_fea_dim': 8, 'pro_input_dim': 25, 'pro_output_dim': 8, 'poi_num': 5, 'online_dim': 10, 'poi_type': 'company,food,shopping,viewpoint,house', 'min_lat': 41.15, 'min_lng': -8.633, 'max_lat': 41.153, 'max_lng': -8.63, 'keep_ratio': 1, 'grid_size': 50, 'time_span': 15, 'win_size': 25, 'ds_type': 'random', 'split_flag': True, 'shuffle': True, 'hid_dim': 512, 'id_emb_dim': 128, 'dropout': 0.5, 'id_size': 263, 'lambda1': 10, 'n_epochs': 20, 'batch_size': 128, 'learning_rate': 0.001, 'tf_ratio': 0.5, 'clip': 1, 'log_step': 1, 'nhead': 8, 'nlayers': 2, 'max_xid': 7, 'max_yid': 6}
args.update(args_dict)

In [4]:
enc = Encoder(args)
dec = DecoderMulti(args)
model = Seq2SeqMulti(enc, dec, args).to(device)
model.apply(init_weights)

Seq2SeqMulti(
  (encoder): Encoder(
    (rnn): GRU(3, 512)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): DecoderMulti(
    (emb_id): Embedding(263, 128)
    (tandem_fc): Sequential(
      (0): Linear(in_features=640, out_features=512, bias=True)
      (1): ReLU()
    )
    (attn): Attention(
      (attn): Linear(in_features=1024, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (rnn): GRU(641, 512)
    (fc_id_out): Linear(in_features=512, out_features=263, bias=True)
    (fc_rate_out): Linear(in_features=512, out_features=1, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [5]:
model_path = '../results/debug1_gs_50_lam_10_attn_True_prob_False_fea_False_20240518_102344/'
model.load_state_dict(torch.load(model_path + 'val-best-model.pt'))

<All keys matched successfully>

In [6]:
test_trajs_dir = "../data/model_data/test_data/"
rn_dir = "../data/map/road_network/"
extra_info_dir = "../data/map/extra_info/"

new2raw_rid_dict = load_rid_freqs(extra_info_dir, file_name='new2raw_rid.json')
rn_dict = load_rn_dict(extra_info_dir, file_name='rn_dict.json')
rid_features_dict = get_rid_rnfea_dict(rn_dict, args)
rn = load_rn_shp(rn_dir, is_directed=True)
mbr = MBR(args.min_lat, args.min_lng, args.max_lat, args.max_lng)
test_dataset = Dataset(
    test_trajs_dir,
    mbr,
    None,
    None,
    None,
    rn,
    new2raw_rid_dict,
    parameters=args,
    is_test=True,
    debug=False
)

# of nodes:172
# of edges:262


100%|███████████████████████████████████████████████████████████████████████████████| 500/500 [00:01<00:00, 361.11it/s]


In [7]:
print('test dataset shape: ' + str(len(test_dataset)))
test_iterator = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
    shuffle=False, collate_fn=collate_fn,
    num_workers=4, pin_memory=True)

test dataset shape: 5857


In [8]:
for i, batch in enumerate(test_iterator):
#     if i<5:
#         continue
    src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths = batch
    src_pro_feas = src_pro_feas.float().to(device)
    max_trg_len = max(trg_lengths)
    batch_size = src_grid_seqs.size(0)
    constraint_mat = torch.zeros(max_trg_len, batch_size, args.id_size).to(device)
    pre_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
    next_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
    src_grid_seqs = src_grid_seqs.permute(1, 0, 2).to(device)
    trg_gps_seqs = trg_gps_seqs.permute(1, 0, 2).to(device)
    trg_rids = trg_rids.permute(1, 0, 2).long().to(device)
    trg_rates = trg_rates.permute(1, 0, 2).to(device)
    break

In [9]:
output_ids, output_rates, attention_weights = model(src_grid_seqs, src_lengths, trg_rids, trg_rates, trg_lengths,
     pre_grids, next_grids, constraint_mat,
     src_pro_feas, None, rid_features_dict,
     teacher_forcing_ratio=0)

output_rates = output_rates.squeeze(2)
output_seqs = toseq(rn_dict, output_ids, output_rates, args)
output_seqs = output_seqs.permute(1,0,2)
trg_gps_seqs = trg_gps_seqs.permute(1,0,2)

In [10]:
gt = trg_gps_seqs[:,1:]
pred = output_seqs[:,1:]
ls = src_gps_seqs[:,1:]

In [44]:
traj_index = 8
m = folium.Map(location=gt[traj_index][0], zoom_start=12)

# folium.Marker(gt[traj_index].tolist()[0]).add_to(m)
# folium.Marker(gt[traj_index].tolist()[5]).add_to(m)
for point in gt[traj_index].tolist():
    folium.Marker(point).add_to(m)
    
m

In [45]:
gt[traj_index].tolist()

[[41.15163803100586, -8.632606506347656],
 [41.151668548583984, -8.632688522338867],
 [41.151676177978516, -8.632711410522461],
 [41.15193176269531, -8.632638931274414],
 [41.151954650878906, -8.6326322555542],
 [41.15199279785156, -8.632621765136719],
 [41.15228271484375, -8.632538795471191],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0]]

In [46]:
m = folium.Map(location=pred[traj_index][0], zoom_start=12)
for point in pred[traj_index].tolist():
    folium.Marker(point).add_to(m)
    
m

In [47]:
m = folium.Map(location=ls[traj_index][0], zoom_start=12)
for point in ls[traj_index].tolist():
    folium.Marker(point).add_to(m)
    
m

In [48]:
ls[traj_index]

tensor([[41.1518, -8.6327],
        [41.1519, -8.6326],
        [41.1519, -8.6326],
        [41.1519, -8.6326],
        [41.1520, -8.6327],
        [41.1520, -8.6326],
        [41.1523, -8.6326],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]])

In [43]:
gt[traj_index]

tensor([[41.1526, -8.6302],
        [41.1526, -8.6307],
        [41.1526, -8.6307],
        [41.1527, -8.6316],
        [41.1527, -8.6317],
        [41.1527, -8.6317],
        [41.1527, -8.6328],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]], device='cuda:0')